merge: 合并 upstream/main 并保留本地图片计费功能

This commit is contained in:
song
2026-01-06 10:49:26 +08:00
187 changed files with 17081 additions and 19407 deletions

View File

@@ -11,6 +11,7 @@ import (
type Account struct {
ID int64
Name string
Notes *string
Platform string
Type string
Credentials map[string]any
@@ -262,6 +263,17 @@ func parseTempUnschedStrings(value any) []string {
return out
}
func normalizeAccountNotes(value *string) *string {
if value == nil {
return nil
}
trimmed := strings.TrimSpace(*value)
if trimmed == "" {
return nil
}
return &trimmed
}
func parseTempUnschedInt(value any) int {
switch v := value.(type) {
case int:

View File

@@ -72,6 +72,7 @@ type AccountBulkUpdate struct {
// CreateAccountRequest 创建账号请求
type CreateAccountRequest struct {
Name string `json:"name"`
Notes *string `json:"notes"`
Platform string `json:"platform"`
Type string `json:"type"`
Credentials map[string]any `json:"credentials"`
@@ -85,6 +86,7 @@ type CreateAccountRequest struct {
// UpdateAccountRequest 更新账号请求
type UpdateAccountRequest struct {
Name *string `json:"name"`
Notes *string `json:"notes"`
Credentials *map[string]any `json:"credentials"`
Extra *map[string]any `json:"extra"`
ProxyID *int64 `json:"proxy_id"`
@@ -123,6 +125,7 @@ func (s *AccountService) Create(ctx context.Context, req CreateAccountRequest) (
// 创建账号
account := &Account{
Name: req.Name,
Notes: normalizeAccountNotes(req.Notes),
Platform: req.Platform,
Type: req.Type,
Credentials: req.Credentials,
@@ -194,6 +197,9 @@ func (s *AccountService) Update(ctx context.Context, id int64, req UpdateAccount
if req.Name != nil {
account.Name = *req.Name
}
if req.Notes != nil {
account.Notes = normalizeAccountNotes(req.Notes)
}
if req.Credentials != nil {
account.Credentials = *req.Credentials

View File

@@ -7,6 +7,7 @@ import (
"crypto/rand"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"io"
"log"
@@ -14,9 +15,11 @@ import (
"regexp"
"strings"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
"github.com/Wei-Shaw/sub2api/internal/pkg/geminicli"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
)
@@ -45,6 +48,7 @@ type AccountTestService struct {
geminiTokenProvider *GeminiTokenProvider
antigravityGatewayService *AntigravityGatewayService
httpUpstream HTTPUpstream
cfg *config.Config
}
// NewAccountTestService creates a new AccountTestService
@@ -53,15 +57,35 @@ func NewAccountTestService(
geminiTokenProvider *GeminiTokenProvider,
antigravityGatewayService *AntigravityGatewayService,
httpUpstream HTTPUpstream,
cfg *config.Config,
) *AccountTestService {
return &AccountTestService{
accountRepo: accountRepo,
geminiTokenProvider: geminiTokenProvider,
antigravityGatewayService: antigravityGatewayService,
httpUpstream: httpUpstream,
cfg: cfg,
}
}
func (s *AccountTestService) validateUpstreamBaseURL(raw string) (string, error) {
if s.cfg == nil {
return "", errors.New("config is not available")
}
if !s.cfg.Security.URLAllowlist.Enabled {
return urlvalidator.ValidateURLFormat(raw, s.cfg.Security.URLAllowlist.AllowInsecureHTTP)
}
normalized, err := urlvalidator.ValidateHTTPSURL(raw, urlvalidator.ValidationOptions{
AllowedHosts: s.cfg.Security.URLAllowlist.UpstreamHosts,
RequireAllowlist: true,
AllowPrivate: s.cfg.Security.URLAllowlist.AllowPrivateHosts,
})
if err != nil {
return "", err
}
return normalized, nil
}
// generateSessionString generates a Claude Code style session string
func generateSessionString() (string, error) {
bytes := make([]byte, 32)
@@ -183,11 +207,15 @@ func (s *AccountTestService) testClaudeAccountConnection(c *gin.Context, account
return s.sendErrorAndEnd(c, "No API key available")
}
apiURL = account.GetBaseURL()
if apiURL == "" {
apiURL = "https://api.anthropic.com"
baseURL := account.GetBaseURL()
if baseURL == "" {
baseURL = "https://api.anthropic.com"
}
apiURL = strings.TrimSuffix(apiURL, "/") + "/v1/messages"
normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
if err != nil {
return s.sendErrorAndEnd(c, fmt.Sprintf("Invalid base URL: %s", err.Error()))
}
apiURL = strings.TrimSuffix(normalizedBaseURL, "/") + "/v1/messages"
} else {
return s.sendErrorAndEnd(c, fmt.Sprintf("Unsupported account type: %s", account.Type))
}
@@ -300,7 +328,11 @@ func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account
if baseURL == "" {
baseURL = "https://api.openai.com"
}
apiURL = strings.TrimSuffix(baseURL, "/") + "/responses"
normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
if err != nil {
return s.sendErrorAndEnd(c, fmt.Sprintf("Invalid base URL: %s", err.Error()))
}
apiURL = strings.TrimSuffix(normalizedBaseURL, "/") + "/responses"
} else {
return s.sendErrorAndEnd(c, fmt.Sprintf("Unsupported account type: %s", account.Type))
}
@@ -480,10 +512,14 @@ func (s *AccountTestService) buildGeminiAPIKeyRequest(ctx context.Context, accou
if baseURL == "" {
baseURL = geminicli.AIStudioBaseURL
}
normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
if err != nil {
return nil, err
}
// Use streamGenerateContent for real-time feedback
fullURL := fmt.Sprintf("%s/v1beta/models/%s:streamGenerateContent?alt=sse",
strings.TrimRight(baseURL, "/"), modelID)
strings.TrimRight(normalizedBaseURL, "/"), modelID)
req, err := http.NewRequestWithContext(ctx, "POST", fullURL, bytes.NewReader(payload))
if err != nil {
@@ -515,7 +551,11 @@ func (s *AccountTestService) buildGeminiOAuthRequest(ctx context.Context, accoun
if strings.TrimSpace(baseURL) == "" {
baseURL = geminicli.AIStudioBaseURL
}
fullURL := fmt.Sprintf("%s/v1beta/models/%s:streamGenerateContent?alt=sse", strings.TrimRight(baseURL, "/"), modelID)
normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
if err != nil {
return nil, err
}
fullURL := fmt.Sprintf("%s/v1beta/models/%s:streamGenerateContent?alt=sse", strings.TrimRight(normalizedBaseURL, "/"), modelID)
req, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader(payload))
if err != nil {
@@ -544,7 +584,11 @@ func (s *AccountTestService) buildCodeAssistRequest(ctx context.Context, accessT
}
wrappedBytes, _ := json.Marshal(wrapped)
fullURL := fmt.Sprintf("%s/v1internal:streamGenerateContent?alt=sse", geminicli.GeminiCliBaseURL)
normalizedBaseURL, err := s.validateUpstreamBaseURL(geminicli.GeminiCliBaseURL)
if err != nil {
return nil, err
}
fullURL := fmt.Sprintf("%s/v1internal:streamGenerateContent?alt=sse", normalizedBaseURL)
req, err := http.NewRequestWithContext(ctx, "POST", fullURL, bytes.NewReader(wrappedBytes))
if err != nil {

View File

@@ -123,6 +123,7 @@ type UpdateGroupInput struct {
type CreateAccountInput struct {
Name string
Notes *string
Platform string
Type string
Credentials map[string]any
@@ -138,6 +139,7 @@ type CreateAccountInput struct {
type UpdateAccountInput struct {
Name string
Notes *string
Type string // Account type: oauth, setup-token, apikey
Credentials map[string]any
Extra map[string]any
@@ -687,6 +689,7 @@ func (s *adminServiceImpl) CreateAccount(ctx context.Context, input *CreateAccou
account := &Account{
Name: input.Name,
Notes: normalizeAccountNotes(input.Notes),
Platform: input.Platform,
Type: input.Type,
Credentials: input.Credentials,
@@ -723,6 +726,9 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U
if input.Type != "" {
account.Type = input.Type
}
if input.Notes != nil {
account.Notes = normalizeAccountNotes(input.Notes)
}
if len(input.Credentials) > 0 {
account.Credentials = input.Credentials
}
@@ -730,7 +736,12 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U
account.Extra = input.Extra
}
if input.ProxyID != nil {
account.ProxyID = input.ProxyID
// 0 表示清除代理(前端发送 0 而不是 null 来表达清除意图)
if *input.ProxyID == 0 {
account.ProxyID = nil
} else {
account.ProxyID = input.ProxyID
}
account.Proxy = nil // 清除关联对象,防止 GORM Save 时根据 Proxy.ID 覆盖 ProxyID
}
// 只在指针非 nil 时更新 Concurrency支持设置为 0

View File

@@ -9,8 +9,10 @@ import (
"fmt"
"io"
"log"
mathrand "math/rand"
"net/http"
"strings"
"sync/atomic"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
@@ -255,6 +257,16 @@ func (s *AntigravityGatewayService) buildClaudeTestRequest(projectID, mappedMode
return antigravity.TransformClaudeToGemini(claudeReq, projectID, mappedModel)
}
func (s *AntigravityGatewayService) getClaudeTransformOptions(ctx context.Context) antigravity.TransformOptions {
opts := antigravity.DefaultTransformOptions()
if s.settingService == nil {
return opts
}
opts.EnableIdentityPatch = s.settingService.IsIdentityPatchEnabled(ctx)
opts.IdentityPatch = s.settingService.GetIdentityPatchPrompt(ctx)
return opts
}
// extractGeminiResponseText 从 Gemini 响应中提取文本
func extractGeminiResponseText(respBody []byte) string {
var resp map[string]any
@@ -380,7 +392,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
}
// 转换 Claude 请求为 Gemini 格式
geminiBody, err := antigravity.TransformClaudeToGemini(&claudeReq, projectID, mappedModel)
geminiBody, err := antigravity.TransformClaudeToGeminiWithOptions(&claudeReq, projectID, mappedModel, s.getClaudeTransformOptions(ctx))
if err != nil {
return nil, fmt.Errorf("transform request: %w", err)
}
@@ -394,6 +406,14 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
// 重试循环
var resp *http.Response
for attempt := 1; attempt <= antigravityMaxRetries; attempt++ {
// 检查 context 是否已取消(客户端断开连接)
select {
case <-ctx.Done():
log.Printf("%s status=context_canceled error=%v", prefix, ctx.Err())
return nil, ctx.Err()
default:
}
upstreamReq, err := antigravity.NewAPIRequest(ctx, action, accessToken, geminiBody)
if err != nil {
return nil, err
@@ -403,7 +423,10 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
if err != nil {
if attempt < antigravityMaxRetries {
log.Printf("%s status=request_failed retry=%d/%d error=%v", prefix, attempt, antigravityMaxRetries, err)
sleepAntigravityBackoff(attempt)
if !sleepAntigravityBackoffWithContext(ctx, attempt) {
log.Printf("%s status=context_canceled_during_backoff", prefix)
return nil, ctx.Err()
}
continue
}
log.Printf("%s status=request_failed retries_exhausted error=%v", prefix, err)
@@ -416,7 +439,10 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
if attempt < antigravityMaxRetries {
log.Printf("%s status=%d retry=%d/%d", prefix, resp.StatusCode, attempt, antigravityMaxRetries)
sleepAntigravityBackoff(attempt)
if !sleepAntigravityBackoffWithContext(ctx, attempt) {
log.Printf("%s status=context_canceled_during_backoff", prefix)
return nil, ctx.Err()
}
continue
}
// 所有重试都失败,标记限流状态
@@ -443,35 +469,70 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
// Antigravity /v1internal 链路在部分场景会对 thought/thinking signature 做严格校验,
// 当历史消息携带的 signature 不合法时会直接 400去除 thinking 后可继续完成请求。
if resp.StatusCode == http.StatusBadRequest && isSignatureRelatedError(respBody) {
retryClaudeReq := claudeReq
retryClaudeReq.Messages = append([]antigravity.ClaudeMessage(nil), claudeReq.Messages...)
// Conservative two-stage fallback:
// 1) Disable top-level thinking + thinking->text
// 2) Only if still signature-related 400: also downgrade tool_use/tool_result to text.
stripped, stripErr := stripThinkingFromClaudeRequest(&retryClaudeReq)
if stripErr == nil && stripped {
log.Printf("Antigravity account %d: detected signature-related 400, retrying once without thinking blocks", account.ID)
retryStages := []struct {
name string
strip func(*antigravity.ClaudeRequest) (bool, error)
}{
{name: "thinking-only", strip: stripThinkingFromClaudeRequest},
{name: "thinking+tools", strip: stripSignatureSensitiveBlocksFromClaudeRequest},
}
retryGeminiBody, txErr := antigravity.TransformClaudeToGemini(&retryClaudeReq, projectID, mappedModel)
if txErr == nil {
retryReq, buildErr := antigravity.NewAPIRequest(ctx, action, accessToken, retryGeminiBody)
if buildErr == nil {
retryResp, retryErr := s.httpUpstream.Do(retryReq, proxyURL, account.ID, account.Concurrency)
if retryErr == nil {
// Retry success: continue normal success flow with the new response.
if retryResp.StatusCode < 400 {
_ = resp.Body.Close()
resp = retryResp
respBody = nil
} else {
// Retry still errored: replace error context with retry response.
retryBody, _ := io.ReadAll(io.LimitReader(retryResp.Body, 2<<20))
_ = retryResp.Body.Close()
respBody = retryBody
resp = retryResp
}
} else {
log.Printf("Antigravity account %d: signature retry request failed: %v", account.ID, retryErr)
}
for _, stage := range retryStages {
retryClaudeReq := claudeReq
retryClaudeReq.Messages = append([]antigravity.ClaudeMessage(nil), claudeReq.Messages...)
stripped, stripErr := stage.strip(&retryClaudeReq)
if stripErr != nil || !stripped {
continue
}
log.Printf("Antigravity account %d: detected signature-related 400, retrying once (%s)", account.ID, stage.name)
retryGeminiBody, txErr := antigravity.TransformClaudeToGeminiWithOptions(&retryClaudeReq, projectID, mappedModel, s.getClaudeTransformOptions(ctx))
if txErr != nil {
continue
}
retryReq, buildErr := antigravity.NewAPIRequest(ctx, action, accessToken, retryGeminiBody)
if buildErr != nil {
continue
}
retryResp, retryErr := s.httpUpstream.Do(retryReq, proxyURL, account.ID, account.Concurrency)
if retryErr != nil {
log.Printf("Antigravity account %d: signature retry request failed (%s): %v", account.ID, stage.name, retryErr)
continue
}
if retryResp.StatusCode < 400 {
_ = resp.Body.Close()
resp = retryResp
respBody = nil
break
}
retryBody, _ := io.ReadAll(io.LimitReader(retryResp.Body, 2<<20))
_ = retryResp.Body.Close()
// If this stage fixed the signature issue, we stop; otherwise we may try the next stage.
if retryResp.StatusCode != http.StatusBadRequest || !isSignatureRelatedError(retryBody) {
respBody = retryBody
resp = &http.Response{
StatusCode: retryResp.StatusCode,
Header: retryResp.Header.Clone(),
Body: io.NopCloser(bytes.NewReader(retryBody)),
}
break
}
// Still signature-related; capture context and allow next stage.
respBody = retryBody
resp = &http.Response{
StatusCode: retryResp.StatusCode,
Header: retryResp.Header.Clone(),
Body: io.NopCloser(bytes.NewReader(retryBody)),
}
}
}
@@ -528,7 +589,17 @@ func isSignatureRelatedError(respBody []byte) bool {
}
// Keep this intentionally broad: different upstreams may use "signature" or "thought_signature".
return strings.Contains(msg, "thought_signature") || strings.Contains(msg, "signature")
if strings.Contains(msg, "thought_signature") || strings.Contains(msg, "signature") {
return true
}
// Also detect thinking block structural errors:
// "Expected `thinking` or `redacted_thinking`, but found `text`"
if strings.Contains(msg, "expected") && (strings.Contains(msg, "thinking") || strings.Contains(msg, "redacted_thinking")) {
return true
}
return false
}
func extractAntigravityErrorMessage(body []byte) string {
@@ -555,7 +626,7 @@ func extractAntigravityErrorMessage(body []byte) string {
// stripThinkingFromClaudeRequest converts thinking blocks to text blocks in a Claude Messages request.
// This preserves the thinking content while avoiding signature validation errors.
// Note: redacted_thinking blocks are removed because they cannot be converted to text.
// It also disables top-level `thinking` to prevent dummy-thought injection during retry.
// It also disables top-level `thinking` to avoid upstream structural constraints for thinking mode.
func stripThinkingFromClaudeRequest(req *antigravity.ClaudeRequest) (bool, error) {
if req == nil {
return false, nil
@@ -585,6 +656,92 @@ func stripThinkingFromClaudeRequest(req *antigravity.ClaudeRequest) (bool, error
continue
}
filtered := make([]map[string]any, 0, len(blocks))
modifiedAny := false
for _, block := range blocks {
t, _ := block["type"].(string)
switch t {
case "thinking":
thinkingText, _ := block["thinking"].(string)
if thinkingText != "" {
filtered = append(filtered, map[string]any{
"type": "text",
"text": thinkingText,
})
}
modifiedAny = true
case "redacted_thinking":
modifiedAny = true
case "":
if thinkingText, hasThinking := block["thinking"].(string); hasThinking {
if thinkingText != "" {
filtered = append(filtered, map[string]any{
"type": "text",
"text": thinkingText,
})
}
modifiedAny = true
} else {
filtered = append(filtered, block)
}
default:
filtered = append(filtered, block)
}
}
if !modifiedAny {
continue
}
if len(filtered) == 0 {
filtered = append(filtered, map[string]any{
"type": "text",
"text": "(content removed)",
})
}
newRaw, err := json.Marshal(filtered)
if err != nil {
return changed, err
}
req.Messages[i].Content = newRaw
changed = true
}
return changed, nil
}
// stripSignatureSensitiveBlocksFromClaudeRequest is a stronger retry degradation that additionally converts
// tool blocks to plain text. Use this only after a thinking-only retry still fails with signature errors.
func stripSignatureSensitiveBlocksFromClaudeRequest(req *antigravity.ClaudeRequest) (bool, error) {
if req == nil {
return false, nil
}
changed := false
if req.Thinking != nil {
req.Thinking = nil
changed = true
}
for i := range req.Messages {
raw := req.Messages[i].Content
if len(raw) == 0 {
continue
}
// If content is a string, nothing to strip.
var str string
if json.Unmarshal(raw, &str) == nil {
continue
}
// Otherwise treat as an array of blocks and convert signature-sensitive blocks to text.
var blocks []map[string]any
if err := json.Unmarshal(raw, &blocks); err != nil {
continue
}
filtered := make([]map[string]any, 0, len(blocks))
modifiedAny := false
for _, block := range blocks {
@@ -603,6 +760,49 @@ func stripThinkingFromClaudeRequest(req *antigravity.ClaudeRequest) (bool, error
case "redacted_thinking":
// Remove redacted_thinking (cannot convert encrypted content)
modifiedAny = true
case "tool_use":
// Convert tool_use to text to avoid upstream signature/thought_signature validation errors.
// This is a retry-only degradation path, so we prioritise request validity over tool semantics.
name, _ := block["name"].(string)
id, _ := block["id"].(string)
input := block["input"]
inputJSON, _ := json.Marshal(input)
text := "(tool_use)"
if name != "" {
text += " name=" + name
}
if id != "" {
text += " id=" + id
}
if len(inputJSON) > 0 && string(inputJSON) != "null" {
text += " input=" + string(inputJSON)
}
filtered = append(filtered, map[string]any{
"type": "text",
"text": text,
})
modifiedAny = true
case "tool_result":
// Convert tool_result to text so it stays consistent when tool_use is downgraded.
toolUseID, _ := block["tool_use_id"].(string)
isError, _ := block["is_error"].(bool)
content := block["content"]
contentJSON, _ := json.Marshal(content)
text := "(tool_result)"
if toolUseID != "" {
text += " tool_use_id=" + toolUseID
}
if isError {
text += " is_error=true"
}
if len(contentJSON) > 0 && string(contentJSON) != "null" {
text += "\n" + string(contentJSON)
}
filtered = append(filtered, map[string]any{
"type": "text",
"text": text,
})
modifiedAny = true
case "":
// Handle untyped block with "thinking" field
if thinkingText, hasThinking := block["thinking"].(string); hasThinking {
@@ -625,6 +825,14 @@ func stripThinkingFromClaudeRequest(req *antigravity.ClaudeRequest) (bool, error
continue
}
if len(filtered) == 0 {
// Keep request valid: upstream rejects empty content arrays.
filtered = append(filtered, map[string]any{
"type": "text",
"text": "(content removed)",
})
}
newRaw, err := json.Marshal(filtered)
if err != nil {
return changed, err
@@ -711,6 +919,14 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
// 重试循环
var resp *http.Response
for attempt := 1; attempt <= antigravityMaxRetries; attempt++ {
// 检查 context 是否已取消(客户端断开连接)
select {
case <-ctx.Done():
log.Printf("%s status=context_canceled error=%v", prefix, ctx.Err())
return nil, ctx.Err()
default:
}
upstreamReq, err := antigravity.NewAPIRequest(ctx, upstreamAction, accessToken, wrappedBody)
if err != nil {
return nil, err
@@ -720,7 +936,10 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
if err != nil {
if attempt < antigravityMaxRetries {
log.Printf("%s status=request_failed retry=%d/%d error=%v", prefix, attempt, antigravityMaxRetries, err)
sleepAntigravityBackoff(attempt)
if !sleepAntigravityBackoffWithContext(ctx, attempt) {
log.Printf("%s status=context_canceled_during_backoff", prefix)
return nil, ctx.Err()
}
continue
}
log.Printf("%s status=request_failed retries_exhausted error=%v", prefix, err)
@@ -733,7 +952,10 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
if attempt < antigravityMaxRetries {
log.Printf("%s status=%d retry=%d/%d", prefix, resp.StatusCode, attempt, antigravityMaxRetries)
sleepAntigravityBackoff(attempt)
if !sleepAntigravityBackoffWithContext(ctx, attempt) {
log.Printf("%s status=context_canceled_during_backoff", prefix)
return nil, ctx.Err()
}
continue
}
// 所有重试都失败,标记限流状态
@@ -750,11 +972,18 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
break
}
defer func() { _ = resp.Body.Close() }()
defer func() {
if resp != nil && resp.Body != nil {
_ = resp.Body.Close()
}
}()
// 处理错误响应
if resp.StatusCode >= 400 {
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
// 尽早关闭原始响应体,释放连接;后续逻辑仍可能需要读取 body因此用内存副本重新包装。
_ = resp.Body.Close()
resp.Body = io.NopCloser(bytes.NewReader(respBody))
// 模型兜底:模型不存在且开启 fallback 时,自动用 fallback 模型重试一次
if s.settingService != nil && s.settingService.IsModelFallbackEnabled(ctx) &&
@@ -763,15 +992,13 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
if fallbackModel != "" && fallbackModel != mappedModel {
log.Printf("[Antigravity] Model not found (%s), retrying with fallback model %s (account: %s)", mappedModel, fallbackModel, account.Name)
// 关闭原始响应释放连接respBody 已读取到内存)
_ = resp.Body.Close()
fallbackWrapped, err := s.wrapV1InternalRequest(projectID, fallbackModel, body)
if err == nil {
fallbackReq, err := antigravity.NewAPIRequest(ctx, upstreamAction, accessToken, fallbackWrapped)
if err == nil {
fallbackResp, err := s.httpUpstream.Do(fallbackReq, proxyURL, account.ID, account.Concurrency)
if err == nil && fallbackResp.StatusCode < 400 {
_ = resp.Body.Close()
resp = fallbackResp
} else if fallbackResp != nil {
_ = fallbackResp.Body.Close()
@@ -872,8 +1099,28 @@ func (s *AntigravityGatewayService) shouldFailoverUpstreamError(statusCode int)
}
}
func sleepAntigravityBackoff(attempt int) {
sleepGeminiBackoff(attempt) // 复用 Gemini 的退避逻辑
// sleepAntigravityBackoffWithContext 带 context 取消检查的退避等待
// 返回 true 表示正常完成等待false 表示 context 已取消
func sleepAntigravityBackoffWithContext(ctx context.Context, attempt int) bool {
delay := geminiRetryBaseDelay * time.Duration(1<<uint(attempt-1))
if delay > geminiRetryMaxDelay {
delay = geminiRetryMaxDelay
}
// +/- 20% jitter
r := mathrand.New(mathrand.NewSource(time.Now().UnixNano()))
jitter := time.Duration(float64(delay) * 0.2 * (r.Float64()*2 - 1))
sleepFor := delay + jitter
if sleepFor < 0 {
sleepFor = 0
}
select {
case <-ctx.Done():
return false
case <-time.After(sleepFor):
return true
}
}
func (s *AntigravityGatewayService) handleUpstreamError(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte) {
@@ -928,57 +1175,145 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context
return nil, errors.New("streaming not supported")
}
reader := bufio.NewReader(resp.Body)
// 使用 Scanner 并限制单行大小,避免 ReadString 无上限导致 OOM
scanner := bufio.NewScanner(resp.Body)
maxLineSize := defaultMaxLineSize
if s.settingService.cfg != nil && s.settingService.cfg.Gateway.MaxLineSize > 0 {
maxLineSize = s.settingService.cfg.Gateway.MaxLineSize
}
scanner.Buffer(make([]byte, 64*1024), maxLineSize)
usage := &ClaudeUsage{}
var firstTokenMs *int
type scanEvent struct {
line string
err error
}
// 独立 goroutine 读取上游,避免读取阻塞影响超时处理
events := make(chan scanEvent, 16)
done := make(chan struct{})
sendEvent := func(ev scanEvent) bool {
select {
case events <- ev:
return true
case <-done:
return false
}
}
var lastReadAt int64
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
go func() {
defer close(events)
for scanner.Scan() {
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
if !sendEvent(scanEvent{line: scanner.Text()}) {
return
}
}
if err := scanner.Err(); err != nil {
_ = sendEvent(scanEvent{err: err})
}
}()
defer close(done)
// 上游数据间隔超时保护(防止上游挂起长期占用连接)
streamInterval := time.Duration(0)
if s.settingService.cfg != nil && s.settingService.cfg.Gateway.StreamDataIntervalTimeout > 0 {
streamInterval = time.Duration(s.settingService.cfg.Gateway.StreamDataIntervalTimeout) * time.Second
}
var intervalTicker *time.Ticker
if streamInterval > 0 {
intervalTicker = time.NewTicker(streamInterval)
defer intervalTicker.Stop()
}
var intervalCh <-chan time.Time
if intervalTicker != nil {
intervalCh = intervalTicker.C
}
// 仅发送一次错误事件,避免多次写入导致协议混乱
errorEventSent := false
sendErrorEvent := func(reason string) {
if errorEventSent {
return
}
errorEventSent = true
_, _ = fmt.Fprintf(c.Writer, "event: error\ndata: {\"error\":\"%s\"}\n\n", reason)
flusher.Flush()
}
for {
line, err := reader.ReadString('\n')
if len(line) > 0 {
select {
case ev, ok := <-events:
if !ok {
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}, nil
}
if ev.err != nil {
if errors.Is(ev.err, bufio.ErrTooLong) {
log.Printf("SSE line too long (antigravity): max_size=%d error=%v", maxLineSize, ev.err)
sendErrorEvent("response_too_large")
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}, ev.err
}
sendErrorEvent("stream_read_error")
return nil, ev.err
}
line := ev.line
trimmed := strings.TrimRight(line, "\r\n")
if strings.HasPrefix(trimmed, "data:") {
payload := strings.TrimSpace(strings.TrimPrefix(trimmed, "data:"))
if payload == "" || payload == "[DONE]" {
_, _ = io.WriteString(c.Writer, line)
flusher.Flush()
} else {
// 解包 v1internal 响应
inner, parseErr := s.unwrapV1InternalResponse([]byte(payload))
if parseErr == nil && inner != nil {
payload = string(inner)
}
// 解析 usage
var parsed map[string]any
if json.Unmarshal(inner, &parsed) == nil {
if u := extractGeminiUsage(parsed); u != nil {
usage = u
}
}
if firstTokenMs == nil {
ms := int(time.Since(startTime).Milliseconds())
firstTokenMs = &ms
}
_, _ = fmt.Fprintf(c.Writer, "data: %s\n\n", payload)
if _, err := fmt.Fprintf(c.Writer, "%s\n", line); err != nil {
sendErrorEvent("write_failed")
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}, err
}
flusher.Flush()
continue
}
} else {
_, _ = io.WriteString(c.Writer, line)
flusher.Flush()
}
}
if errors.Is(err, io.EOF) {
break
}
if err != nil {
return nil, err
// 解包 v1internal 响应
inner, parseErr := s.unwrapV1InternalResponse([]byte(payload))
if parseErr == nil && inner != nil {
payload = string(inner)
}
// 解析 usage
var parsed map[string]any
if json.Unmarshal(inner, &parsed) == nil {
if u := extractGeminiUsage(parsed); u != nil {
usage = u
}
}
if firstTokenMs == nil {
ms := int(time.Since(startTime).Milliseconds())
firstTokenMs = &ms
}
if _, err := fmt.Fprintf(c.Writer, "data: %s\n\n", payload); err != nil {
sendErrorEvent("write_failed")
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}, err
}
flusher.Flush()
continue
}
if _, err := fmt.Fprintf(c.Writer, "%s\n", line); err != nil {
sendErrorEvent("write_failed")
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}, err
}
flusher.Flush()
case <-intervalCh:
lastRead := time.Unix(0, atomic.LoadInt64(&lastReadAt))
if time.Since(lastRead) < streamInterval {
continue
}
log.Printf("Stream data interval timeout (antigravity)")
sendErrorEvent("stream_timeout")
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout")
}
}
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}, nil
}
func (s *AntigravityGatewayService) handleGeminiNonStreamingResponse(c *gin.Context, resp *http.Response) (*ClaudeUsage, error) {
@@ -1117,7 +1452,13 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context
processor := antigravity.NewStreamingProcessor(originalModel)
var firstTokenMs *int
reader := bufio.NewReader(resp.Body)
// 使用 Scanner 并限制单行大小,避免 ReadString 无上限导致 OOM
scanner := bufio.NewScanner(resp.Body)
maxLineSize := defaultMaxLineSize
if s.settingService.cfg != nil && s.settingService.cfg.Gateway.MaxLineSize > 0 {
maxLineSize = s.settingService.cfg.Gateway.MaxLineSize
}
scanner.Buffer(make([]byte, 64*1024), maxLineSize)
// 辅助函数:转换 antigravity.ClaudeUsage 到 service.ClaudeUsage
convertUsage := func(agUsage *antigravity.ClaudeUsage) *ClaudeUsage {
@@ -1132,13 +1473,85 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context
}
}
for {
line, err := reader.ReadString('\n')
if err != nil && !errors.Is(err, io.EOF) {
return nil, fmt.Errorf("stream read error: %w", err)
type scanEvent struct {
line string
err error
}
// 独立 goroutine 读取上游,避免读取阻塞影响超时处理
events := make(chan scanEvent, 16)
done := make(chan struct{})
sendEvent := func(ev scanEvent) bool {
select {
case events <- ev:
return true
case <-done:
return false
}
}
var lastReadAt int64
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
go func() {
defer close(events)
for scanner.Scan() {
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
if !sendEvent(scanEvent{line: scanner.Text()}) {
return
}
}
if err := scanner.Err(); err != nil {
_ = sendEvent(scanEvent{err: err})
}
}()
defer close(done)
if len(line) > 0 {
streamInterval := time.Duration(0)
if s.settingService.cfg != nil && s.settingService.cfg.Gateway.StreamDataIntervalTimeout > 0 {
streamInterval = time.Duration(s.settingService.cfg.Gateway.StreamDataIntervalTimeout) * time.Second
}
var intervalTicker *time.Ticker
if streamInterval > 0 {
intervalTicker = time.NewTicker(streamInterval)
defer intervalTicker.Stop()
}
var intervalCh <-chan time.Time
if intervalTicker != nil {
intervalCh = intervalTicker.C
}
// 仅发送一次错误事件,避免多次写入导致协议混乱
errorEventSent := false
sendErrorEvent := func(reason string) {
if errorEventSent {
return
}
errorEventSent = true
_, _ = fmt.Fprintf(c.Writer, "event: error\ndata: {\"error\":\"%s\"}\n\n", reason)
flusher.Flush()
}
for {
select {
case ev, ok := <-events:
if !ok {
// 发送结束事件
finalEvents, agUsage := processor.Finish()
if len(finalEvents) > 0 {
_, _ = c.Writer.Write(finalEvents)
flusher.Flush()
}
return &antigravityStreamResult{usage: convertUsage(agUsage), firstTokenMs: firstTokenMs}, nil
}
if ev.err != nil {
if errors.Is(ev.err, bufio.ErrTooLong) {
log.Printf("SSE line too long (antigravity): max_size=%d error=%v", maxLineSize, ev.err)
sendErrorEvent("response_too_large")
return &antigravityStreamResult{usage: convertUsage(nil), firstTokenMs: firstTokenMs}, ev.err
}
sendErrorEvent("stream_read_error")
return nil, fmt.Errorf("stream read error: %w", ev.err)
}
line := ev.line
// 处理 SSE 行,转换为 Claude 格式
claudeEvents := processor.ProcessLine(strings.TrimRight(line, "\r\n"))
@@ -1153,25 +1566,23 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context
if len(finalEvents) > 0 {
_, _ = c.Writer.Write(finalEvents)
}
sendErrorEvent("write_failed")
return &antigravityStreamResult{usage: convertUsage(agUsage), firstTokenMs: firstTokenMs}, writeErr
}
flusher.Flush()
}
}
if errors.Is(err, io.EOF) {
break
case <-intervalCh:
lastRead := time.Unix(0, atomic.LoadInt64(&lastReadAt))
if time.Since(lastRead) < streamInterval {
continue
}
log.Printf("Stream data interval timeout (antigravity)")
sendErrorEvent("stream_timeout")
return &antigravityStreamResult{usage: convertUsage(nil), firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout")
}
}
// 发送结束事件
finalEvents, agUsage := processor.Finish()
if len(finalEvents) > 0 {
_, _ = c.Writer.Write(finalEvents)
flusher.Flush()
}
return &antigravityStreamResult{usage: convertUsage(agUsage), firstTokenMs: firstTokenMs}, nil
}
// extractImageSize 从 Gemini 请求中提取 image_size 参数

View File

@@ -0,0 +1,83 @@
package service
import (
"encoding/json"
"testing"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
"github.com/stretchr/testify/require"
)
func TestStripSignatureSensitiveBlocksFromClaudeRequest(t *testing.T) {
req := &antigravity.ClaudeRequest{
Model: "claude-sonnet-4-5",
Thinking: &antigravity.ThinkingConfig{
Type: "enabled",
BudgetTokens: 1024,
},
Messages: []antigravity.ClaudeMessage{
{
Role: "assistant",
Content: json.RawMessage(`[
{"type":"thinking","thinking":"secret plan","signature":""},
{"type":"tool_use","id":"t1","name":"Bash","input":{"command":"ls"}}
]`),
},
{
Role: "user",
Content: json.RawMessage(`[
{"type":"tool_result","tool_use_id":"t1","content":"ok","is_error":false},
{"type":"redacted_thinking","data":"..."}
]`),
},
},
}
changed, err := stripSignatureSensitiveBlocksFromClaudeRequest(req)
require.NoError(t, err)
require.True(t, changed)
require.Nil(t, req.Thinking)
require.Len(t, req.Messages, 2)
var blocks0 []map[string]any
require.NoError(t, json.Unmarshal(req.Messages[0].Content, &blocks0))
require.Len(t, blocks0, 2)
require.Equal(t, "text", blocks0[0]["type"])
require.Equal(t, "secret plan", blocks0[0]["text"])
require.Equal(t, "text", blocks0[1]["type"])
var blocks1 []map[string]any
require.NoError(t, json.Unmarshal(req.Messages[1].Content, &blocks1))
require.Len(t, blocks1, 1)
require.Equal(t, "text", blocks1[0]["type"])
require.NotEmpty(t, blocks1[0]["text"])
}
func TestStripThinkingFromClaudeRequest_DoesNotDowngradeTools(t *testing.T) {
req := &antigravity.ClaudeRequest{
Model: "claude-sonnet-4-5",
Thinking: &antigravity.ThinkingConfig{
Type: "enabled",
BudgetTokens: 1024,
},
Messages: []antigravity.ClaudeMessage{
{
Role: "assistant",
Content: json.RawMessage(`[{"type":"thinking","thinking":"secret plan"},{"type":"tool_use","id":"t1","name":"Bash","input":{"command":"ls"}}]`),
},
},
}
changed, err := stripThinkingFromClaudeRequest(req)
require.NoError(t, err)
require.True(t, changed)
require.Nil(t, req.Thinking)
var blocks []map[string]any
require.NoError(t, json.Unmarshal(req.Messages[0].Content, &blocks))
require.Len(t, blocks, 2)
require.Equal(t, "text", blocks[0]["type"])
require.Equal(t, "secret plan", blocks[0]["text"])
require.Equal(t, "tool_use", blocks[1]["type"])
}

View File

@@ -221,9 +221,33 @@ func (s *AuthService) SendVerifyCodeAsync(ctx context.Context, email string) (*S
// VerifyTurnstile 验证Turnstile token
func (s *AuthService) VerifyTurnstile(ctx context.Context, token string, remoteIP string) error {
required := s.cfg != nil && s.cfg.Server.Mode == "release" && s.cfg.Turnstile.Required
if required {
if s.settingService == nil {
log.Println("[Auth] Turnstile required but settings service is not configured")
return ErrTurnstileNotConfigured
}
enabled := s.settingService.IsTurnstileEnabled(ctx)
secretConfigured := s.settingService.GetTurnstileSecretKey(ctx) != ""
if !enabled || !secretConfigured {
log.Printf("[Auth] Turnstile required but not configured (enabled=%v, secret_configured=%v)", enabled, secretConfigured)
return ErrTurnstileNotConfigured
}
}
if s.turnstileService == nil {
if required {
log.Println("[Auth] Turnstile required but service not configured")
return ErrTurnstileNotConfigured
}
return nil // 服务未配置则跳过验证
}
if !required && s.settingService != nil && s.settingService.IsTurnstileEnabled(ctx) && s.settingService.GetTurnstileSecretKey(ctx) == "" {
log.Println("[Auth] Turnstile enabled but secret key not configured")
}
return s.turnstileService.VerifyToken(ctx, token, remoteIP)
}

View File

@@ -16,7 +16,8 @@ import (
// 注ErrInsufficientBalance在redeem_service.go中定义
// 注ErrDailyLimitExceeded/ErrWeeklyLimitExceeded/ErrMonthlyLimitExceeded在subscription_service.go中定义
var (
ErrSubscriptionInvalid = infraerrors.Forbidden("SUBSCRIPTION_INVALID", "subscription is invalid or expired")
ErrSubscriptionInvalid = infraerrors.Forbidden("SUBSCRIPTION_INVALID", "subscription is invalid or expired")
ErrBillingServiceUnavailable = infraerrors.ServiceUnavailable("BILLING_SERVICE_ERROR", "Billing service temporarily unavailable. Please retry later.")
)
// subscriptionCacheData 订阅缓存数据结构(内部使用)
@@ -72,10 +73,11 @@ type cacheWriteTask struct {
// BillingCacheService 计费缓存服务
// 负责余额和订阅数据的缓存管理,提供高性能的计费资格检查
type BillingCacheService struct {
cache BillingCache
userRepo UserRepository
subRepo UserSubscriptionRepository
cfg *config.Config
cache BillingCache
userRepo UserRepository
subRepo UserSubscriptionRepository
cfg *config.Config
circuitBreaker *billingCircuitBreaker
cacheWriteChan chan cacheWriteTask
cacheWriteWg sync.WaitGroup
@@ -95,6 +97,7 @@ func NewBillingCacheService(cache BillingCache, userRepo UserRepository, subRepo
subRepo: subRepo,
cfg: cfg,
}
svc.circuitBreaker = newBillingCircuitBreaker(cfg.Billing.CircuitBreaker)
svc.startCacheWriteWorkers()
return svc
}
@@ -450,6 +453,9 @@ func (s *BillingCacheService) CheckBillingEligibility(ctx context.Context, user
if s.cfg.RunMode == config.RunModeSimple {
return nil
}
if s.circuitBreaker != nil && !s.circuitBreaker.Allow() {
return ErrBillingServiceUnavailable
}
// 判断计费模式
isSubscriptionMode := group != nil && group.IsSubscriptionType() && subscription != nil
@@ -465,9 +471,14 @@ func (s *BillingCacheService) CheckBillingEligibility(ctx context.Context, user
func (s *BillingCacheService) checkBalanceEligibility(ctx context.Context, userID int64) error {
balance, err := s.GetUserBalance(ctx, userID)
if err != nil {
// 缓存/数据库错误,允许通过(降级处理)
log.Printf("Warning: get user balance failed, allowing request: %v", err)
return nil
if s.circuitBreaker != nil {
s.circuitBreaker.OnFailure(err)
}
log.Printf("ALERT: billing balance check failed for user %d: %v", userID, err)
return ErrBillingServiceUnavailable.WithCause(err)
}
if s.circuitBreaker != nil {
s.circuitBreaker.OnSuccess()
}
if balance <= 0 {
@@ -482,9 +493,14 @@ func (s *BillingCacheService) checkSubscriptionEligibility(ctx context.Context,
// 获取订阅缓存数据
subData, err := s.GetSubscriptionStatus(ctx, userID, group.ID)
if err != nil {
// 缓存/数据库错误降级使用传入的subscription进行检查
log.Printf("Warning: get subscription cache failed, using fallback: %v", err)
return s.checkSubscriptionLimitsFallback(subscription, group)
if s.circuitBreaker != nil {
s.circuitBreaker.OnFailure(err)
}
log.Printf("ALERT: billing subscription check failed for user %d group %d: %v", userID, group.ID, err)
return ErrBillingServiceUnavailable.WithCause(err)
}
if s.circuitBreaker != nil {
s.circuitBreaker.OnSuccess()
}
// 检查订阅状态
@@ -513,27 +529,133 @@ func (s *BillingCacheService) checkSubscriptionEligibility(ctx context.Context,
return nil
}
// checkSubscriptionLimitsFallback 降级检查订阅限额
func (s *BillingCacheService) checkSubscriptionLimitsFallback(subscription *UserSubscription, group *Group) error {
if subscription == nil {
return ErrSubscriptionInvalid
}
type billingCircuitBreakerState int
if !subscription.IsActive() {
return ErrSubscriptionInvalid
}
const (
billingCircuitClosed billingCircuitBreakerState = iota
billingCircuitOpen
billingCircuitHalfOpen
)
if !subscription.CheckDailyLimit(group, 0) {
return ErrDailyLimitExceeded
}
if !subscription.CheckWeeklyLimit(group, 0) {
return ErrWeeklyLimitExceeded
}
if !subscription.CheckMonthlyLimit(group, 0) {
return ErrMonthlyLimitExceeded
}
return nil
type billingCircuitBreaker struct {
mu sync.Mutex
state billingCircuitBreakerState
failures int
openedAt time.Time
failureThreshold int
resetTimeout time.Duration
halfOpenRequests int
halfOpenRemaining int
}
func newBillingCircuitBreaker(cfg config.CircuitBreakerConfig) *billingCircuitBreaker {
if !cfg.Enabled {
return nil
}
resetTimeout := time.Duration(cfg.ResetTimeoutSeconds) * time.Second
if resetTimeout <= 0 {
resetTimeout = 30 * time.Second
}
halfOpen := cfg.HalfOpenRequests
if halfOpen <= 0 {
halfOpen = 1
}
threshold := cfg.FailureThreshold
if threshold <= 0 {
threshold = 5
}
return &billingCircuitBreaker{
state: billingCircuitClosed,
failureThreshold: threshold,
resetTimeout: resetTimeout,
halfOpenRequests: halfOpen,
}
}
func (b *billingCircuitBreaker) Allow() bool {
b.mu.Lock()
defer b.mu.Unlock()
switch b.state {
case billingCircuitClosed:
return true
case billingCircuitOpen:
if time.Since(b.openedAt) < b.resetTimeout {
return false
}
b.state = billingCircuitHalfOpen
b.halfOpenRemaining = b.halfOpenRequests
log.Printf("ALERT: billing circuit breaker entering half-open state")
fallthrough
case billingCircuitHalfOpen:
if b.halfOpenRemaining <= 0 {
return false
}
b.halfOpenRemaining--
return true
default:
return false
}
}
func (b *billingCircuitBreaker) OnFailure(err error) {
if b == nil {
return
}
b.mu.Lock()
defer b.mu.Unlock()
switch b.state {
case billingCircuitOpen:
return
case billingCircuitHalfOpen:
b.state = billingCircuitOpen
b.openedAt = time.Now()
b.halfOpenRemaining = 0
log.Printf("ALERT: billing circuit breaker opened after half-open failure: %v", err)
return
default:
b.failures++
if b.failures >= b.failureThreshold {
b.state = billingCircuitOpen
b.openedAt = time.Now()
b.halfOpenRemaining = 0
log.Printf("ALERT: billing circuit breaker opened after %d failures: %v", b.failures, err)
}
}
}
func (b *billingCircuitBreaker) OnSuccess() {
if b == nil {
return
}
b.mu.Lock()
defer b.mu.Unlock()
previousState := b.state
previousFailures := b.failures
b.state = billingCircuitClosed
b.failures = 0
b.halfOpenRemaining = 0
// 只有状态真正发生变化时才记录日志
if previousState != billingCircuitClosed {
log.Printf("ALERT: billing circuit breaker closed (was %s)", circuitStateString(previousState))
} else if previousFailures > 0 {
log.Printf("INFO: billing circuit breaker failures reset from %d", previousFailures)
}
}
func circuitStateString(state billingCircuitBreakerState) string {
switch state {
case billingCircuitClosed:
return "closed"
case billingCircuitOpen:
return "open"
case billingCircuitHalfOpen:
return "half-open"
default:
return "unknown"
}
}

View File

@@ -8,12 +8,13 @@ import (
"fmt"
"io"
"net/http"
"net/url"
"strconv"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
)
type CRSSyncService struct {
@@ -22,6 +23,7 @@ type CRSSyncService struct {
oauthService *OAuthService
openaiOAuthService *OpenAIOAuthService
geminiOAuthService *GeminiOAuthService
cfg *config.Config
}
func NewCRSSyncService(
@@ -30,6 +32,7 @@ func NewCRSSyncService(
oauthService *OAuthService,
openaiOAuthService *OpenAIOAuthService,
geminiOAuthService *GeminiOAuthService,
cfg *config.Config,
) *CRSSyncService {
return &CRSSyncService{
accountRepo: accountRepo,
@@ -37,6 +40,7 @@ func NewCRSSyncService(
oauthService: oauthService,
openaiOAuthService: openaiOAuthService,
geminiOAuthService: geminiOAuthService,
cfg: cfg,
}
}
@@ -187,16 +191,31 @@ type crsGeminiAPIKeyAccount struct {
}
func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput) (*SyncFromCRSResult, error) {
baseURL, err := normalizeBaseURL(input.BaseURL)
if err != nil {
return nil, err
if s.cfg == nil {
return nil, errors.New("config is not available")
}
baseURL := strings.TrimSpace(input.BaseURL)
if s.cfg.Security.URLAllowlist.Enabled {
normalized, err := normalizeBaseURL(baseURL, s.cfg.Security.URLAllowlist.CRSHosts, s.cfg.Security.URLAllowlist.AllowPrivateHosts)
if err != nil {
return nil, err
}
baseURL = normalized
} else {
normalized, err := urlvalidator.ValidateURLFormat(baseURL, s.cfg.Security.URLAllowlist.AllowInsecureHTTP)
if err != nil {
return nil, fmt.Errorf("invalid base_url: %w", err)
}
baseURL = normalized
}
if strings.TrimSpace(input.Username) == "" || strings.TrimSpace(input.Password) == "" {
return nil, errors.New("username and password are required")
}
client, err := httpclient.GetClient(httpclient.Options{
Timeout: 20 * time.Second,
Timeout: 20 * time.Second,
ValidateResolvedIP: s.cfg.Security.URLAllowlist.Enabled,
AllowPrivateHosts: s.cfg.Security.URLAllowlist.AllowPrivateHosts,
})
if err != nil {
client = &http.Client{Timeout: 20 * time.Second}
@@ -1055,17 +1074,18 @@ func mapCRSStatus(isActive bool, status string) string {
return "active"
}
func normalizeBaseURL(raw string) (string, error) {
trimmed := strings.TrimSpace(raw)
if trimmed == "" {
return "", errors.New("base_url is required")
func normalizeBaseURL(raw string, allowlist []string, allowPrivate bool) (string, error) {
// 当 allowlist 为空时,不强制要求白名单(只进行基本的 URL 和 SSRF 验证)
requireAllowlist := len(allowlist) > 0
normalized, err := urlvalidator.ValidateHTTPSURL(raw, urlvalidator.ValidationOptions{
AllowedHosts: allowlist,
RequireAllowlist: requireAllowlist,
AllowPrivate: allowPrivate,
})
if err != nil {
return "", fmt.Errorf("invalid base_url: %w", err)
}
u, err := url.Parse(trimmed)
if err != nil || u.Scheme == "" || u.Host == "" {
return "", fmt.Errorf("invalid base_url: %s", trimmed)
}
u.Path = strings.TrimRight(u.Path, "/")
return strings.TrimRight(u.String(), "/"), nil
return normalized, nil
}
// cleanBaseURL removes trailing suffix from base_url in credentials

View File

@@ -101,6 +101,10 @@ const (
SettingKeyFallbackModelOpenAI = "fallback_model_openai"
SettingKeyFallbackModelGemini = "fallback_model_gemini"
SettingKeyFallbackModelAntigravity = "fallback_model_antigravity"
// Request identity patch (Claude -> Gemini systemInstruction injection)
SettingKeyEnableIdentityPatch = "enable_identity_patch"
SettingKeyIdentityPatchPrompt = "identity_patch_prompt"
)
// AdminAPIKeyPrefix is the prefix for admin API keys (distinct from user "sk-" keys).

View File

@@ -84,25 +84,37 @@ func FilterThinkingBlocks(body []byte) []byte {
return filterThinkingBlocksInternal(body, false)
}
// FilterThinkingBlocksForRetry removes thinking blocks from HISTORICAL messages for retry scenarios.
// This is used when upstream returns signature-related 400 errors.
// FilterThinkingBlocksForRetry strips thinking-related constructs for retry scenarios.
//
// Key insight:
// - User's thinking.type = "enabled" should be PRESERVED (user's intent)
// - Only HISTORICAL assistant messages have thinking blocks with signatures
// - These signatures may be invalid when switching accounts/platforms
// - New responses will generate fresh thinking blocks without signature issues
// Why:
// - Upstreams may reject historical `thinking`/`redacted_thinking` blocks due to invalid/missing signatures.
// - Anthropic extended thinking has a structural constraint: when top-level `thinking` is enabled and the
// final message is an assistant prefill, the assistant content must start with a thinking block.
// - If we remove thinking blocks but keep top-level `thinking` enabled, we can trigger:
// "Expected `thinking` or `redacted_thinking`, but found `text`"
//
// Strategy:
// - Keep thinking.type = "enabled" (preserve user intent)
// - Remove thinking/redacted_thinking blocks from historical assistant messages
// - Ensure no message has empty content after filtering
// Strategy (B: preserve content as text):
// - Disable top-level `thinking` (remove `thinking` field).
// - Convert `thinking` blocks to `text` blocks (preserve the thinking content).
// - Remove `redacted_thinking` blocks (cannot be converted to text).
// - Ensure no message ends up with empty content.
func FilterThinkingBlocksForRetry(body []byte) []byte {
// Fast path: check for presence of thinking-related keys in messages
if !bytes.Contains(body, []byte(`"type":"thinking"`)) &&
!bytes.Contains(body, []byte(`"type": "thinking"`)) &&
!bytes.Contains(body, []byte(`"type":"redacted_thinking"`)) &&
!bytes.Contains(body, []byte(`"type": "redacted_thinking"`)) {
hasThinkingContent := bytes.Contains(body, []byte(`"type":"thinking"`)) ||
bytes.Contains(body, []byte(`"type": "thinking"`)) ||
bytes.Contains(body, []byte(`"type":"redacted_thinking"`)) ||
bytes.Contains(body, []byte(`"type": "redacted_thinking"`)) ||
bytes.Contains(body, []byte(`"thinking":`)) ||
bytes.Contains(body, []byte(`"thinking" :`))
// Also check for empty content arrays that need fixing.
// Note: This is a heuristic check; the actual empty content handling is done below.
hasEmptyContent := bytes.Contains(body, []byte(`"content":[]`)) ||
bytes.Contains(body, []byte(`"content": []`)) ||
bytes.Contains(body, []byte(`"content" : []`)) ||
bytes.Contains(body, []byte(`"content" :[]`))
// Fast path: nothing to process
if !hasThinkingContent && !hasEmptyContent {
return body
}
@@ -111,15 +123,19 @@ func FilterThinkingBlocksForRetry(body []byte) []byte {
return body
}
// DO NOT modify thinking.type - preserve user's intent to use thinking mode
// The issue is with historical message signatures, not the thinking mode itself
modified := false
messages, ok := req["messages"].([]any)
if !ok {
return body
}
modified := false
// Disable top-level thinking mode for retry to avoid structural/signature constraints upstream.
if _, exists := req["thinking"]; exists {
delete(req, "thinking")
modified = true
}
newMessages := make([]any, 0, len(messages))
for _, msg := range messages {
@@ -149,33 +165,59 @@ func FilterThinkingBlocksForRetry(body []byte) []byte {
blockType, _ := blockMap["type"].(string)
// Remove thinking/redacted_thinking blocks from historical messages
// These have signatures that may be invalid across different accounts
if blockType == "thinking" || blockType == "redacted_thinking" {
// Convert thinking blocks to text (preserve content) and drop redacted_thinking.
switch blockType {
case "thinking":
modifiedThisMsg = true
thinkingText, _ := blockMap["thinking"].(string)
if thinkingText == "" {
continue
}
newContent = append(newContent, map[string]any{
"type": "text",
"text": thinkingText,
})
continue
case "redacted_thinking":
modifiedThisMsg = true
continue
}
// Handle blocks without type discriminator but with a "thinking" field.
if blockType == "" {
if rawThinking, hasThinking := blockMap["thinking"]; hasThinking {
modifiedThisMsg = true
switch v := rawThinking.(type) {
case string:
if v != "" {
newContent = append(newContent, map[string]any{"type": "text", "text": v})
}
default:
if b, err := json.Marshal(v); err == nil && len(b) > 0 {
newContent = append(newContent, map[string]any{"type": "text", "text": string(b)})
}
}
continue
}
}
newContent = append(newContent, block)
}
if modifiedThisMsg {
// Handle empty content: either from filtering or originally empty
if len(newContent) == 0 {
modified = true
// Handle empty content after filtering
if len(newContent) == 0 {
// For assistant messages, skip entirely (remove from conversation)
// For user messages, add placeholder to avoid empty content error
if role == "user" {
newContent = append(newContent, map[string]any{
"type": "text",
"text": "(content removed)",
})
msgMap["content"] = newContent
newMessages = append(newMessages, msgMap)
}
// Skip assistant messages with empty content (don't append)
continue
placeholder := "(content removed)"
if role == "assistant" {
placeholder = "(assistant content removed)"
}
newContent = append(newContent, map[string]any{
"type": "text",
"text": placeholder,
})
msgMap["content"] = newContent
} else if modifiedThisMsg {
modified = true
msgMap["content"] = newContent
}
newMessages = append(newMessages, msgMap)
@@ -183,6 +225,9 @@ func FilterThinkingBlocksForRetry(body []byte) []byte {
if modified {
req["messages"] = newMessages
} else {
// Avoid rewriting JSON when no changes are needed.
return body
}
newBody, err := json.Marshal(req)
@@ -192,6 +237,172 @@ func FilterThinkingBlocksForRetry(body []byte) []byte {
return newBody
}
// FilterSignatureSensitiveBlocksForRetry is a stronger retry filter for cases where upstream errors indicate
// signature/thought_signature validation issues involving tool blocks.
//
// This performs everything in FilterThinkingBlocksForRetry, plus:
// - Convert `tool_use` blocks to text (name/id/input) so we stop sending structured tool calls.
// - Convert `tool_result` blocks to text so we keep tool results visible without tool semantics.
//
// Use this only when needed: converting tool blocks to text changes model behaviour and can increase the
// risk of prompt injection (tool output becomes plain conversation text).
func FilterSignatureSensitiveBlocksForRetry(body []byte) []byte {
// Fast path: only run when we see likely relevant constructs.
if !bytes.Contains(body, []byte(`"type":"thinking"`)) &&
!bytes.Contains(body, []byte(`"type": "thinking"`)) &&
!bytes.Contains(body, []byte(`"type":"redacted_thinking"`)) &&
!bytes.Contains(body, []byte(`"type": "redacted_thinking"`)) &&
!bytes.Contains(body, []byte(`"type":"tool_use"`)) &&
!bytes.Contains(body, []byte(`"type": "tool_use"`)) &&
!bytes.Contains(body, []byte(`"type":"tool_result"`)) &&
!bytes.Contains(body, []byte(`"type": "tool_result"`)) &&
!bytes.Contains(body, []byte(`"thinking":`)) &&
!bytes.Contains(body, []byte(`"thinking" :`)) {
return body
}
var req map[string]any
if err := json.Unmarshal(body, &req); err != nil {
return body
}
modified := false
// Disable top-level thinking for retry to avoid structural/signature constraints upstream.
if _, exists := req["thinking"]; exists {
delete(req, "thinking")
modified = true
}
messages, ok := req["messages"].([]any)
if !ok {
return body
}
newMessages := make([]any, 0, len(messages))
for _, msg := range messages {
msgMap, ok := msg.(map[string]any)
if !ok {
newMessages = append(newMessages, msg)
continue
}
role, _ := msgMap["role"].(string)
content, ok := msgMap["content"].([]any)
if !ok {
newMessages = append(newMessages, msg)
continue
}
newContent := make([]any, 0, len(content))
modifiedThisMsg := false
for _, block := range content {
blockMap, ok := block.(map[string]any)
if !ok {
newContent = append(newContent, block)
continue
}
blockType, _ := blockMap["type"].(string)
switch blockType {
case "thinking":
modifiedThisMsg = true
thinkingText, _ := blockMap["thinking"].(string)
if thinkingText == "" {
continue
}
newContent = append(newContent, map[string]any{"type": "text", "text": thinkingText})
continue
case "redacted_thinking":
modifiedThisMsg = true
continue
case "tool_use":
modifiedThisMsg = true
name, _ := blockMap["name"].(string)
id, _ := blockMap["id"].(string)
input := blockMap["input"]
inputJSON, _ := json.Marshal(input)
text := "(tool_use)"
if name != "" {
text += " name=" + name
}
if id != "" {
text += " id=" + id
}
if len(inputJSON) > 0 && string(inputJSON) != "null" {
text += " input=" + string(inputJSON)
}
newContent = append(newContent, map[string]any{"type": "text", "text": text})
continue
case "tool_result":
modifiedThisMsg = true
toolUseID, _ := blockMap["tool_use_id"].(string)
isError, _ := blockMap["is_error"].(bool)
content := blockMap["content"]
contentJSON, _ := json.Marshal(content)
text := "(tool_result)"
if toolUseID != "" {
text += " tool_use_id=" + toolUseID
}
if isError {
text += " is_error=true"
}
if len(contentJSON) > 0 && string(contentJSON) != "null" {
text += "\n" + string(contentJSON)
}
newContent = append(newContent, map[string]any{"type": "text", "text": text})
continue
}
if blockType == "" {
if rawThinking, hasThinking := blockMap["thinking"]; hasThinking {
modifiedThisMsg = true
switch v := rawThinking.(type) {
case string:
if v != "" {
newContent = append(newContent, map[string]any{"type": "text", "text": v})
}
default:
if b, err := json.Marshal(v); err == nil && len(b) > 0 {
newContent = append(newContent, map[string]any{"type": "text", "text": string(b)})
}
}
continue
}
}
newContent = append(newContent, block)
}
if modifiedThisMsg {
modified = true
if len(newContent) == 0 {
placeholder := "(content removed)"
if role == "assistant" {
placeholder = "(assistant content removed)"
}
newContent = append(newContent, map[string]any{"type": "text", "text": placeholder})
}
msgMap["content"] = newContent
}
newMessages = append(newMessages, msgMap)
}
if !modified {
return body
}
req["messages"] = newMessages
newBody, err := json.Marshal(req)
if err != nil {
return body
}
return newBody
}
// filterThinkingBlocksInternal removes invalid thinking blocks from request
// Strategy:
// - When thinking.type != "enabled": Remove all thinking blocks

View File

@@ -151,3 +151,148 @@ func TestFilterThinkingBlocks(t *testing.T) {
})
}
}
func TestFilterThinkingBlocksForRetry_DisablesThinkingAndPreservesAsText(t *testing.T) {
input := []byte(`{
"model":"claude-3-5-sonnet-20241022",
"thinking":{"type":"enabled","budget_tokens":1024},
"messages":[
{"role":"user","content":[{"type":"text","text":"Hi"}]},
{"role":"assistant","content":[
{"type":"thinking","thinking":"Let me think...","signature":"bad_sig"},
{"type":"text","text":"Answer"}
]}
]
}`)
out := FilterThinkingBlocksForRetry(input)
var req map[string]any
require.NoError(t, json.Unmarshal(out, &req))
_, hasThinking := req["thinking"]
require.False(t, hasThinking)
msgs, ok := req["messages"].([]any)
require.True(t, ok)
require.Len(t, msgs, 2)
assistant, ok := msgs[1].(map[string]any)
require.True(t, ok)
content, ok := assistant["content"].([]any)
require.True(t, ok)
require.Len(t, content, 2)
first, ok := content[0].(map[string]any)
require.True(t, ok)
require.Equal(t, "text", first["type"])
require.Equal(t, "Let me think...", first["text"])
}
func TestFilterThinkingBlocksForRetry_DisablesThinkingEvenWithoutThinkingBlocks(t *testing.T) {
input := []byte(`{
"model":"claude-3-5-sonnet-20241022",
"thinking":{"type":"enabled","budget_tokens":1024},
"messages":[
{"role":"user","content":[{"type":"text","text":"Hi"}]},
{"role":"assistant","content":[{"type":"text","text":"Prefill"}]}
]
}`)
out := FilterThinkingBlocksForRetry(input)
var req map[string]any
require.NoError(t, json.Unmarshal(out, &req))
_, hasThinking := req["thinking"]
require.False(t, hasThinking)
}
func TestFilterThinkingBlocksForRetry_RemovesRedactedThinkingAndKeepsValidContent(t *testing.T) {
input := []byte(`{
"thinking":{"type":"enabled","budget_tokens":1024},
"messages":[
{"role":"assistant","content":[
{"type":"redacted_thinking","data":"..."},
{"type":"text","text":"Visible"}
]}
]
}`)
out := FilterThinkingBlocksForRetry(input)
var req map[string]any
require.NoError(t, json.Unmarshal(out, &req))
_, hasThinking := req["thinking"]
require.False(t, hasThinking)
msgs, ok := req["messages"].([]any)
require.True(t, ok)
msg0, ok := msgs[0].(map[string]any)
require.True(t, ok)
content, ok := msg0["content"].([]any)
require.True(t, ok)
require.Len(t, content, 1)
content0, ok := content[0].(map[string]any)
require.True(t, ok)
require.Equal(t, "text", content0["type"])
require.Equal(t, "Visible", content0["text"])
}
func TestFilterThinkingBlocksForRetry_EmptyContentGetsPlaceholder(t *testing.T) {
input := []byte(`{
"thinking":{"type":"enabled"},
"messages":[
{"role":"assistant","content":[{"type":"redacted_thinking","data":"..."}]}
]
}`)
out := FilterThinkingBlocksForRetry(input)
var req map[string]any
require.NoError(t, json.Unmarshal(out, &req))
msgs, ok := req["messages"].([]any)
require.True(t, ok)
msg0, ok := msgs[0].(map[string]any)
require.True(t, ok)
content, ok := msg0["content"].([]any)
require.True(t, ok)
require.Len(t, content, 1)
content0, ok := content[0].(map[string]any)
require.True(t, ok)
require.Equal(t, "text", content0["type"])
require.NotEmpty(t, content0["text"])
}
func TestFilterSignatureSensitiveBlocksForRetry_DowngradesTools(t *testing.T) {
input := []byte(`{
"thinking":{"type":"enabled","budget_tokens":1024},
"messages":[
{"role":"assistant","content":[
{"type":"tool_use","id":"t1","name":"Bash","input":{"command":"ls"}},
{"type":"tool_result","tool_use_id":"t1","content":"ok","is_error":false}
]}
]
}`)
out := FilterSignatureSensitiveBlocksForRetry(input)
var req map[string]any
require.NoError(t, json.Unmarshal(out, &req))
_, hasThinking := req["thinking"]
require.False(t, hasThinking)
msgs, ok := req["messages"].([]any)
require.True(t, ok)
msg0, ok := msgs[0].(map[string]any)
require.True(t, ok)
content, ok := msg0["content"].([]any)
require.True(t, ok)
require.Len(t, content, 2)
content0, ok := content[0].(map[string]any)
require.True(t, ok)
content1, ok := content[1].(map[string]any)
require.True(t, ok)
require.Equal(t, "text", content0["type"])
require.Equal(t, "text", content1["type"])
require.Contains(t, content0["text"], "tool_use")
require.Contains(t, content1["text"], "tool_result")
}

View File

@@ -15,11 +15,14 @@ import (
"regexp"
"sort"
"strings"
"sync/atomic"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
@@ -30,6 +33,7 @@ const (
claudeAPIURL = "https://api.anthropic.com/v1/messages?beta=true"
claudeAPICountTokensURL = "https://api.anthropic.com/v1/messages/count_tokens?beta=true"
stickySessionTTL = time.Hour // 粘性会话TTL
defaultMaxLineSize = 10 * 1024 * 1024
claudeCodeSystemPrompt = "You are Claude Code, Anthropic's official CLI for Claude."
)
@@ -933,8 +937,16 @@ func (s *GatewayService) getOAuthToken(ctx context.Context, account *Account) (s
// 重试相关常量
const (
maxRetries = 10 // 最大试次数
retryDelay = 3 * time.Second // 重试等待时间
// 最大试次数(包含首次请求)。过多重试会导致请求堆积与资源耗尽。
maxRetryAttempts = 5
// 指数退避:第 N 次失败后的等待 = retryBaseDelay * 2^(N-1),并且上限为 retryMaxDelay。
retryBaseDelay = 300 * time.Millisecond
retryMaxDelay = 3 * time.Second
// 最大重试耗时(包含请求本身耗时 + 退避等待时间)。
// 用于防止极端情况下 goroutine 长时间堆积导致资源耗尽。
maxRetryElapsed = 10 * time.Second
)
func (s *GatewayService) shouldRetryUpstreamError(account *Account, statusCode int) bool {
@@ -957,6 +969,40 @@ func (s *GatewayService) shouldFailoverUpstreamError(statusCode int) bool {
}
}
func retryBackoffDelay(attempt int) time.Duration {
// attempt 从 1 开始,表示第 attempt 次请求刚失败,需要等待后进行第 attempt+1 次请求。
if attempt <= 0 {
return retryBaseDelay
}
delay := retryBaseDelay * time.Duration(1<<(attempt-1))
if delay > retryMaxDelay {
return retryMaxDelay
}
return delay
}
func sleepWithContext(ctx context.Context, d time.Duration) error {
if d <= 0 {
return nil
}
timer := time.NewTimer(d)
defer func() {
if !timer.Stop() {
select {
case <-timer.C:
default:
}
}
}()
select {
case <-ctx.Done():
return ctx.Err()
case <-timer.C:
return nil
}
}
// isClaudeCodeClient 判断请求是否来自 Claude Code 客户端
// 简化判断User-Agent 匹配 + metadata.user_id 存在
func isClaudeCodeClient(userAgent string, metadataUserID string) bool {
@@ -1073,7 +1119,8 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
// 重试循环
var resp *http.Response
for attempt := 1; attempt <= maxRetries; attempt++ {
retryStart := time.Now()
for attempt := 1; attempt <= maxRetryAttempts; attempt++ {
// 构建上游请求(每次重试需要重新构建,因为请求体需要重新读取)
upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, body, token, tokenType, reqModel)
if err != nil {
@@ -1083,6 +1130,9 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
// 发送请求
resp, err = s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
if err != nil {
if resp != nil && resp.Body != nil {
_ = resp.Body.Close()
}
return nil, fmt.Errorf("upstream request failed: %w", err)
}
@@ -1093,28 +1143,80 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
_ = resp.Body.Close()
if s.isThinkingBlockSignatureError(respBody) {
looksLikeToolSignatureError := func(msg string) bool {
m := strings.ToLower(msg)
return strings.Contains(m, "tool_use") ||
strings.Contains(m, "tool_result") ||
strings.Contains(m, "functioncall") ||
strings.Contains(m, "function_call") ||
strings.Contains(m, "functionresponse") ||
strings.Contains(m, "function_response")
}
// 避免在重试预算已耗尽时再发起额外请求
if time.Since(retryStart) >= maxRetryElapsed {
resp.Body = io.NopCloser(bytes.NewReader(respBody))
break
}
log.Printf("Account %d: detected thinking block signature error, retrying with filtered thinking blocks", account.ID)
// 过滤thinking blocks并重试使用更激进的过滤
// Conservative two-stage fallback:
// 1) Disable thinking + thinking->text (preserve content)
// 2) Only if upstream still errors AND error message points to tool/function signature issues:
// also downgrade tool_use/tool_result blocks to text.
filteredBody := FilterThinkingBlocksForRetry(body)
retryReq, buildErr := s.buildUpstreamRequest(ctx, c, account, filteredBody, token, tokenType, reqModel)
if buildErr == nil {
retryResp, retryErr := s.httpUpstream.Do(retryReq, proxyURL, account.ID, account.Concurrency)
if retryErr == nil {
// 使用重试后的响应,继续后续处理
if retryResp.StatusCode < 400 {
log.Printf("Account %d: signature error retry succeeded", account.ID)
} else {
log.Printf("Account %d: signature error retry returned status %d", account.ID, retryResp.StatusCode)
log.Printf("Account %d: signature error retry succeeded (thinking downgraded)", account.ID)
resp = retryResp
break
}
retryRespBody, retryReadErr := io.ReadAll(io.LimitReader(retryResp.Body, 2<<20))
_ = retryResp.Body.Close()
if retryReadErr == nil && retryResp.StatusCode == 400 && s.isThinkingBlockSignatureError(retryRespBody) {
msg2 := extractUpstreamErrorMessage(retryRespBody)
if looksLikeToolSignatureError(msg2) && time.Since(retryStart) < maxRetryElapsed {
log.Printf("Account %d: signature retry still failing and looks tool-related, retrying with tool blocks downgraded", account.ID)
filteredBody2 := FilterSignatureSensitiveBlocksForRetry(body)
retryReq2, buildErr2 := s.buildUpstreamRequest(ctx, c, account, filteredBody2, token, tokenType, reqModel)
if buildErr2 == nil {
retryResp2, retryErr2 := s.httpUpstream.Do(retryReq2, proxyURL, account.ID, account.Concurrency)
if retryErr2 == nil {
resp = retryResp2
break
}
if retryResp2 != nil && retryResp2.Body != nil {
_ = retryResp2.Body.Close()
}
log.Printf("Account %d: tool-downgrade signature retry failed: %v", account.ID, retryErr2)
} else {
log.Printf("Account %d: tool-downgrade signature retry build failed: %v", account.ID, buildErr2)
}
}
}
// Fall back to the original retry response context.
resp = &http.Response{
StatusCode: retryResp.StatusCode,
Header: retryResp.Header.Clone(),
Body: io.NopCloser(bytes.NewReader(retryRespBody)),
}
resp = retryResp
break
}
if retryResp != nil && retryResp.Body != nil {
_ = retryResp.Body.Close()
}
log.Printf("Account %d: signature error retry failed: %v", account.ID, retryErr)
} else {
log.Printf("Account %d: signature error retry build request failed: %v", account.ID, buildErr)
}
// 重试失败,恢复原始响应体继续处理
// Retry failed: restore original response body and continue handling.
resp.Body = io.NopCloser(bytes.NewReader(respBody))
break
}
@@ -1125,11 +1227,27 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
// 检查是否需要通用重试排除400因为400已经在上面特殊处理过了
if resp.StatusCode >= 400 && resp.StatusCode != 400 && s.shouldRetryUpstreamError(account, resp.StatusCode) {
if attempt < maxRetries {
log.Printf("Account %d: upstream error %d, retry %d/%d after %v",
account.ID, resp.StatusCode, attempt, maxRetries, retryDelay)
if attempt < maxRetryAttempts {
elapsed := time.Since(retryStart)
if elapsed >= maxRetryElapsed {
break
}
delay := retryBackoffDelay(attempt)
remaining := maxRetryElapsed - elapsed
if delay > remaining {
delay = remaining
}
if delay <= 0 {
break
}
log.Printf("Account %d: upstream error %d, retry %d/%d after %v (elapsed=%v/%v)",
account.ID, resp.StatusCode, attempt, maxRetryAttempts, delay, elapsed, maxRetryElapsed)
_ = resp.Body.Close()
time.Sleep(retryDelay)
if err := sleepWithContext(ctx, delay); err != nil {
return nil, err
}
continue
}
// 最后一次尝试也失败,跳出循环处理重试耗尽
@@ -1146,6 +1264,9 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
}
break
}
if resp == nil || resp.Body == nil {
return nil, errors.New("upstream request failed: empty response")
}
defer func() { _ = resp.Body.Close() }()
// 处理重试耗尽的情况
@@ -1229,7 +1350,13 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
targetURL := claudeAPIURL
if account.Type == AccountTypeAPIKey {
baseURL := account.GetBaseURL()
targetURL = baseURL + "/v1/messages"
if baseURL != "" {
validatedURL, err := s.validateUpstreamBaseURL(baseURL)
if err != nil {
return nil, err
}
targetURL = validatedURL + "/v1/messages"
}
}
// OAuth账号应用统一指纹
@@ -1537,10 +1664,10 @@ func (s *GatewayService) handleRetryExhaustedSideEffects(ctx context.Context, re
// OAuth/Setup Token 账号的 403标记账号异常
if account.IsOAuth() && statusCode == 403 {
s.rateLimitService.HandleUpstreamError(ctx, account, statusCode, resp.Header, body)
log.Printf("Account %d: marked as error after %d retries for status %d", account.ID, maxRetries, statusCode)
log.Printf("Account %d: marked as error after %d retries for status %d", account.ID, maxRetryAttempts, statusCode)
} else {
// API Key 未配置错误码:不标记账号状态
log.Printf("Account %d: upstream error %d after %d retries (not marking account)", account.ID, statusCode, maxRetries)
log.Printf("Account %d: upstream error %d after %d retries (not marking account)", account.ID, statusCode, maxRetryAttempts)
}
}
@@ -1577,6 +1704,10 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
// 更新5h窗口状态
s.rateLimitService.UpdateSessionWindow(ctx, account, resp.Header)
if s.cfg != nil {
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.cfg.Security.ResponseHeaders)
}
// 设置SSE响应头
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
@@ -1598,51 +1729,133 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
var firstTokenMs *int
scanner := bufio.NewScanner(resp.Body)
// 设置更大的buffer以处理长行
scanner.Buffer(make([]byte, 64*1024), 1024*1024)
maxLineSize := defaultMaxLineSize
if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 {
maxLineSize = s.cfg.Gateway.MaxLineSize
}
scanner.Buffer(make([]byte, 64*1024), maxLineSize)
type scanEvent struct {
line string
err error
}
// 独立 goroutine 读取上游,避免读取阻塞导致超时/keepalive无法处理
events := make(chan scanEvent, 16)
done := make(chan struct{})
sendEvent := func(ev scanEvent) bool {
select {
case events <- ev:
return true
case <-done:
return false
}
}
var lastReadAt int64
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
go func() {
defer close(events)
for scanner.Scan() {
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
if !sendEvent(scanEvent{line: scanner.Text()}) {
return
}
}
if err := scanner.Err(); err != nil {
_ = sendEvent(scanEvent{err: err})
}
}()
defer close(done)
streamInterval := time.Duration(0)
if s.cfg != nil && s.cfg.Gateway.StreamDataIntervalTimeout > 0 {
streamInterval = time.Duration(s.cfg.Gateway.StreamDataIntervalTimeout) * time.Second
}
// 仅监控上游数据间隔超时,避免下游写入阻塞导致误判
var intervalTicker *time.Ticker
if streamInterval > 0 {
intervalTicker = time.NewTicker(streamInterval)
defer intervalTicker.Stop()
}
var intervalCh <-chan time.Time
if intervalTicker != nil {
intervalCh = intervalTicker.C
}
// 仅发送一次错误事件,避免多次写入导致协议混乱(写失败时尽力通知客户端)
errorEventSent := false
sendErrorEvent := func(reason string) {
if errorEventSent {
return
}
errorEventSent = true
_, _ = fmt.Fprintf(w, "event: error\ndata: {\"error\":\"%s\"}\n\n", reason)
flusher.Flush()
}
needModelReplace := originalModel != mappedModel
for scanner.Scan() {
line := scanner.Text()
if line == "event: error" {
return nil, errors.New("have error in stream")
}
// Extract data from SSE line (supports both "data: " and "data:" formats)
if sseDataRe.MatchString(line) {
data := sseDataRe.ReplaceAllString(line, "")
// 如果有模型映射替换响应中的model字段
if needModelReplace {
line = s.replaceModelInSSELine(line, mappedModel, originalModel)
for {
select {
case ev, ok := <-events:
if !ok {
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, nil
}
if ev.err != nil {
if errors.Is(ev.err, bufio.ErrTooLong) {
log.Printf("SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, ev.err)
sendErrorEvent("response_too_large")
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, ev.err
}
sendErrorEvent("stream_read_error")
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream read error: %w", ev.err)
}
line := ev.line
if line == "event: error" {
return nil, errors.New("have error in stream")
}
// 转发行
if _, err := fmt.Fprintf(w, "%s\n", line); err != nil {
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, err
}
flusher.Flush()
// Extract data from SSE line (supports both "data: " and "data:" formats)
if sseDataRe.MatchString(line) {
data := sseDataRe.ReplaceAllString(line, "")
// 记录首字时间:第一个有效的 content_block_delta 或 message_start
if firstTokenMs == nil && data != "" && data != "[DONE]" {
ms := int(time.Since(startTime).Milliseconds())
firstTokenMs = &ms
// 如果有模型映射替换响应中的model字段
if needModelReplace {
line = s.replaceModelInSSELine(line, mappedModel, originalModel)
}
// 转发行
if _, err := fmt.Fprintf(w, "%s\n", line); err != nil {
sendErrorEvent("write_failed")
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, err
}
flusher.Flush()
// 记录首字时间:第一个有效的 content_block_delta 或 message_start
if firstTokenMs == nil && data != "" && data != "[DONE]" {
ms := int(time.Since(startTime).Milliseconds())
firstTokenMs = &ms
}
s.parseSSEUsage(data, usage)
} else {
// 非 data 行直接转发
if _, err := fmt.Fprintf(w, "%s\n", line); err != nil {
sendErrorEvent("write_failed")
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, err
}
flusher.Flush()
}
s.parseSSEUsage(data, usage)
} else {
// 非 data 行直接转发
if _, err := fmt.Fprintf(w, "%s\n", line); err != nil {
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, err
case <-intervalCh:
lastRead := time.Unix(0, atomic.LoadInt64(&lastReadAt))
if time.Since(lastRead) < streamInterval {
continue
}
flusher.Flush()
log.Printf("Stream data interval timeout: account=%d model=%s interval=%s", account.ID, originalModel, streamInterval)
sendErrorEvent("stream_timeout")
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout")
}
}
if err := scanner.Err(); err != nil {
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream read error: %w", err)
}
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, nil
}
// replaceModelInSSELine 替换SSE数据行中的model字段
@@ -1747,15 +1960,17 @@ func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *h
body = s.replaceModelInResponseBody(body, mappedModel, originalModel)
}
// 透传响应头
for key, values := range resp.Header {
for _, value := range values {
c.Header(key, value)
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.cfg.Security.ResponseHeaders)
contentType := "application/json"
if s.cfg != nil && !s.cfg.Security.ResponseHeaders.Enabled {
if upstreamType := resp.Header.Get("Content-Type"); upstreamType != "" {
contentType = upstreamType
}
}
// 写入响应
c.Data(resp.StatusCode, "application/json", body)
c.Data(resp.StatusCode, contentType, body)
return &response.Usage, nil
}
@@ -1989,7 +2204,7 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
if resp.StatusCode == 400 && s.isThinkingBlockSignatureError(respBody) {
log.Printf("Account %d: detected thinking block signature error on count_tokens, retrying with filtered thinking blocks", account.ID)
filteredBody := FilterThinkingBlocks(body)
filteredBody := FilterThinkingBlocksForRetry(body)
retryReq, buildErr := s.buildCountTokensRequest(ctx, c, account, filteredBody, token, tokenType, reqModel)
if buildErr == nil {
retryResp, retryErr := s.httpUpstream.Do(retryReq, proxyURL, account.ID, account.Concurrency)
@@ -2045,7 +2260,13 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
targetURL := claudeAPICountTokensURL
if account.Type == AccountTypeAPIKey {
baseURL := account.GetBaseURL()
targetURL = baseURL + "/v1/messages/count_tokens"
if baseURL != "" {
validatedURL, err := s.validateUpstreamBaseURL(baseURL)
if err != nil {
return nil, err
}
targetURL = validatedURL + "/v1/messages/count_tokens"
}
}
// OAuth 账号:应用统一指纹和重写 userID
@@ -2125,6 +2346,25 @@ func (s *GatewayService) countTokensError(c *gin.Context, status int, errType, m
})
}
func (s *GatewayService) validateUpstreamBaseURL(raw string) (string, error) {
if s.cfg != nil && !s.cfg.Security.URLAllowlist.Enabled {
normalized, err := urlvalidator.ValidateURLFormat(raw, s.cfg.Security.URLAllowlist.AllowInsecureHTTP)
if err != nil {
return "", fmt.Errorf("invalid base_url: %w", err)
}
return normalized, nil
}
normalized, err := urlvalidator.ValidateHTTPSURL(raw, urlvalidator.ValidationOptions{
AllowedHosts: s.cfg.Security.URLAllowlist.UpstreamHosts,
RequireAllowlist: true,
AllowPrivate: s.cfg.Security.URLAllowlist.AllowPrivateHosts,
})
if err != nil {
return "", fmt.Errorf("invalid base_url: %w", err)
}
return normalized, nil
}
// GetAvailableModels returns the list of models available for a group
// It aggregates model_mapping keys from all schedulable accounts in the group
func (s *GatewayService) GetAvailableModels(ctx context.Context, groupID *int64, platform string) []string {

View File

@@ -18,9 +18,12 @@ import (
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/Wei-Shaw/sub2api/internal/pkg/geminicli"
"github.com/Wei-Shaw/sub2api/internal/pkg/googleapi"
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
"github.com/gin-gonic/gin"
)
@@ -41,6 +44,7 @@ type GeminiMessagesCompatService struct {
rateLimitService *RateLimitService
httpUpstream HTTPUpstream
antigravityGatewayService *AntigravityGatewayService
cfg *config.Config
}
func NewGeminiMessagesCompatService(
@@ -51,6 +55,7 @@ func NewGeminiMessagesCompatService(
rateLimitService *RateLimitService,
httpUpstream HTTPUpstream,
antigravityGatewayService *AntigravityGatewayService,
cfg *config.Config,
) *GeminiMessagesCompatService {
return &GeminiMessagesCompatService{
accountRepo: accountRepo,
@@ -60,6 +65,7 @@ func NewGeminiMessagesCompatService(
rateLimitService: rateLimitService,
httpUpstream: httpUpstream,
antigravityGatewayService: antigravityGatewayService,
cfg: cfg,
}
}
@@ -230,6 +236,25 @@ func (s *GeminiMessagesCompatService) GetAntigravityGatewayService() *Antigravit
return s.antigravityGatewayService
}
func (s *GeminiMessagesCompatService) validateUpstreamBaseURL(raw string) (string, error) {
if s.cfg != nil && !s.cfg.Security.URLAllowlist.Enabled {
normalized, err := urlvalidator.ValidateURLFormat(raw, s.cfg.Security.URLAllowlist.AllowInsecureHTTP)
if err != nil {
return "", fmt.Errorf("invalid base_url: %w", err)
}
return normalized, nil
}
normalized, err := urlvalidator.ValidateHTTPSURL(raw, urlvalidator.ValidationOptions{
AllowedHosts: s.cfg.Security.URLAllowlist.UpstreamHosts,
RequireAllowlist: true,
AllowPrivate: s.cfg.Security.URLAllowlist.AllowPrivateHosts,
})
if err != nil {
return "", fmt.Errorf("invalid base_url: %w", err)
}
return normalized, nil
}
// HasAntigravityAccounts 检查是否有可用的 antigravity 账户
func (s *GeminiMessagesCompatService) HasAntigravityAccounts(ctx context.Context, groupID *int64) (bool, error) {
var accounts []Account
@@ -359,6 +384,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
if err != nil {
return nil, s.writeClaudeError(c, http.StatusBadRequest, "invalid_request_error", err.Error())
}
originalClaudeBody := body
proxyURL := ""
if account.ProxyID != nil && account.Proxy != nil {
@@ -381,16 +407,20 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
return nil, "", errors.New("gemini api_key not configured")
}
baseURL := strings.TrimRight(account.GetCredential("base_url"), "/")
baseURL := strings.TrimSpace(account.GetCredential("base_url"))
if baseURL == "" {
baseURL = geminicli.AIStudioBaseURL
}
normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
if err != nil {
return nil, "", err
}
action := "generateContent"
if req.Stream {
action = "streamGenerateContent"
}
fullURL := fmt.Sprintf("%s/v1beta/models/%s:%s", strings.TrimRight(baseURL, "/"), mappedModel, action)
fullURL := fmt.Sprintf("%s/v1beta/models/%s:%s", strings.TrimRight(normalizedBaseURL, "/"), mappedModel, action)
if req.Stream {
fullURL += "?alt=sse"
}
@@ -427,7 +457,11 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
// 2. Without project_id -> AI Studio API (direct OAuth, like API key but with Bearer token)
if projectID != "" {
// Mode 1: Code Assist API
fullURL := fmt.Sprintf("%s/v1internal:%s", geminicli.GeminiCliBaseURL, action)
baseURL, err := s.validateUpstreamBaseURL(geminicli.GeminiCliBaseURL)
if err != nil {
return nil, "", err
}
fullURL := fmt.Sprintf("%s/v1internal:%s", strings.TrimRight(baseURL, "/"), action)
if useUpstreamStream {
fullURL += "?alt=sse"
}
@@ -453,12 +487,16 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
return upstreamReq, "x-request-id", nil
} else {
// Mode 2: AI Studio API with OAuth (like API key mode, but using Bearer token)
baseURL := strings.TrimRight(account.GetCredential("base_url"), "/")
baseURL := strings.TrimSpace(account.GetCredential("base_url"))
if baseURL == "" {
baseURL = geminicli.AIStudioBaseURL
}
normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
if err != nil {
return nil, "", err
}
fullURL := fmt.Sprintf("%s/v1beta/models/%s:%s", baseURL, mappedModel, action)
fullURL := fmt.Sprintf("%s/v1beta/models/%s:%s", strings.TrimRight(normalizedBaseURL, "/"), mappedModel, action)
if useUpstreamStream {
fullURL += "?alt=sse"
}
@@ -479,6 +517,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
}
var resp *http.Response
signatureRetryStage := 0
for attempt := 1; attempt <= geminiMaxRetries; attempt++ {
upstreamReq, idHeader, err := buildReq(ctx)
if err != nil {
@@ -503,6 +542,46 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Upstream request failed after retries: "+sanitizeUpstreamErrorMessage(err.Error()))
}
// Special-case: signature/thought_signature validation errors are not transient, but may be fixed by
// downgrading Claude thinking/tool history to plain text (conservative two-stage retry).
if resp.StatusCode == http.StatusBadRequest && signatureRetryStage < 2 {
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
_ = resp.Body.Close()
if isGeminiSignatureRelatedError(respBody) {
var strippedClaudeBody []byte
stageName := ""
switch signatureRetryStage {
case 0:
// Stage 1: disable thinking + thinking->text
strippedClaudeBody = FilterThinkingBlocksForRetry(originalClaudeBody)
stageName = "thinking-only"
signatureRetryStage = 1
default:
// Stage 2: additionally downgrade tool_use/tool_result blocks to text
strippedClaudeBody = FilterSignatureSensitiveBlocksForRetry(originalClaudeBody)
stageName = "thinking+tools"
signatureRetryStage = 2
}
retryGeminiReq, txErr := convertClaudeMessagesToGeminiGenerateContent(strippedClaudeBody)
if txErr == nil {
log.Printf("Gemini account %d: detected signature-related 400, retrying with downgraded Claude blocks (%s)", account.ID, stageName)
geminiReq = retryGeminiReq
// Consume one retry budget attempt and continue with the updated request payload.
sleepGeminiBackoff(1)
continue
}
}
// Restore body for downstream error handling.
resp = &http.Response{
StatusCode: http.StatusBadRequest,
Header: resp.Header.Clone(),
Body: io.NopCloser(bytes.NewReader(respBody)),
}
break
}
if resp.StatusCode >= 400 && s.shouldRetryGeminiUpstreamError(account, resp.StatusCode) {
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
_ = resp.Body.Close()
@@ -600,6 +679,14 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
}, nil
}
func isGeminiSignatureRelatedError(respBody []byte) bool {
msg := strings.ToLower(strings.TrimSpace(extractAntigravityErrorMessage(respBody)))
if msg == "" {
msg = strings.ToLower(string(respBody))
}
return strings.Contains(msg, "thought_signature") || strings.Contains(msg, "signature")
}
func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.Context, account *Account, originalModel string, action string, stream bool, body []byte) (*ForwardResult, error) {
startTime := time.Now()
@@ -650,12 +737,16 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
return nil, "", errors.New("gemini api_key not configured")
}
baseURL := strings.TrimRight(account.GetCredential("base_url"), "/")
baseURL := strings.TrimSpace(account.GetCredential("base_url"))
if baseURL == "" {
baseURL = geminicli.AIStudioBaseURL
}
normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
if err != nil {
return nil, "", err
}
fullURL := fmt.Sprintf("%s/v1beta/models/%s:%s", strings.TrimRight(baseURL, "/"), mappedModel, upstreamAction)
fullURL := fmt.Sprintf("%s/v1beta/models/%s:%s", strings.TrimRight(normalizedBaseURL, "/"), mappedModel, upstreamAction)
if useUpstreamStream {
fullURL += "?alt=sse"
}
@@ -687,7 +778,11 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
// 2. Without project_id -> AI Studio API (direct OAuth, like API key but with Bearer token)
if projectID != "" && !forceAIStudio {
// Mode 1: Code Assist API
fullURL := fmt.Sprintf("%s/v1internal:%s", geminicli.GeminiCliBaseURL, upstreamAction)
baseURL, err := s.validateUpstreamBaseURL(geminicli.GeminiCliBaseURL)
if err != nil {
return nil, "", err
}
fullURL := fmt.Sprintf("%s/v1internal:%s", strings.TrimRight(baseURL, "/"), upstreamAction)
if useUpstreamStream {
fullURL += "?alt=sse"
}
@@ -713,12 +808,16 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
return upstreamReq, "x-request-id", nil
} else {
// Mode 2: AI Studio API with OAuth (like API key mode, but using Bearer token)
baseURL := strings.TrimRight(account.GetCredential("base_url"), "/")
baseURL := strings.TrimSpace(account.GetCredential("base_url"))
if baseURL == "" {
baseURL = geminicli.AIStudioBaseURL
}
normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
if err != nil {
return nil, "", err
}
fullURL := fmt.Sprintf("%s/v1beta/models/%s:%s", baseURL, mappedModel, upstreamAction)
fullURL := fmt.Sprintf("%s/v1beta/models/%s:%s", strings.TrimRight(normalizedBaseURL, "/"), mappedModel, upstreamAction)
if useUpstreamStream {
fullURL += "?alt=sse"
}
@@ -1652,6 +1751,8 @@ func (s *GeminiMessagesCompatService) handleNativeNonStreamingResponse(c *gin.Co
_ = json.Unmarshal(respBody, &parsed)
}
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.cfg.Security.ResponseHeaders)
contentType := resp.Header.Get("Content-Type")
if contentType == "" {
contentType = "application/json"
@@ -1676,6 +1777,10 @@ func (s *GeminiMessagesCompatService) handleNativeStreamingResponse(c *gin.Conte
}
log.Printf("[GeminiAPI] ====================================================")
if s.cfg != nil {
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.cfg.Security.ResponseHeaders)
}
c.Status(resp.StatusCode)
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
@@ -1773,11 +1878,15 @@ func (s *GeminiMessagesCompatService) ForwardAIStudioGET(ctx context.Context, ac
return nil, errors.New("invalid path")
}
baseURL := strings.TrimRight(account.GetCredential("base_url"), "/")
baseURL := strings.TrimSpace(account.GetCredential("base_url"))
if baseURL == "" {
baseURL = geminicli.AIStudioBaseURL
}
fullURL := strings.TrimRight(baseURL, "/") + path
normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
if err != nil {
return nil, err
}
fullURL := strings.TrimRight(normalizedBaseURL, "/") + path
var proxyURL string
if account.ProxyID != nil && account.Proxy != nil {
@@ -1816,9 +1925,14 @@ func (s *GeminiMessagesCompatService) ForwardAIStudioGET(ctx context.Context, ac
defer func() { _ = resp.Body.Close() }()
body, _ := io.ReadAll(io.LimitReader(resp.Body, 8<<20))
wwwAuthenticate := resp.Header.Get("Www-Authenticate")
filteredHeaders := responseheaders.FilterHeaders(resp.Header, s.cfg.Security.ResponseHeaders)
if wwwAuthenticate != "" {
filteredHeaders.Set("Www-Authenticate", wwwAuthenticate)
}
return &UpstreamHTTPResult{
StatusCode: resp.StatusCode,
Headers: resp.Header.Clone(),
Headers: filteredHeaders,
Body: body,
}, nil
}

View File

@@ -1000,8 +1000,9 @@ func fetchProjectIDFromResourceManager(ctx context.Context, accessToken, proxyUR
req.Header.Set("User-Agent", geminicli.GeminiCLIUserAgent)
client, err := httpclient.GetClient(httpclient.Options{
ProxyURL: strings.TrimSpace(proxyURL),
Timeout: 30 * time.Second,
ProxyURL: strings.TrimSpace(proxyURL),
Timeout: 30 * time.Second,
ValidateResolvedIP: true,
})
if err != nil {
client = &http.Client{Timeout: 30 * time.Second}

View File

@@ -16,9 +16,12 @@ import (
"sort"
"strconv"
"strings"
"sync/atomic"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
"github.com/gin-gonic/gin"
)
@@ -630,10 +633,14 @@ func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin.
case AccountTypeAPIKey:
// API Key accounts use Platform API or custom base URL
baseURL := account.GetOpenAIBaseURL()
if baseURL != "" {
targetURL = baseURL + "/responses"
} else {
if baseURL == "" {
targetURL = openaiPlatformAPIURL
} else {
validatedURL, err := s.validateUpstreamBaseURL(baseURL)
if err != nil {
return nil, err
}
targetURL = validatedURL + "/responses"
}
default:
targetURL = openaiPlatformAPIURL
@@ -755,6 +762,10 @@ type openaiStreamingResult struct {
}
func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, startTime time.Time, originalModel, mappedModel string) (*openaiStreamingResult, error) {
if s.cfg != nil {
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.cfg.Security.ResponseHeaders)
}
// Set SSE response headers
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
@@ -775,48 +786,158 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
usage := &OpenAIUsage{}
var firstTokenMs *int
scanner := bufio.NewScanner(resp.Body)
scanner.Buffer(make([]byte, 64*1024), 1024*1024)
maxLineSize := defaultMaxLineSize
if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 {
maxLineSize = s.cfg.Gateway.MaxLineSize
}
scanner.Buffer(make([]byte, 64*1024), maxLineSize)
type scanEvent struct {
line string
err error
}
// 独立 goroutine 读取上游,避免读取阻塞影响 keepalive/超时处理
events := make(chan scanEvent, 16)
done := make(chan struct{})
sendEvent := func(ev scanEvent) bool {
select {
case events <- ev:
return true
case <-done:
return false
}
}
var lastReadAt int64
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
go func() {
defer close(events)
for scanner.Scan() {
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
if !sendEvent(scanEvent{line: scanner.Text()}) {
return
}
}
if err := scanner.Err(); err != nil {
_ = sendEvent(scanEvent{err: err})
}
}()
defer close(done)
streamInterval := time.Duration(0)
if s.cfg != nil && s.cfg.Gateway.StreamDataIntervalTimeout > 0 {
streamInterval = time.Duration(s.cfg.Gateway.StreamDataIntervalTimeout) * time.Second
}
// 仅监控上游数据间隔超时,不被下游写入阻塞影响
var intervalTicker *time.Ticker
if streamInterval > 0 {
intervalTicker = time.NewTicker(streamInterval)
defer intervalTicker.Stop()
}
var intervalCh <-chan time.Time
if intervalTicker != nil {
intervalCh = intervalTicker.C
}
keepaliveInterval := time.Duration(0)
if s.cfg != nil && s.cfg.Gateway.StreamKeepaliveInterval > 0 {
keepaliveInterval = time.Duration(s.cfg.Gateway.StreamKeepaliveInterval) * time.Second
}
// 下游 keepalive 仅用于防止代理空闲断开
var keepaliveTicker *time.Ticker
if keepaliveInterval > 0 {
keepaliveTicker = time.NewTicker(keepaliveInterval)
defer keepaliveTicker.Stop()
}
var keepaliveCh <-chan time.Time
if keepaliveTicker != nil {
keepaliveCh = keepaliveTicker.C
}
// 记录上次收到上游数据的时间,用于控制 keepalive 发送频率
lastDataAt := time.Now()
// 仅发送一次错误事件,避免多次写入导致协议混乱(写失败时尽力通知客户端)
errorEventSent := false
sendErrorEvent := func(reason string) {
if errorEventSent {
return
}
errorEventSent = true
_, _ = fmt.Fprintf(w, "event: error\ndata: {\"error\":\"%s\"}\n\n", reason)
flusher.Flush()
}
needModelReplace := originalModel != mappedModel
for scanner.Scan() {
line := scanner.Text()
// Extract data from SSE line (supports both "data: " and "data:" formats)
if openaiSSEDataRe.MatchString(line) {
data := openaiSSEDataRe.ReplaceAllString(line, "")
// Replace model in response if needed
if needModelReplace {
line = s.replaceModelInSSELine(line, mappedModel, originalModel)
for {
select {
case ev, ok := <-events:
if !ok {
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, nil
}
if ev.err != nil {
if errors.Is(ev.err, bufio.ErrTooLong) {
log.Printf("SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, ev.err)
sendErrorEvent("response_too_large")
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, ev.err
}
sendErrorEvent("stream_read_error")
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream read error: %w", ev.err)
}
// Forward line
if _, err := fmt.Fprintf(w, "%s\n", line); err != nil {
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, err
}
flusher.Flush()
line := ev.line
lastDataAt = time.Now()
// Record first token time
if firstTokenMs == nil && data != "" && data != "[DONE]" {
ms := int(time.Since(startTime).Milliseconds())
firstTokenMs = &ms
// Extract data from SSE line (supports both "data: " and "data:" formats)
if openaiSSEDataRe.MatchString(line) {
data := openaiSSEDataRe.ReplaceAllString(line, "")
// Replace model in response if needed
if needModelReplace {
line = s.replaceModelInSSELine(line, mappedModel, originalModel)
}
// Forward line
if _, err := fmt.Fprintf(w, "%s\n", line); err != nil {
sendErrorEvent("write_failed")
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, err
}
flusher.Flush()
// Record first token time
if firstTokenMs == nil && data != "" && data != "[DONE]" {
ms := int(time.Since(startTime).Milliseconds())
firstTokenMs = &ms
}
s.parseSSEUsage(data, usage)
} else {
// Forward non-data lines as-is
if _, err := fmt.Fprintf(w, "%s\n", line); err != nil {
sendErrorEvent("write_failed")
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, err
}
flusher.Flush()
}
s.parseSSEUsage(data, usage)
} else {
// Forward non-data lines as-is
if _, err := fmt.Fprintf(w, "%s\n", line); err != nil {
case <-intervalCh:
lastRead := time.Unix(0, atomic.LoadInt64(&lastReadAt))
if time.Since(lastRead) < streamInterval {
continue
}
log.Printf("Stream data interval timeout: account=%d model=%s interval=%s", account.ID, originalModel, streamInterval)
sendErrorEvent("stream_timeout")
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout")
case <-keepaliveCh:
if time.Since(lastDataAt) < keepaliveInterval {
continue
}
if _, err := fmt.Fprint(w, ":\n\n"); err != nil {
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, err
}
flusher.Flush()
}
}
if err := scanner.Err(); err != nil {
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream read error: %w", err)
}
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, nil
}
func (s *OpenAIGatewayService) replaceModelInSSELine(line, fromModel, toModel string) string {
@@ -911,18 +1032,39 @@ func (s *OpenAIGatewayService) handleNonStreamingResponse(ctx context.Context, r
body = s.replaceModelInResponseBody(body, mappedModel, originalModel)
}
// Pass through headers
for key, values := range resp.Header {
for _, value := range values {
c.Header(key, value)
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.cfg.Security.ResponseHeaders)
contentType := "application/json"
if s.cfg != nil && !s.cfg.Security.ResponseHeaders.Enabled {
if upstreamType := resp.Header.Get("Content-Type"); upstreamType != "" {
contentType = upstreamType
}
}
c.Data(resp.StatusCode, "application/json", body)
c.Data(resp.StatusCode, contentType, body)
return usage, nil
}
func (s *OpenAIGatewayService) validateUpstreamBaseURL(raw string) (string, error) {
if s.cfg != nil && !s.cfg.Security.URLAllowlist.Enabled {
normalized, err := urlvalidator.ValidateURLFormat(raw, s.cfg.Security.URLAllowlist.AllowInsecureHTTP)
if err != nil {
return "", fmt.Errorf("invalid base_url: %w", err)
}
return normalized, nil
}
normalized, err := urlvalidator.ValidateHTTPSURL(raw, urlvalidator.ValidationOptions{
AllowedHosts: s.cfg.Security.URLAllowlist.UpstreamHosts,
RequireAllowlist: true,
AllowPrivate: s.cfg.Security.URLAllowlist.AllowPrivateHosts,
})
if err != nil {
return "", fmt.Errorf("invalid base_url: %w", err)
}
return normalized, nil
}
func (s *OpenAIGatewayService) replaceModelInResponseBody(body []byte, fromModel, toModel string) []byte {
var resp map[string]any
if err := json.Unmarshal(body, &resp); err != nil {

View File

@@ -0,0 +1,286 @@
package service
import (
"bufio"
"bytes"
"errors"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/gin-gonic/gin"
)
func TestOpenAIStreamingTimeout(t *testing.T) {
gin.SetMode(gin.TestMode)
cfg := &config.Config{
Gateway: config.GatewayConfig{
StreamDataIntervalTimeout: 1,
StreamKeepaliveInterval: 0,
MaxLineSize: defaultMaxLineSize,
},
}
svc := &OpenAIGatewayService{cfg: cfg}
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
pr, pw := io.Pipe()
resp := &http.Response{
StatusCode: http.StatusOK,
Body: pr,
Header: http.Header{},
}
start := time.Now()
_, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 1}, start, "model", "model")
_ = pw.Close()
_ = pr.Close()
if err == nil || !strings.Contains(err.Error(), "stream data interval timeout") {
t.Fatalf("expected stream timeout error, got %v", err)
}
if !strings.Contains(rec.Body.String(), "stream_timeout") {
t.Fatalf("expected stream_timeout SSE error, got %q", rec.Body.String())
}
}
func TestOpenAIStreamingTooLong(t *testing.T) {
gin.SetMode(gin.TestMode)
cfg := &config.Config{
Gateway: config.GatewayConfig{
StreamDataIntervalTimeout: 0,
StreamKeepaliveInterval: 0,
MaxLineSize: 64 * 1024,
},
}
svc := &OpenAIGatewayService{cfg: cfg}
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
pr, pw := io.Pipe()
resp := &http.Response{
StatusCode: http.StatusOK,
Body: pr,
Header: http.Header{},
}
go func() {
defer func() { _ = pw.Close() }()
// 写入超过 MaxLineSize 的单行数据,触发 ErrTooLong
payload := "data: " + strings.Repeat("a", 128*1024) + "\n"
_, _ = pw.Write([]byte(payload))
}()
_, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 2}, time.Now(), "model", "model")
_ = pr.Close()
if !errors.Is(err, bufio.ErrTooLong) {
t.Fatalf("expected ErrTooLong, got %v", err)
}
if !strings.Contains(rec.Body.String(), "response_too_large") {
t.Fatalf("expected response_too_large SSE error, got %q", rec.Body.String())
}
}
func TestOpenAINonStreamingContentTypePassThrough(t *testing.T) {
gin.SetMode(gin.TestMode)
cfg := &config.Config{
Security: config.SecurityConfig{
ResponseHeaders: config.ResponseHeaderConfig{Enabled: false},
},
}
svc := &OpenAIGatewayService{cfg: cfg}
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
body := []byte(`{"usage":{"input_tokens":1,"output_tokens":2,"input_tokens_details":{"cached_tokens":0}}}`)
resp := &http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(bytes.NewReader(body)),
Header: http.Header{"Content-Type": []string{"application/vnd.test+json"}},
}
_, err := svc.handleNonStreamingResponse(c.Request.Context(), resp, c, &Account{}, "model", "model")
if err != nil {
t.Fatalf("handleNonStreamingResponse error: %v", err)
}
if !strings.Contains(rec.Header().Get("Content-Type"), "application/vnd.test+json") {
t.Fatalf("expected Content-Type passthrough, got %q", rec.Header().Get("Content-Type"))
}
}
func TestOpenAINonStreamingContentTypeDefault(t *testing.T) {
gin.SetMode(gin.TestMode)
cfg := &config.Config{
Security: config.SecurityConfig{
ResponseHeaders: config.ResponseHeaderConfig{Enabled: false},
},
}
svc := &OpenAIGatewayService{cfg: cfg}
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
body := []byte(`{"usage":{"input_tokens":1,"output_tokens":2,"input_tokens_details":{"cached_tokens":0}}}`)
resp := &http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(bytes.NewReader(body)),
Header: http.Header{},
}
_, err := svc.handleNonStreamingResponse(c.Request.Context(), resp, c, &Account{}, "model", "model")
if err != nil {
t.Fatalf("handleNonStreamingResponse error: %v", err)
}
if !strings.Contains(rec.Header().Get("Content-Type"), "application/json") {
t.Fatalf("expected default Content-Type, got %q", rec.Header().Get("Content-Type"))
}
}
func TestOpenAIStreamingHeadersOverride(t *testing.T) {
gin.SetMode(gin.TestMode)
cfg := &config.Config{
Security: config.SecurityConfig{
ResponseHeaders: config.ResponseHeaderConfig{Enabled: false},
},
Gateway: config.GatewayConfig{
StreamDataIntervalTimeout: 0,
StreamKeepaliveInterval: 0,
MaxLineSize: defaultMaxLineSize,
},
}
svc := &OpenAIGatewayService{cfg: cfg}
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
pr, pw := io.Pipe()
resp := &http.Response{
StatusCode: http.StatusOK,
Body: pr,
Header: http.Header{
"Cache-Control": []string{"upstream"},
"X-Request-Id": []string{"req-123"},
"Content-Type": []string{"application/custom"},
},
}
go func() {
defer func() { _ = pw.Close() }()
_, _ = pw.Write([]byte("data: {}\n\n"))
}()
_, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 1}, time.Now(), "model", "model")
_ = pr.Close()
if err != nil {
t.Fatalf("handleStreamingResponse error: %v", err)
}
if rec.Header().Get("Cache-Control") != "no-cache" {
t.Fatalf("expected Cache-Control override, got %q", rec.Header().Get("Cache-Control"))
}
if rec.Header().Get("Content-Type") != "text/event-stream" {
t.Fatalf("expected Content-Type override, got %q", rec.Header().Get("Content-Type"))
}
if rec.Header().Get("X-Request-Id") != "req-123" {
t.Fatalf("expected X-Request-Id passthrough, got %q", rec.Header().Get("X-Request-Id"))
}
}
func TestOpenAIInvalidBaseURLWhenAllowlistDisabled(t *testing.T) {
gin.SetMode(gin.TestMode)
cfg := &config.Config{
Security: config.SecurityConfig{
URLAllowlist: config.URLAllowlistConfig{Enabled: false},
},
}
svc := &OpenAIGatewayService{cfg: cfg}
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
account := &Account{
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Credentials: map[string]any{"base_url": "://invalid-url"},
}
_, err := svc.buildUpstreamRequest(c.Request.Context(), c, account, []byte("{}"), "token", false)
if err == nil {
t.Fatalf("expected error for invalid base_url when allowlist disabled")
}
}
func TestOpenAIValidateUpstreamBaseURLDisabledRequiresHTTPS(t *testing.T) {
cfg := &config.Config{
Security: config.SecurityConfig{
URLAllowlist: config.URLAllowlistConfig{Enabled: false},
},
}
svc := &OpenAIGatewayService{cfg: cfg}
if _, err := svc.validateUpstreamBaseURL("http://not-https.example.com"); err == nil {
t.Fatalf("expected http to be rejected when allow_insecure_http is false")
}
normalized, err := svc.validateUpstreamBaseURL("https://example.com")
if err != nil {
t.Fatalf("expected https to be allowed when allowlist disabled, got %v", err)
}
if normalized != "https://example.com" {
t.Fatalf("expected raw url passthrough, got %q", normalized)
}
}
func TestOpenAIValidateUpstreamBaseURLDisabledAllowsHTTP(t *testing.T) {
cfg := &config.Config{
Security: config.SecurityConfig{
URLAllowlist: config.URLAllowlistConfig{
Enabled: false,
AllowInsecureHTTP: true,
},
},
}
svc := &OpenAIGatewayService{cfg: cfg}
normalized, err := svc.validateUpstreamBaseURL("http://not-https.example.com")
if err != nil {
t.Fatalf("expected http allowed when allow_insecure_http is true, got %v", err)
}
if normalized != "http://not-https.example.com" {
t.Fatalf("expected raw url passthrough, got %q", normalized)
}
}
func TestOpenAIValidateUpstreamBaseURLEnabledEnforcesAllowlist(t *testing.T) {
cfg := &config.Config{
Security: config.SecurityConfig{
URLAllowlist: config.URLAllowlistConfig{
Enabled: true,
UpstreamHosts: []string{"example.com"},
},
},
}
svc := &OpenAIGatewayService{cfg: cfg}
if _, err := svc.validateUpstreamBaseURL("https://example.com"); err != nil {
t.Fatalf("expected allowlisted host to pass, got %v", err)
}
if _, err := svc.validateUpstreamBaseURL("https://evil.com"); err == nil {
t.Fatalf("expected non-allowlisted host to fail")
}
}

View File

@@ -16,6 +16,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
)
var (
@@ -213,16 +214,35 @@ func (s *PricingService) syncWithRemote() error {
// downloadPricingData 从远程下载价格数据
func (s *PricingService) downloadPricingData() error {
log.Printf("[Pricing] Downloading from %s", s.cfg.Pricing.RemoteURL)
remoteURL, err := s.validatePricingURL(s.cfg.Pricing.RemoteURL)
if err != nil {
return err
}
log.Printf("[Pricing] Downloading from %s", remoteURL)
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
body, err := s.remoteClient.FetchPricingJSON(ctx, s.cfg.Pricing.RemoteURL)
var expectedHash string
if strings.TrimSpace(s.cfg.Pricing.HashURL) != "" {
expectedHash, err = s.fetchRemoteHash()
if err != nil {
return fmt.Errorf("fetch remote hash: %w", err)
}
}
body, err := s.remoteClient.FetchPricingJSON(ctx, remoteURL)
if err != nil {
return fmt.Errorf("download failed: %w", err)
}
if expectedHash != "" {
actualHash := sha256.Sum256(body)
if !strings.EqualFold(expectedHash, hex.EncodeToString(actualHash[:])) {
return fmt.Errorf("pricing hash mismatch")
}
}
// 解析JSON数据使用灵活的解析方式
data, err := s.parsePricingData(body)
if err != nil {
@@ -378,10 +398,38 @@ func (s *PricingService) useFallbackPricing() error {
// fetchRemoteHash 从远程获取哈希值
func (s *PricingService) fetchRemoteHash() (string, error) {
hashURL, err := s.validatePricingURL(s.cfg.Pricing.HashURL)
if err != nil {
return "", err
}
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
return s.remoteClient.FetchHashText(ctx, s.cfg.Pricing.HashURL)
hash, err := s.remoteClient.FetchHashText(ctx, hashURL)
if err != nil {
return "", err
}
return strings.TrimSpace(hash), nil
}
func (s *PricingService) validatePricingURL(raw string) (string, error) {
if s.cfg != nil && !s.cfg.Security.URLAllowlist.Enabled {
normalized, err := urlvalidator.ValidateURLFormat(raw, s.cfg.Security.URLAllowlist.AllowInsecureHTTP)
if err != nil {
return "", fmt.Errorf("invalid pricing url: %w", err)
}
return normalized, nil
}
normalized, err := urlvalidator.ValidateHTTPSURL(raw, urlvalidator.ValidationOptions{
AllowedHosts: s.cfg.Security.URLAllowlist.PricingHosts,
RequireAllowlist: true,
AllowPrivate: s.cfg.Security.URLAllowlist.AllowPrivateHosts,
})
if err != nil {
return "", fmt.Errorf("invalid pricing url: %w", err)
}
return normalized, nil
}
// computeFileHash 计算文件哈希

View File

@@ -130,6 +130,10 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet
updates[SettingKeyFallbackModelGemini] = settings.FallbackModelGemini
updates[SettingKeyFallbackModelAntigravity] = settings.FallbackModelAntigravity
// Identity patch configuration (Claude -> Gemini)
updates[SettingKeyEnableIdentityPatch] = strconv.FormatBool(settings.EnableIdentityPatch)
updates[SettingKeyIdentityPatchPrompt] = settings.IdentityPatchPrompt
return s.settingRepo.SetMultiple(ctx, updates)
}
@@ -213,6 +217,9 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
SettingKeyFallbackModelOpenAI: "gpt-4o",
SettingKeyFallbackModelGemini: "gemini-2.5-pro",
SettingKeyFallbackModelAntigravity: "gemini-2.5-pro",
// Identity patch defaults
SettingKeyEnableIdentityPatch: "true",
SettingKeyIdentityPatchPrompt: "",
}
return s.settingRepo.SetMultiple(ctx, defaults)
@@ -221,21 +228,23 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
// parseSettings 解析设置到结构体
func (s *SettingService) parseSettings(settings map[string]string) *SystemSettings {
result := &SystemSettings{
RegistrationEnabled: settings[SettingKeyRegistrationEnabled] == "true",
EmailVerifyEnabled: settings[SettingKeyEmailVerifyEnabled] == "true",
SMTPHost: settings[SettingKeySMTPHost],
SMTPUsername: settings[SettingKeySMTPUsername],
SMTPFrom: settings[SettingKeySMTPFrom],
SMTPFromName: settings[SettingKeySMTPFromName],
SMTPUseTLS: settings[SettingKeySMTPUseTLS] == "true",
TurnstileEnabled: settings[SettingKeyTurnstileEnabled] == "true",
TurnstileSiteKey: settings[SettingKeyTurnstileSiteKey],
SiteName: s.getStringOrDefault(settings, SettingKeySiteName, "Sub2API"),
SiteLogo: settings[SettingKeySiteLogo],
SiteSubtitle: s.getStringOrDefault(settings, SettingKeySiteSubtitle, "Subscription to API Conversion Platform"),
APIBaseURL: settings[SettingKeyAPIBaseURL],
ContactInfo: settings[SettingKeyContactInfo],
DocURL: settings[SettingKeyDocURL],
RegistrationEnabled: settings[SettingKeyRegistrationEnabled] == "true",
EmailVerifyEnabled: settings[SettingKeyEmailVerifyEnabled] == "true",
SMTPHost: settings[SettingKeySMTPHost],
SMTPUsername: settings[SettingKeySMTPUsername],
SMTPFrom: settings[SettingKeySMTPFrom],
SMTPFromName: settings[SettingKeySMTPFromName],
SMTPUseTLS: settings[SettingKeySMTPUseTLS] == "true",
SMTPPasswordConfigured: settings[SettingKeySMTPPassword] != "",
TurnstileEnabled: settings[SettingKeyTurnstileEnabled] == "true",
TurnstileSiteKey: settings[SettingKeyTurnstileSiteKey],
TurnstileSecretKeyConfigured: settings[SettingKeyTurnstileSecretKey] != "",
SiteName: s.getStringOrDefault(settings, SettingKeySiteName, "Sub2API"),
SiteLogo: settings[SettingKeySiteLogo],
SiteSubtitle: s.getStringOrDefault(settings, SettingKeySiteSubtitle, "Subscription to API Conversion Platform"),
APIBaseURL: settings[SettingKeyAPIBaseURL],
ContactInfo: settings[SettingKeyContactInfo],
DocURL: settings[SettingKeyDocURL],
}
// 解析整数类型
@@ -269,6 +278,14 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
result.FallbackModelGemini = s.getStringOrDefault(settings, SettingKeyFallbackModelGemini, "gemini-2.5-pro")
result.FallbackModelAntigravity = s.getStringOrDefault(settings, SettingKeyFallbackModelAntigravity, "gemini-2.5-pro")
// Identity patch settings (default: enabled, to preserve existing behavior)
if v, ok := settings[SettingKeyEnableIdentityPatch]; ok && v != "" {
result.EnableIdentityPatch = v == "true"
} else {
result.EnableIdentityPatch = true
}
result.IdentityPatchPrompt = settings[SettingKeyIdentityPatchPrompt]
return result
}
@@ -298,6 +315,25 @@ func (s *SettingService) GetTurnstileSecretKey(ctx context.Context) string {
return value
}
// IsIdentityPatchEnabled 检查是否启用身份补丁Claude -> Gemini systemInstruction 注入)
func (s *SettingService) IsIdentityPatchEnabled(ctx context.Context) bool {
value, err := s.settingRepo.GetValue(ctx, SettingKeyEnableIdentityPatch)
if err != nil {
// 默认开启,保持兼容
return true
}
return value == "true"
}
// GetIdentityPatchPrompt 获取自定义身份补丁提示词(为空表示使用内置默认模板)
func (s *SettingService) GetIdentityPatchPrompt(ctx context.Context) string {
value, err := s.settingRepo.GetValue(ctx, SettingKeyIdentityPatchPrompt)
if err != nil {
return ""
}
return value
}
// GenerateAdminAPIKey 生成新的管理员 API Key
func (s *SettingService) GenerateAdminAPIKey(ctx context.Context) (string, error) {
// 生成 32 字节随机数 = 64 位十六进制字符

View File

@@ -4,17 +4,19 @@ type SystemSettings struct {
RegistrationEnabled bool
EmailVerifyEnabled bool
SMTPHost string
SMTPPort int
SMTPUsername string
SMTPPassword string
SMTPFrom string
SMTPFromName string
SMTPUseTLS bool
SMTPHost string
SMTPPort int
SMTPUsername string
SMTPPassword string
SMTPPasswordConfigured bool
SMTPFrom string
SMTPFromName string
SMTPUseTLS bool
TurnstileEnabled bool
TurnstileSiteKey string
TurnstileSecretKey string
TurnstileEnabled bool
TurnstileSiteKey string
TurnstileSecretKey string
TurnstileSecretKeyConfigured bool
SiteName string
SiteLogo string
@@ -32,6 +34,10 @@ type SystemSettings struct {
FallbackModelOpenAI string `json:"fallback_model_openai"`
FallbackModelGemini string `json:"fallback_model_gemini"`
FallbackModelAntigravity string `json:"fallback_model_antigravity"`
// Identity patch configuration (Claude -> Gemini)
EnableIdentityPatch bool `json:"enable_identity_patch"`
IdentityPatchPrompt string `json:"identity_patch_prompt"`
}
type PublicSettings struct {