merge: 合并 upstream/main 并保留本地图片计费功能
This commit is contained in:
@@ -11,6 +11,7 @@ import (
|
||||
type Account struct {
|
||||
ID int64
|
||||
Name string
|
||||
Notes *string
|
||||
Platform string
|
||||
Type string
|
||||
Credentials map[string]any
|
||||
@@ -262,6 +263,17 @@ func parseTempUnschedStrings(value any) []string {
|
||||
return out
|
||||
}
|
||||
|
||||
func normalizeAccountNotes(value *string) *string {
|
||||
if value == nil {
|
||||
return nil
|
||||
}
|
||||
trimmed := strings.TrimSpace(*value)
|
||||
if trimmed == "" {
|
||||
return nil
|
||||
}
|
||||
return &trimmed
|
||||
}
|
||||
|
||||
func parseTempUnschedInt(value any) int {
|
||||
switch v := value.(type) {
|
||||
case int:
|
||||
|
||||
@@ -72,6 +72,7 @@ type AccountBulkUpdate struct {
|
||||
// CreateAccountRequest 创建账号请求
|
||||
type CreateAccountRequest struct {
|
||||
Name string `json:"name"`
|
||||
Notes *string `json:"notes"`
|
||||
Platform string `json:"platform"`
|
||||
Type string `json:"type"`
|
||||
Credentials map[string]any `json:"credentials"`
|
||||
@@ -85,6 +86,7 @@ type CreateAccountRequest struct {
|
||||
// UpdateAccountRequest 更新账号请求
|
||||
type UpdateAccountRequest struct {
|
||||
Name *string `json:"name"`
|
||||
Notes *string `json:"notes"`
|
||||
Credentials *map[string]any `json:"credentials"`
|
||||
Extra *map[string]any `json:"extra"`
|
||||
ProxyID *int64 `json:"proxy_id"`
|
||||
@@ -123,6 +125,7 @@ func (s *AccountService) Create(ctx context.Context, req CreateAccountRequest) (
|
||||
// 创建账号
|
||||
account := &Account{
|
||||
Name: req.Name,
|
||||
Notes: normalizeAccountNotes(req.Notes),
|
||||
Platform: req.Platform,
|
||||
Type: req.Type,
|
||||
Credentials: req.Credentials,
|
||||
@@ -194,6 +197,9 @@ func (s *AccountService) Update(ctx context.Context, id int64, req UpdateAccount
|
||||
if req.Name != nil {
|
||||
account.Name = *req.Name
|
||||
}
|
||||
if req.Notes != nil {
|
||||
account.Notes = normalizeAccountNotes(req.Notes)
|
||||
}
|
||||
|
||||
if req.Credentials != nil {
|
||||
account.Credentials = *req.Credentials
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
@@ -14,9 +15,11 @@ import (
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/geminicli"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
||||
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
@@ -45,6 +48,7 @@ type AccountTestService struct {
|
||||
geminiTokenProvider *GeminiTokenProvider
|
||||
antigravityGatewayService *AntigravityGatewayService
|
||||
httpUpstream HTTPUpstream
|
||||
cfg *config.Config
|
||||
}
|
||||
|
||||
// NewAccountTestService creates a new AccountTestService
|
||||
@@ -53,15 +57,35 @@ func NewAccountTestService(
|
||||
geminiTokenProvider *GeminiTokenProvider,
|
||||
antigravityGatewayService *AntigravityGatewayService,
|
||||
httpUpstream HTTPUpstream,
|
||||
cfg *config.Config,
|
||||
) *AccountTestService {
|
||||
return &AccountTestService{
|
||||
accountRepo: accountRepo,
|
||||
geminiTokenProvider: geminiTokenProvider,
|
||||
antigravityGatewayService: antigravityGatewayService,
|
||||
httpUpstream: httpUpstream,
|
||||
cfg: cfg,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *AccountTestService) validateUpstreamBaseURL(raw string) (string, error) {
|
||||
if s.cfg == nil {
|
||||
return "", errors.New("config is not available")
|
||||
}
|
||||
if !s.cfg.Security.URLAllowlist.Enabled {
|
||||
return urlvalidator.ValidateURLFormat(raw, s.cfg.Security.URLAllowlist.AllowInsecureHTTP)
|
||||
}
|
||||
normalized, err := urlvalidator.ValidateHTTPSURL(raw, urlvalidator.ValidationOptions{
|
||||
AllowedHosts: s.cfg.Security.URLAllowlist.UpstreamHosts,
|
||||
RequireAllowlist: true,
|
||||
AllowPrivate: s.cfg.Security.URLAllowlist.AllowPrivateHosts,
|
||||
})
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return normalized, nil
|
||||
}
|
||||
|
||||
// generateSessionString generates a Claude Code style session string
|
||||
func generateSessionString() (string, error) {
|
||||
bytes := make([]byte, 32)
|
||||
@@ -183,11 +207,15 @@ func (s *AccountTestService) testClaudeAccountConnection(c *gin.Context, account
|
||||
return s.sendErrorAndEnd(c, "No API key available")
|
||||
}
|
||||
|
||||
apiURL = account.GetBaseURL()
|
||||
if apiURL == "" {
|
||||
apiURL = "https://api.anthropic.com"
|
||||
baseURL := account.GetBaseURL()
|
||||
if baseURL == "" {
|
||||
baseURL = "https://api.anthropic.com"
|
||||
}
|
||||
apiURL = strings.TrimSuffix(apiURL, "/") + "/v1/messages"
|
||||
normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
|
||||
if err != nil {
|
||||
return s.sendErrorAndEnd(c, fmt.Sprintf("Invalid base URL: %s", err.Error()))
|
||||
}
|
||||
apiURL = strings.TrimSuffix(normalizedBaseURL, "/") + "/v1/messages"
|
||||
} else {
|
||||
return s.sendErrorAndEnd(c, fmt.Sprintf("Unsupported account type: %s", account.Type))
|
||||
}
|
||||
@@ -300,7 +328,11 @@ func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account
|
||||
if baseURL == "" {
|
||||
baseURL = "https://api.openai.com"
|
||||
}
|
||||
apiURL = strings.TrimSuffix(baseURL, "/") + "/responses"
|
||||
normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
|
||||
if err != nil {
|
||||
return s.sendErrorAndEnd(c, fmt.Sprintf("Invalid base URL: %s", err.Error()))
|
||||
}
|
||||
apiURL = strings.TrimSuffix(normalizedBaseURL, "/") + "/responses"
|
||||
} else {
|
||||
return s.sendErrorAndEnd(c, fmt.Sprintf("Unsupported account type: %s", account.Type))
|
||||
}
|
||||
@@ -480,10 +512,14 @@ func (s *AccountTestService) buildGeminiAPIKeyRequest(ctx context.Context, accou
|
||||
if baseURL == "" {
|
||||
baseURL = geminicli.AIStudioBaseURL
|
||||
}
|
||||
normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Use streamGenerateContent for real-time feedback
|
||||
fullURL := fmt.Sprintf("%s/v1beta/models/%s:streamGenerateContent?alt=sse",
|
||||
strings.TrimRight(baseURL, "/"), modelID)
|
||||
strings.TrimRight(normalizedBaseURL, "/"), modelID)
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", fullURL, bytes.NewReader(payload))
|
||||
if err != nil {
|
||||
@@ -515,7 +551,11 @@ func (s *AccountTestService) buildGeminiOAuthRequest(ctx context.Context, accoun
|
||||
if strings.TrimSpace(baseURL) == "" {
|
||||
baseURL = geminicli.AIStudioBaseURL
|
||||
}
|
||||
fullURL := fmt.Sprintf("%s/v1beta/models/%s:streamGenerateContent?alt=sse", strings.TrimRight(baseURL, "/"), modelID)
|
||||
normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
fullURL := fmt.Sprintf("%s/v1beta/models/%s:streamGenerateContent?alt=sse", strings.TrimRight(normalizedBaseURL, "/"), modelID)
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader(payload))
|
||||
if err != nil {
|
||||
@@ -544,7 +584,11 @@ func (s *AccountTestService) buildCodeAssistRequest(ctx context.Context, accessT
|
||||
}
|
||||
wrappedBytes, _ := json.Marshal(wrapped)
|
||||
|
||||
fullURL := fmt.Sprintf("%s/v1internal:streamGenerateContent?alt=sse", geminicli.GeminiCliBaseURL)
|
||||
normalizedBaseURL, err := s.validateUpstreamBaseURL(geminicli.GeminiCliBaseURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
fullURL := fmt.Sprintf("%s/v1internal:streamGenerateContent?alt=sse", normalizedBaseURL)
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", fullURL, bytes.NewReader(wrappedBytes))
|
||||
if err != nil {
|
||||
|
||||
@@ -123,6 +123,7 @@ type UpdateGroupInput struct {
|
||||
|
||||
type CreateAccountInput struct {
|
||||
Name string
|
||||
Notes *string
|
||||
Platform string
|
||||
Type string
|
||||
Credentials map[string]any
|
||||
@@ -138,6 +139,7 @@ type CreateAccountInput struct {
|
||||
|
||||
type UpdateAccountInput struct {
|
||||
Name string
|
||||
Notes *string
|
||||
Type string // Account type: oauth, setup-token, apikey
|
||||
Credentials map[string]any
|
||||
Extra map[string]any
|
||||
@@ -687,6 +689,7 @@ func (s *adminServiceImpl) CreateAccount(ctx context.Context, input *CreateAccou
|
||||
|
||||
account := &Account{
|
||||
Name: input.Name,
|
||||
Notes: normalizeAccountNotes(input.Notes),
|
||||
Platform: input.Platform,
|
||||
Type: input.Type,
|
||||
Credentials: input.Credentials,
|
||||
@@ -723,6 +726,9 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U
|
||||
if input.Type != "" {
|
||||
account.Type = input.Type
|
||||
}
|
||||
if input.Notes != nil {
|
||||
account.Notes = normalizeAccountNotes(input.Notes)
|
||||
}
|
||||
if len(input.Credentials) > 0 {
|
||||
account.Credentials = input.Credentials
|
||||
}
|
||||
@@ -730,7 +736,12 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U
|
||||
account.Extra = input.Extra
|
||||
}
|
||||
if input.ProxyID != nil {
|
||||
account.ProxyID = input.ProxyID
|
||||
// 0 表示清除代理(前端发送 0 而不是 null 来表达清除意图)
|
||||
if *input.ProxyID == 0 {
|
||||
account.ProxyID = nil
|
||||
} else {
|
||||
account.ProxyID = input.ProxyID
|
||||
}
|
||||
account.Proxy = nil // 清除关联对象,防止 GORM Save 时根据 Proxy.ID 覆盖 ProxyID
|
||||
}
|
||||
// 只在指针非 nil 时更新 Concurrency(支持设置为 0)
|
||||
|
||||
@@ -9,8 +9,10 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
mathrand "math/rand"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
|
||||
@@ -255,6 +257,16 @@ func (s *AntigravityGatewayService) buildClaudeTestRequest(projectID, mappedMode
|
||||
return antigravity.TransformClaudeToGemini(claudeReq, projectID, mappedModel)
|
||||
}
|
||||
|
||||
func (s *AntigravityGatewayService) getClaudeTransformOptions(ctx context.Context) antigravity.TransformOptions {
|
||||
opts := antigravity.DefaultTransformOptions()
|
||||
if s.settingService == nil {
|
||||
return opts
|
||||
}
|
||||
opts.EnableIdentityPatch = s.settingService.IsIdentityPatchEnabled(ctx)
|
||||
opts.IdentityPatch = s.settingService.GetIdentityPatchPrompt(ctx)
|
||||
return opts
|
||||
}
|
||||
|
||||
// extractGeminiResponseText 从 Gemini 响应中提取文本
|
||||
func extractGeminiResponseText(respBody []byte) string {
|
||||
var resp map[string]any
|
||||
@@ -380,7 +392,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
|
||||
}
|
||||
|
||||
// 转换 Claude 请求为 Gemini 格式
|
||||
geminiBody, err := antigravity.TransformClaudeToGemini(&claudeReq, projectID, mappedModel)
|
||||
geminiBody, err := antigravity.TransformClaudeToGeminiWithOptions(&claudeReq, projectID, mappedModel, s.getClaudeTransformOptions(ctx))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("transform request: %w", err)
|
||||
}
|
||||
@@ -394,6 +406,14 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
|
||||
// 重试循环
|
||||
var resp *http.Response
|
||||
for attempt := 1; attempt <= antigravityMaxRetries; attempt++ {
|
||||
// 检查 context 是否已取消(客户端断开连接)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
log.Printf("%s status=context_canceled error=%v", prefix, ctx.Err())
|
||||
return nil, ctx.Err()
|
||||
default:
|
||||
}
|
||||
|
||||
upstreamReq, err := antigravity.NewAPIRequest(ctx, action, accessToken, geminiBody)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -403,7 +423,10 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
|
||||
if err != nil {
|
||||
if attempt < antigravityMaxRetries {
|
||||
log.Printf("%s status=request_failed retry=%d/%d error=%v", prefix, attempt, antigravityMaxRetries, err)
|
||||
sleepAntigravityBackoff(attempt)
|
||||
if !sleepAntigravityBackoffWithContext(ctx, attempt) {
|
||||
log.Printf("%s status=context_canceled_during_backoff", prefix)
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
continue
|
||||
}
|
||||
log.Printf("%s status=request_failed retries_exhausted error=%v", prefix, err)
|
||||
@@ -416,7 +439,10 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
|
||||
|
||||
if attempt < antigravityMaxRetries {
|
||||
log.Printf("%s status=%d retry=%d/%d", prefix, resp.StatusCode, attempt, antigravityMaxRetries)
|
||||
sleepAntigravityBackoff(attempt)
|
||||
if !sleepAntigravityBackoffWithContext(ctx, attempt) {
|
||||
log.Printf("%s status=context_canceled_during_backoff", prefix)
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
continue
|
||||
}
|
||||
// 所有重试都失败,标记限流状态
|
||||
@@ -443,35 +469,70 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
|
||||
// Antigravity /v1internal 链路在部分场景会对 thought/thinking signature 做严格校验,
|
||||
// 当历史消息携带的 signature 不合法时会直接 400;去除 thinking 后可继续完成请求。
|
||||
if resp.StatusCode == http.StatusBadRequest && isSignatureRelatedError(respBody) {
|
||||
retryClaudeReq := claudeReq
|
||||
retryClaudeReq.Messages = append([]antigravity.ClaudeMessage(nil), claudeReq.Messages...)
|
||||
// Conservative two-stage fallback:
|
||||
// 1) Disable top-level thinking + thinking->text
|
||||
// 2) Only if still signature-related 400: also downgrade tool_use/tool_result to text.
|
||||
|
||||
stripped, stripErr := stripThinkingFromClaudeRequest(&retryClaudeReq)
|
||||
if stripErr == nil && stripped {
|
||||
log.Printf("Antigravity account %d: detected signature-related 400, retrying once without thinking blocks", account.ID)
|
||||
retryStages := []struct {
|
||||
name string
|
||||
strip func(*antigravity.ClaudeRequest) (bool, error)
|
||||
}{
|
||||
{name: "thinking-only", strip: stripThinkingFromClaudeRequest},
|
||||
{name: "thinking+tools", strip: stripSignatureSensitiveBlocksFromClaudeRequest},
|
||||
}
|
||||
|
||||
retryGeminiBody, txErr := antigravity.TransformClaudeToGemini(&retryClaudeReq, projectID, mappedModel)
|
||||
if txErr == nil {
|
||||
retryReq, buildErr := antigravity.NewAPIRequest(ctx, action, accessToken, retryGeminiBody)
|
||||
if buildErr == nil {
|
||||
retryResp, retryErr := s.httpUpstream.Do(retryReq, proxyURL, account.ID, account.Concurrency)
|
||||
if retryErr == nil {
|
||||
// Retry success: continue normal success flow with the new response.
|
||||
if retryResp.StatusCode < 400 {
|
||||
_ = resp.Body.Close()
|
||||
resp = retryResp
|
||||
respBody = nil
|
||||
} else {
|
||||
// Retry still errored: replace error context with retry response.
|
||||
retryBody, _ := io.ReadAll(io.LimitReader(retryResp.Body, 2<<20))
|
||||
_ = retryResp.Body.Close()
|
||||
respBody = retryBody
|
||||
resp = retryResp
|
||||
}
|
||||
} else {
|
||||
log.Printf("Antigravity account %d: signature retry request failed: %v", account.ID, retryErr)
|
||||
}
|
||||
for _, stage := range retryStages {
|
||||
retryClaudeReq := claudeReq
|
||||
retryClaudeReq.Messages = append([]antigravity.ClaudeMessage(nil), claudeReq.Messages...)
|
||||
|
||||
stripped, stripErr := stage.strip(&retryClaudeReq)
|
||||
if stripErr != nil || !stripped {
|
||||
continue
|
||||
}
|
||||
|
||||
log.Printf("Antigravity account %d: detected signature-related 400, retrying once (%s)", account.ID, stage.name)
|
||||
|
||||
retryGeminiBody, txErr := antigravity.TransformClaudeToGeminiWithOptions(&retryClaudeReq, projectID, mappedModel, s.getClaudeTransformOptions(ctx))
|
||||
if txErr != nil {
|
||||
continue
|
||||
}
|
||||
retryReq, buildErr := antigravity.NewAPIRequest(ctx, action, accessToken, retryGeminiBody)
|
||||
if buildErr != nil {
|
||||
continue
|
||||
}
|
||||
retryResp, retryErr := s.httpUpstream.Do(retryReq, proxyURL, account.ID, account.Concurrency)
|
||||
if retryErr != nil {
|
||||
log.Printf("Antigravity account %d: signature retry request failed (%s): %v", account.ID, stage.name, retryErr)
|
||||
continue
|
||||
}
|
||||
|
||||
if retryResp.StatusCode < 400 {
|
||||
_ = resp.Body.Close()
|
||||
resp = retryResp
|
||||
respBody = nil
|
||||
break
|
||||
}
|
||||
|
||||
retryBody, _ := io.ReadAll(io.LimitReader(retryResp.Body, 2<<20))
|
||||
_ = retryResp.Body.Close()
|
||||
|
||||
// If this stage fixed the signature issue, we stop; otherwise we may try the next stage.
|
||||
if retryResp.StatusCode != http.StatusBadRequest || !isSignatureRelatedError(retryBody) {
|
||||
respBody = retryBody
|
||||
resp = &http.Response{
|
||||
StatusCode: retryResp.StatusCode,
|
||||
Header: retryResp.Header.Clone(),
|
||||
Body: io.NopCloser(bytes.NewReader(retryBody)),
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
// Still signature-related; capture context and allow next stage.
|
||||
respBody = retryBody
|
||||
resp = &http.Response{
|
||||
StatusCode: retryResp.StatusCode,
|
||||
Header: retryResp.Header.Clone(),
|
||||
Body: io.NopCloser(bytes.NewReader(retryBody)),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -528,7 +589,17 @@ func isSignatureRelatedError(respBody []byte) bool {
|
||||
}
|
||||
|
||||
// Keep this intentionally broad: different upstreams may use "signature" or "thought_signature".
|
||||
return strings.Contains(msg, "thought_signature") || strings.Contains(msg, "signature")
|
||||
if strings.Contains(msg, "thought_signature") || strings.Contains(msg, "signature") {
|
||||
return true
|
||||
}
|
||||
|
||||
// Also detect thinking block structural errors:
|
||||
// "Expected `thinking` or `redacted_thinking`, but found `text`"
|
||||
if strings.Contains(msg, "expected") && (strings.Contains(msg, "thinking") || strings.Contains(msg, "redacted_thinking")) {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func extractAntigravityErrorMessage(body []byte) string {
|
||||
@@ -555,7 +626,7 @@ func extractAntigravityErrorMessage(body []byte) string {
|
||||
// stripThinkingFromClaudeRequest converts thinking blocks to text blocks in a Claude Messages request.
|
||||
// This preserves the thinking content while avoiding signature validation errors.
|
||||
// Note: redacted_thinking blocks are removed because they cannot be converted to text.
|
||||
// It also disables top-level `thinking` to prevent dummy-thought injection during retry.
|
||||
// It also disables top-level `thinking` to avoid upstream structural constraints for thinking mode.
|
||||
func stripThinkingFromClaudeRequest(req *antigravity.ClaudeRequest) (bool, error) {
|
||||
if req == nil {
|
||||
return false, nil
|
||||
@@ -585,6 +656,92 @@ func stripThinkingFromClaudeRequest(req *antigravity.ClaudeRequest) (bool, error
|
||||
continue
|
||||
}
|
||||
|
||||
filtered := make([]map[string]any, 0, len(blocks))
|
||||
modifiedAny := false
|
||||
for _, block := range blocks {
|
||||
t, _ := block["type"].(string)
|
||||
switch t {
|
||||
case "thinking":
|
||||
thinkingText, _ := block["thinking"].(string)
|
||||
if thinkingText != "" {
|
||||
filtered = append(filtered, map[string]any{
|
||||
"type": "text",
|
||||
"text": thinkingText,
|
||||
})
|
||||
}
|
||||
modifiedAny = true
|
||||
case "redacted_thinking":
|
||||
modifiedAny = true
|
||||
case "":
|
||||
if thinkingText, hasThinking := block["thinking"].(string); hasThinking {
|
||||
if thinkingText != "" {
|
||||
filtered = append(filtered, map[string]any{
|
||||
"type": "text",
|
||||
"text": thinkingText,
|
||||
})
|
||||
}
|
||||
modifiedAny = true
|
||||
} else {
|
||||
filtered = append(filtered, block)
|
||||
}
|
||||
default:
|
||||
filtered = append(filtered, block)
|
||||
}
|
||||
}
|
||||
|
||||
if !modifiedAny {
|
||||
continue
|
||||
}
|
||||
|
||||
if len(filtered) == 0 {
|
||||
filtered = append(filtered, map[string]any{
|
||||
"type": "text",
|
||||
"text": "(content removed)",
|
||||
})
|
||||
}
|
||||
|
||||
newRaw, err := json.Marshal(filtered)
|
||||
if err != nil {
|
||||
return changed, err
|
||||
}
|
||||
req.Messages[i].Content = newRaw
|
||||
changed = true
|
||||
}
|
||||
|
||||
return changed, nil
|
||||
}
|
||||
|
||||
// stripSignatureSensitiveBlocksFromClaudeRequest is a stronger retry degradation that additionally converts
|
||||
// tool blocks to plain text. Use this only after a thinking-only retry still fails with signature errors.
|
||||
func stripSignatureSensitiveBlocksFromClaudeRequest(req *antigravity.ClaudeRequest) (bool, error) {
|
||||
if req == nil {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
changed := false
|
||||
if req.Thinking != nil {
|
||||
req.Thinking = nil
|
||||
changed = true
|
||||
}
|
||||
|
||||
for i := range req.Messages {
|
||||
raw := req.Messages[i].Content
|
||||
if len(raw) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
// If content is a string, nothing to strip.
|
||||
var str string
|
||||
if json.Unmarshal(raw, &str) == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// Otherwise treat as an array of blocks and convert signature-sensitive blocks to text.
|
||||
var blocks []map[string]any
|
||||
if err := json.Unmarshal(raw, &blocks); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
filtered := make([]map[string]any, 0, len(blocks))
|
||||
modifiedAny := false
|
||||
for _, block := range blocks {
|
||||
@@ -603,6 +760,49 @@ func stripThinkingFromClaudeRequest(req *antigravity.ClaudeRequest) (bool, error
|
||||
case "redacted_thinking":
|
||||
// Remove redacted_thinking (cannot convert encrypted content)
|
||||
modifiedAny = true
|
||||
case "tool_use":
|
||||
// Convert tool_use to text to avoid upstream signature/thought_signature validation errors.
|
||||
// This is a retry-only degradation path, so we prioritise request validity over tool semantics.
|
||||
name, _ := block["name"].(string)
|
||||
id, _ := block["id"].(string)
|
||||
input := block["input"]
|
||||
inputJSON, _ := json.Marshal(input)
|
||||
text := "(tool_use)"
|
||||
if name != "" {
|
||||
text += " name=" + name
|
||||
}
|
||||
if id != "" {
|
||||
text += " id=" + id
|
||||
}
|
||||
if len(inputJSON) > 0 && string(inputJSON) != "null" {
|
||||
text += " input=" + string(inputJSON)
|
||||
}
|
||||
filtered = append(filtered, map[string]any{
|
||||
"type": "text",
|
||||
"text": text,
|
||||
})
|
||||
modifiedAny = true
|
||||
case "tool_result":
|
||||
// Convert tool_result to text so it stays consistent when tool_use is downgraded.
|
||||
toolUseID, _ := block["tool_use_id"].(string)
|
||||
isError, _ := block["is_error"].(bool)
|
||||
content := block["content"]
|
||||
contentJSON, _ := json.Marshal(content)
|
||||
text := "(tool_result)"
|
||||
if toolUseID != "" {
|
||||
text += " tool_use_id=" + toolUseID
|
||||
}
|
||||
if isError {
|
||||
text += " is_error=true"
|
||||
}
|
||||
if len(contentJSON) > 0 && string(contentJSON) != "null" {
|
||||
text += "\n" + string(contentJSON)
|
||||
}
|
||||
filtered = append(filtered, map[string]any{
|
||||
"type": "text",
|
||||
"text": text,
|
||||
})
|
||||
modifiedAny = true
|
||||
case "":
|
||||
// Handle untyped block with "thinking" field
|
||||
if thinkingText, hasThinking := block["thinking"].(string); hasThinking {
|
||||
@@ -625,6 +825,14 @@ func stripThinkingFromClaudeRequest(req *antigravity.ClaudeRequest) (bool, error
|
||||
continue
|
||||
}
|
||||
|
||||
if len(filtered) == 0 {
|
||||
// Keep request valid: upstream rejects empty content arrays.
|
||||
filtered = append(filtered, map[string]any{
|
||||
"type": "text",
|
||||
"text": "(content removed)",
|
||||
})
|
||||
}
|
||||
|
||||
newRaw, err := json.Marshal(filtered)
|
||||
if err != nil {
|
||||
return changed, err
|
||||
@@ -711,6 +919,14 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
|
||||
// 重试循环
|
||||
var resp *http.Response
|
||||
for attempt := 1; attempt <= antigravityMaxRetries; attempt++ {
|
||||
// 检查 context 是否已取消(客户端断开连接)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
log.Printf("%s status=context_canceled error=%v", prefix, ctx.Err())
|
||||
return nil, ctx.Err()
|
||||
default:
|
||||
}
|
||||
|
||||
upstreamReq, err := antigravity.NewAPIRequest(ctx, upstreamAction, accessToken, wrappedBody)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -720,7 +936,10 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
|
||||
if err != nil {
|
||||
if attempt < antigravityMaxRetries {
|
||||
log.Printf("%s status=request_failed retry=%d/%d error=%v", prefix, attempt, antigravityMaxRetries, err)
|
||||
sleepAntigravityBackoff(attempt)
|
||||
if !sleepAntigravityBackoffWithContext(ctx, attempt) {
|
||||
log.Printf("%s status=context_canceled_during_backoff", prefix)
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
continue
|
||||
}
|
||||
log.Printf("%s status=request_failed retries_exhausted error=%v", prefix, err)
|
||||
@@ -733,7 +952,10 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
|
||||
|
||||
if attempt < antigravityMaxRetries {
|
||||
log.Printf("%s status=%d retry=%d/%d", prefix, resp.StatusCode, attempt, antigravityMaxRetries)
|
||||
sleepAntigravityBackoff(attempt)
|
||||
if !sleepAntigravityBackoffWithContext(ctx, attempt) {
|
||||
log.Printf("%s status=context_canceled_during_backoff", prefix)
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
continue
|
||||
}
|
||||
// 所有重试都失败,标记限流状态
|
||||
@@ -750,11 +972,18 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
|
||||
|
||||
break
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
defer func() {
|
||||
if resp != nil && resp.Body != nil {
|
||||
_ = resp.Body.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
// 处理错误响应
|
||||
if resp.StatusCode >= 400 {
|
||||
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||||
// 尽早关闭原始响应体,释放连接;后续逻辑仍可能需要读取 body,因此用内存副本重新包装。
|
||||
_ = resp.Body.Close()
|
||||
resp.Body = io.NopCloser(bytes.NewReader(respBody))
|
||||
|
||||
// 模型兜底:模型不存在且开启 fallback 时,自动用 fallback 模型重试一次
|
||||
if s.settingService != nil && s.settingService.IsModelFallbackEnabled(ctx) &&
|
||||
@@ -763,15 +992,13 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
|
||||
if fallbackModel != "" && fallbackModel != mappedModel {
|
||||
log.Printf("[Antigravity] Model not found (%s), retrying with fallback model %s (account: %s)", mappedModel, fallbackModel, account.Name)
|
||||
|
||||
// 关闭原始响应,释放连接(respBody 已读取到内存)
|
||||
_ = resp.Body.Close()
|
||||
|
||||
fallbackWrapped, err := s.wrapV1InternalRequest(projectID, fallbackModel, body)
|
||||
if err == nil {
|
||||
fallbackReq, err := antigravity.NewAPIRequest(ctx, upstreamAction, accessToken, fallbackWrapped)
|
||||
if err == nil {
|
||||
fallbackResp, err := s.httpUpstream.Do(fallbackReq, proxyURL, account.ID, account.Concurrency)
|
||||
if err == nil && fallbackResp.StatusCode < 400 {
|
||||
_ = resp.Body.Close()
|
||||
resp = fallbackResp
|
||||
} else if fallbackResp != nil {
|
||||
_ = fallbackResp.Body.Close()
|
||||
@@ -872,8 +1099,28 @@ func (s *AntigravityGatewayService) shouldFailoverUpstreamError(statusCode int)
|
||||
}
|
||||
}
|
||||
|
||||
func sleepAntigravityBackoff(attempt int) {
|
||||
sleepGeminiBackoff(attempt) // 复用 Gemini 的退避逻辑
|
||||
// sleepAntigravityBackoffWithContext 带 context 取消检查的退避等待
|
||||
// 返回 true 表示正常完成等待,false 表示 context 已取消
|
||||
func sleepAntigravityBackoffWithContext(ctx context.Context, attempt int) bool {
|
||||
delay := geminiRetryBaseDelay * time.Duration(1<<uint(attempt-1))
|
||||
if delay > geminiRetryMaxDelay {
|
||||
delay = geminiRetryMaxDelay
|
||||
}
|
||||
|
||||
// +/- 20% jitter
|
||||
r := mathrand.New(mathrand.NewSource(time.Now().UnixNano()))
|
||||
jitter := time.Duration(float64(delay) * 0.2 * (r.Float64()*2 - 1))
|
||||
sleepFor := delay + jitter
|
||||
if sleepFor < 0 {
|
||||
sleepFor = 0
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return false
|
||||
case <-time.After(sleepFor):
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
func (s *AntigravityGatewayService) handleUpstreamError(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte) {
|
||||
@@ -928,57 +1175,145 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context
|
||||
return nil, errors.New("streaming not supported")
|
||||
}
|
||||
|
||||
reader := bufio.NewReader(resp.Body)
|
||||
// 使用 Scanner 并限制单行大小,避免 ReadString 无上限导致 OOM
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
maxLineSize := defaultMaxLineSize
|
||||
if s.settingService.cfg != nil && s.settingService.cfg.Gateway.MaxLineSize > 0 {
|
||||
maxLineSize = s.settingService.cfg.Gateway.MaxLineSize
|
||||
}
|
||||
scanner.Buffer(make([]byte, 64*1024), maxLineSize)
|
||||
usage := &ClaudeUsage{}
|
||||
var firstTokenMs *int
|
||||
|
||||
type scanEvent struct {
|
||||
line string
|
||||
err error
|
||||
}
|
||||
// 独立 goroutine 读取上游,避免读取阻塞影响超时处理
|
||||
events := make(chan scanEvent, 16)
|
||||
done := make(chan struct{})
|
||||
sendEvent := func(ev scanEvent) bool {
|
||||
select {
|
||||
case events <- ev:
|
||||
return true
|
||||
case <-done:
|
||||
return false
|
||||
}
|
||||
}
|
||||
var lastReadAt int64
|
||||
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
|
||||
go func() {
|
||||
defer close(events)
|
||||
for scanner.Scan() {
|
||||
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
|
||||
if !sendEvent(scanEvent{line: scanner.Text()}) {
|
||||
return
|
||||
}
|
||||
}
|
||||
if err := scanner.Err(); err != nil {
|
||||
_ = sendEvent(scanEvent{err: err})
|
||||
}
|
||||
}()
|
||||
defer close(done)
|
||||
|
||||
// 上游数据间隔超时保护(防止上游挂起长期占用连接)
|
||||
streamInterval := time.Duration(0)
|
||||
if s.settingService.cfg != nil && s.settingService.cfg.Gateway.StreamDataIntervalTimeout > 0 {
|
||||
streamInterval = time.Duration(s.settingService.cfg.Gateway.StreamDataIntervalTimeout) * time.Second
|
||||
}
|
||||
var intervalTicker *time.Ticker
|
||||
if streamInterval > 0 {
|
||||
intervalTicker = time.NewTicker(streamInterval)
|
||||
defer intervalTicker.Stop()
|
||||
}
|
||||
var intervalCh <-chan time.Time
|
||||
if intervalTicker != nil {
|
||||
intervalCh = intervalTicker.C
|
||||
}
|
||||
|
||||
// 仅发送一次错误事件,避免多次写入导致协议混乱
|
||||
errorEventSent := false
|
||||
sendErrorEvent := func(reason string) {
|
||||
if errorEventSent {
|
||||
return
|
||||
}
|
||||
errorEventSent = true
|
||||
_, _ = fmt.Fprintf(c.Writer, "event: error\ndata: {\"error\":\"%s\"}\n\n", reason)
|
||||
flusher.Flush()
|
||||
}
|
||||
|
||||
for {
|
||||
line, err := reader.ReadString('\n')
|
||||
if len(line) > 0 {
|
||||
select {
|
||||
case ev, ok := <-events:
|
||||
if !ok {
|
||||
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}, nil
|
||||
}
|
||||
if ev.err != nil {
|
||||
if errors.Is(ev.err, bufio.ErrTooLong) {
|
||||
log.Printf("SSE line too long (antigravity): max_size=%d error=%v", maxLineSize, ev.err)
|
||||
sendErrorEvent("response_too_large")
|
||||
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}, ev.err
|
||||
}
|
||||
sendErrorEvent("stream_read_error")
|
||||
return nil, ev.err
|
||||
}
|
||||
|
||||
line := ev.line
|
||||
trimmed := strings.TrimRight(line, "\r\n")
|
||||
if strings.HasPrefix(trimmed, "data:") {
|
||||
payload := strings.TrimSpace(strings.TrimPrefix(trimmed, "data:"))
|
||||
if payload == "" || payload == "[DONE]" {
|
||||
_, _ = io.WriteString(c.Writer, line)
|
||||
flusher.Flush()
|
||||
} else {
|
||||
// 解包 v1internal 响应
|
||||
inner, parseErr := s.unwrapV1InternalResponse([]byte(payload))
|
||||
if parseErr == nil && inner != nil {
|
||||
payload = string(inner)
|
||||
}
|
||||
|
||||
// 解析 usage
|
||||
var parsed map[string]any
|
||||
if json.Unmarshal(inner, &parsed) == nil {
|
||||
if u := extractGeminiUsage(parsed); u != nil {
|
||||
usage = u
|
||||
}
|
||||
}
|
||||
|
||||
if firstTokenMs == nil {
|
||||
ms := int(time.Since(startTime).Milliseconds())
|
||||
firstTokenMs = &ms
|
||||
}
|
||||
|
||||
_, _ = fmt.Fprintf(c.Writer, "data: %s\n\n", payload)
|
||||
if _, err := fmt.Fprintf(c.Writer, "%s\n", line); err != nil {
|
||||
sendErrorEvent("write_failed")
|
||||
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}, err
|
||||
}
|
||||
flusher.Flush()
|
||||
continue
|
||||
}
|
||||
} else {
|
||||
_, _ = io.WriteString(c.Writer, line)
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
if errors.Is(err, io.EOF) {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
// 解包 v1internal 响应
|
||||
inner, parseErr := s.unwrapV1InternalResponse([]byte(payload))
|
||||
if parseErr == nil && inner != nil {
|
||||
payload = string(inner)
|
||||
}
|
||||
|
||||
// 解析 usage
|
||||
var parsed map[string]any
|
||||
if json.Unmarshal(inner, &parsed) == nil {
|
||||
if u := extractGeminiUsage(parsed); u != nil {
|
||||
usage = u
|
||||
}
|
||||
}
|
||||
|
||||
if firstTokenMs == nil {
|
||||
ms := int(time.Since(startTime).Milliseconds())
|
||||
firstTokenMs = &ms
|
||||
}
|
||||
|
||||
if _, err := fmt.Fprintf(c.Writer, "data: %s\n\n", payload); err != nil {
|
||||
sendErrorEvent("write_failed")
|
||||
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}, err
|
||||
}
|
||||
flusher.Flush()
|
||||
continue
|
||||
}
|
||||
|
||||
if _, err := fmt.Fprintf(c.Writer, "%s\n", line); err != nil {
|
||||
sendErrorEvent("write_failed")
|
||||
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}, err
|
||||
}
|
||||
flusher.Flush()
|
||||
|
||||
case <-intervalCh:
|
||||
lastRead := time.Unix(0, atomic.LoadInt64(&lastReadAt))
|
||||
if time.Since(lastRead) < streamInterval {
|
||||
continue
|
||||
}
|
||||
log.Printf("Stream data interval timeout (antigravity)")
|
||||
sendErrorEvent("stream_timeout")
|
||||
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout")
|
||||
}
|
||||
}
|
||||
|
||||
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}, nil
|
||||
}
|
||||
|
||||
func (s *AntigravityGatewayService) handleGeminiNonStreamingResponse(c *gin.Context, resp *http.Response) (*ClaudeUsage, error) {
|
||||
@@ -1117,7 +1452,13 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context
|
||||
|
||||
processor := antigravity.NewStreamingProcessor(originalModel)
|
||||
var firstTokenMs *int
|
||||
reader := bufio.NewReader(resp.Body)
|
||||
// 使用 Scanner 并限制单行大小,避免 ReadString 无上限导致 OOM
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
maxLineSize := defaultMaxLineSize
|
||||
if s.settingService.cfg != nil && s.settingService.cfg.Gateway.MaxLineSize > 0 {
|
||||
maxLineSize = s.settingService.cfg.Gateway.MaxLineSize
|
||||
}
|
||||
scanner.Buffer(make([]byte, 64*1024), maxLineSize)
|
||||
|
||||
// 辅助函数:转换 antigravity.ClaudeUsage 到 service.ClaudeUsage
|
||||
convertUsage := func(agUsage *antigravity.ClaudeUsage) *ClaudeUsage {
|
||||
@@ -1132,13 +1473,85 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context
|
||||
}
|
||||
}
|
||||
|
||||
for {
|
||||
line, err := reader.ReadString('\n')
|
||||
if err != nil && !errors.Is(err, io.EOF) {
|
||||
return nil, fmt.Errorf("stream read error: %w", err)
|
||||
type scanEvent struct {
|
||||
line string
|
||||
err error
|
||||
}
|
||||
// 独立 goroutine 读取上游,避免读取阻塞影响超时处理
|
||||
events := make(chan scanEvent, 16)
|
||||
done := make(chan struct{})
|
||||
sendEvent := func(ev scanEvent) bool {
|
||||
select {
|
||||
case events <- ev:
|
||||
return true
|
||||
case <-done:
|
||||
return false
|
||||
}
|
||||
}
|
||||
var lastReadAt int64
|
||||
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
|
||||
go func() {
|
||||
defer close(events)
|
||||
for scanner.Scan() {
|
||||
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
|
||||
if !sendEvent(scanEvent{line: scanner.Text()}) {
|
||||
return
|
||||
}
|
||||
}
|
||||
if err := scanner.Err(); err != nil {
|
||||
_ = sendEvent(scanEvent{err: err})
|
||||
}
|
||||
}()
|
||||
defer close(done)
|
||||
|
||||
if len(line) > 0 {
|
||||
streamInterval := time.Duration(0)
|
||||
if s.settingService.cfg != nil && s.settingService.cfg.Gateway.StreamDataIntervalTimeout > 0 {
|
||||
streamInterval = time.Duration(s.settingService.cfg.Gateway.StreamDataIntervalTimeout) * time.Second
|
||||
}
|
||||
var intervalTicker *time.Ticker
|
||||
if streamInterval > 0 {
|
||||
intervalTicker = time.NewTicker(streamInterval)
|
||||
defer intervalTicker.Stop()
|
||||
}
|
||||
var intervalCh <-chan time.Time
|
||||
if intervalTicker != nil {
|
||||
intervalCh = intervalTicker.C
|
||||
}
|
||||
|
||||
// 仅发送一次错误事件,避免多次写入导致协议混乱
|
||||
errorEventSent := false
|
||||
sendErrorEvent := func(reason string) {
|
||||
if errorEventSent {
|
||||
return
|
||||
}
|
||||
errorEventSent = true
|
||||
_, _ = fmt.Fprintf(c.Writer, "event: error\ndata: {\"error\":\"%s\"}\n\n", reason)
|
||||
flusher.Flush()
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
case ev, ok := <-events:
|
||||
if !ok {
|
||||
// 发送结束事件
|
||||
finalEvents, agUsage := processor.Finish()
|
||||
if len(finalEvents) > 0 {
|
||||
_, _ = c.Writer.Write(finalEvents)
|
||||
flusher.Flush()
|
||||
}
|
||||
return &antigravityStreamResult{usage: convertUsage(agUsage), firstTokenMs: firstTokenMs}, nil
|
||||
}
|
||||
if ev.err != nil {
|
||||
if errors.Is(ev.err, bufio.ErrTooLong) {
|
||||
log.Printf("SSE line too long (antigravity): max_size=%d error=%v", maxLineSize, ev.err)
|
||||
sendErrorEvent("response_too_large")
|
||||
return &antigravityStreamResult{usage: convertUsage(nil), firstTokenMs: firstTokenMs}, ev.err
|
||||
}
|
||||
sendErrorEvent("stream_read_error")
|
||||
return nil, fmt.Errorf("stream read error: %w", ev.err)
|
||||
}
|
||||
|
||||
line := ev.line
|
||||
// 处理 SSE 行,转换为 Claude 格式
|
||||
claudeEvents := processor.ProcessLine(strings.TrimRight(line, "\r\n"))
|
||||
|
||||
@@ -1153,25 +1566,23 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context
|
||||
if len(finalEvents) > 0 {
|
||||
_, _ = c.Writer.Write(finalEvents)
|
||||
}
|
||||
sendErrorEvent("write_failed")
|
||||
return &antigravityStreamResult{usage: convertUsage(agUsage), firstTokenMs: firstTokenMs}, writeErr
|
||||
}
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
if errors.Is(err, io.EOF) {
|
||||
break
|
||||
case <-intervalCh:
|
||||
lastRead := time.Unix(0, atomic.LoadInt64(&lastReadAt))
|
||||
if time.Since(lastRead) < streamInterval {
|
||||
continue
|
||||
}
|
||||
log.Printf("Stream data interval timeout (antigravity)")
|
||||
sendErrorEvent("stream_timeout")
|
||||
return &antigravityStreamResult{usage: convertUsage(nil), firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout")
|
||||
}
|
||||
}
|
||||
|
||||
// 发送结束事件
|
||||
finalEvents, agUsage := processor.Finish()
|
||||
if len(finalEvents) > 0 {
|
||||
_, _ = c.Writer.Write(finalEvents)
|
||||
flusher.Flush()
|
||||
}
|
||||
|
||||
return &antigravityStreamResult{usage: convertUsage(agUsage), firstTokenMs: firstTokenMs}, nil
|
||||
}
|
||||
|
||||
// extractImageSize 从 Gemini 请求中提取 image_size 参数
|
||||
|
||||
83
backend/internal/service/antigravity_gateway_service_test.go
Normal file
83
backend/internal/service/antigravity_gateway_service_test.go
Normal file
@@ -0,0 +1,83 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestStripSignatureSensitiveBlocksFromClaudeRequest(t *testing.T) {
|
||||
req := &antigravity.ClaudeRequest{
|
||||
Model: "claude-sonnet-4-5",
|
||||
Thinking: &antigravity.ThinkingConfig{
|
||||
Type: "enabled",
|
||||
BudgetTokens: 1024,
|
||||
},
|
||||
Messages: []antigravity.ClaudeMessage{
|
||||
{
|
||||
Role: "assistant",
|
||||
Content: json.RawMessage(`[
|
||||
{"type":"thinking","thinking":"secret plan","signature":""},
|
||||
{"type":"tool_use","id":"t1","name":"Bash","input":{"command":"ls"}}
|
||||
]`),
|
||||
},
|
||||
{
|
||||
Role: "user",
|
||||
Content: json.RawMessage(`[
|
||||
{"type":"tool_result","tool_use_id":"t1","content":"ok","is_error":false},
|
||||
{"type":"redacted_thinking","data":"..."}
|
||||
]`),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
changed, err := stripSignatureSensitiveBlocksFromClaudeRequest(req)
|
||||
require.NoError(t, err)
|
||||
require.True(t, changed)
|
||||
require.Nil(t, req.Thinking)
|
||||
|
||||
require.Len(t, req.Messages, 2)
|
||||
|
||||
var blocks0 []map[string]any
|
||||
require.NoError(t, json.Unmarshal(req.Messages[0].Content, &blocks0))
|
||||
require.Len(t, blocks0, 2)
|
||||
require.Equal(t, "text", blocks0[0]["type"])
|
||||
require.Equal(t, "secret plan", blocks0[0]["text"])
|
||||
require.Equal(t, "text", blocks0[1]["type"])
|
||||
|
||||
var blocks1 []map[string]any
|
||||
require.NoError(t, json.Unmarshal(req.Messages[1].Content, &blocks1))
|
||||
require.Len(t, blocks1, 1)
|
||||
require.Equal(t, "text", blocks1[0]["type"])
|
||||
require.NotEmpty(t, blocks1[0]["text"])
|
||||
}
|
||||
|
||||
func TestStripThinkingFromClaudeRequest_DoesNotDowngradeTools(t *testing.T) {
|
||||
req := &antigravity.ClaudeRequest{
|
||||
Model: "claude-sonnet-4-5",
|
||||
Thinking: &antigravity.ThinkingConfig{
|
||||
Type: "enabled",
|
||||
BudgetTokens: 1024,
|
||||
},
|
||||
Messages: []antigravity.ClaudeMessage{
|
||||
{
|
||||
Role: "assistant",
|
||||
Content: json.RawMessage(`[{"type":"thinking","thinking":"secret plan"},{"type":"tool_use","id":"t1","name":"Bash","input":{"command":"ls"}}]`),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
changed, err := stripThinkingFromClaudeRequest(req)
|
||||
require.NoError(t, err)
|
||||
require.True(t, changed)
|
||||
require.Nil(t, req.Thinking)
|
||||
|
||||
var blocks []map[string]any
|
||||
require.NoError(t, json.Unmarshal(req.Messages[0].Content, &blocks))
|
||||
require.Len(t, blocks, 2)
|
||||
require.Equal(t, "text", blocks[0]["type"])
|
||||
require.Equal(t, "secret plan", blocks[0]["text"])
|
||||
require.Equal(t, "tool_use", blocks[1]["type"])
|
||||
}
|
||||
@@ -221,9 +221,33 @@ func (s *AuthService) SendVerifyCodeAsync(ctx context.Context, email string) (*S
|
||||
|
||||
// VerifyTurnstile 验证Turnstile token
|
||||
func (s *AuthService) VerifyTurnstile(ctx context.Context, token string, remoteIP string) error {
|
||||
required := s.cfg != nil && s.cfg.Server.Mode == "release" && s.cfg.Turnstile.Required
|
||||
|
||||
if required {
|
||||
if s.settingService == nil {
|
||||
log.Println("[Auth] Turnstile required but settings service is not configured")
|
||||
return ErrTurnstileNotConfigured
|
||||
}
|
||||
enabled := s.settingService.IsTurnstileEnabled(ctx)
|
||||
secretConfigured := s.settingService.GetTurnstileSecretKey(ctx) != ""
|
||||
if !enabled || !secretConfigured {
|
||||
log.Printf("[Auth] Turnstile required but not configured (enabled=%v, secret_configured=%v)", enabled, secretConfigured)
|
||||
return ErrTurnstileNotConfigured
|
||||
}
|
||||
}
|
||||
|
||||
if s.turnstileService == nil {
|
||||
if required {
|
||||
log.Println("[Auth] Turnstile required but service not configured")
|
||||
return ErrTurnstileNotConfigured
|
||||
}
|
||||
return nil // 服务未配置则跳过验证
|
||||
}
|
||||
|
||||
if !required && s.settingService != nil && s.settingService.IsTurnstileEnabled(ctx) && s.settingService.GetTurnstileSecretKey(ctx) == "" {
|
||||
log.Println("[Auth] Turnstile enabled but secret key not configured")
|
||||
}
|
||||
|
||||
return s.turnstileService.VerifyToken(ctx, token, remoteIP)
|
||||
}
|
||||
|
||||
|
||||
@@ -16,7 +16,8 @@ import (
|
||||
// 注:ErrInsufficientBalance在redeem_service.go中定义
|
||||
// 注:ErrDailyLimitExceeded/ErrWeeklyLimitExceeded/ErrMonthlyLimitExceeded在subscription_service.go中定义
|
||||
var (
|
||||
ErrSubscriptionInvalid = infraerrors.Forbidden("SUBSCRIPTION_INVALID", "subscription is invalid or expired")
|
||||
ErrSubscriptionInvalid = infraerrors.Forbidden("SUBSCRIPTION_INVALID", "subscription is invalid or expired")
|
||||
ErrBillingServiceUnavailable = infraerrors.ServiceUnavailable("BILLING_SERVICE_ERROR", "Billing service temporarily unavailable. Please retry later.")
|
||||
)
|
||||
|
||||
// subscriptionCacheData 订阅缓存数据结构(内部使用)
|
||||
@@ -72,10 +73,11 @@ type cacheWriteTask struct {
|
||||
// BillingCacheService 计费缓存服务
|
||||
// 负责余额和订阅数据的缓存管理,提供高性能的计费资格检查
|
||||
type BillingCacheService struct {
|
||||
cache BillingCache
|
||||
userRepo UserRepository
|
||||
subRepo UserSubscriptionRepository
|
||||
cfg *config.Config
|
||||
cache BillingCache
|
||||
userRepo UserRepository
|
||||
subRepo UserSubscriptionRepository
|
||||
cfg *config.Config
|
||||
circuitBreaker *billingCircuitBreaker
|
||||
|
||||
cacheWriteChan chan cacheWriteTask
|
||||
cacheWriteWg sync.WaitGroup
|
||||
@@ -95,6 +97,7 @@ func NewBillingCacheService(cache BillingCache, userRepo UserRepository, subRepo
|
||||
subRepo: subRepo,
|
||||
cfg: cfg,
|
||||
}
|
||||
svc.circuitBreaker = newBillingCircuitBreaker(cfg.Billing.CircuitBreaker)
|
||||
svc.startCacheWriteWorkers()
|
||||
return svc
|
||||
}
|
||||
@@ -450,6 +453,9 @@ func (s *BillingCacheService) CheckBillingEligibility(ctx context.Context, user
|
||||
if s.cfg.RunMode == config.RunModeSimple {
|
||||
return nil
|
||||
}
|
||||
if s.circuitBreaker != nil && !s.circuitBreaker.Allow() {
|
||||
return ErrBillingServiceUnavailable
|
||||
}
|
||||
|
||||
// 判断计费模式
|
||||
isSubscriptionMode := group != nil && group.IsSubscriptionType() && subscription != nil
|
||||
@@ -465,9 +471,14 @@ func (s *BillingCacheService) CheckBillingEligibility(ctx context.Context, user
|
||||
func (s *BillingCacheService) checkBalanceEligibility(ctx context.Context, userID int64) error {
|
||||
balance, err := s.GetUserBalance(ctx, userID)
|
||||
if err != nil {
|
||||
// 缓存/数据库错误,允许通过(降级处理)
|
||||
log.Printf("Warning: get user balance failed, allowing request: %v", err)
|
||||
return nil
|
||||
if s.circuitBreaker != nil {
|
||||
s.circuitBreaker.OnFailure(err)
|
||||
}
|
||||
log.Printf("ALERT: billing balance check failed for user %d: %v", userID, err)
|
||||
return ErrBillingServiceUnavailable.WithCause(err)
|
||||
}
|
||||
if s.circuitBreaker != nil {
|
||||
s.circuitBreaker.OnSuccess()
|
||||
}
|
||||
|
||||
if balance <= 0 {
|
||||
@@ -482,9 +493,14 @@ func (s *BillingCacheService) checkSubscriptionEligibility(ctx context.Context,
|
||||
// 获取订阅缓存数据
|
||||
subData, err := s.GetSubscriptionStatus(ctx, userID, group.ID)
|
||||
if err != nil {
|
||||
// 缓存/数据库错误,降级使用传入的subscription进行检查
|
||||
log.Printf("Warning: get subscription cache failed, using fallback: %v", err)
|
||||
return s.checkSubscriptionLimitsFallback(subscription, group)
|
||||
if s.circuitBreaker != nil {
|
||||
s.circuitBreaker.OnFailure(err)
|
||||
}
|
||||
log.Printf("ALERT: billing subscription check failed for user %d group %d: %v", userID, group.ID, err)
|
||||
return ErrBillingServiceUnavailable.WithCause(err)
|
||||
}
|
||||
if s.circuitBreaker != nil {
|
||||
s.circuitBreaker.OnSuccess()
|
||||
}
|
||||
|
||||
// 检查订阅状态
|
||||
@@ -513,27 +529,133 @@ func (s *BillingCacheService) checkSubscriptionEligibility(ctx context.Context,
|
||||
return nil
|
||||
}
|
||||
|
||||
// checkSubscriptionLimitsFallback 降级检查订阅限额
|
||||
func (s *BillingCacheService) checkSubscriptionLimitsFallback(subscription *UserSubscription, group *Group) error {
|
||||
if subscription == nil {
|
||||
return ErrSubscriptionInvalid
|
||||
}
|
||||
type billingCircuitBreakerState int
|
||||
|
||||
if !subscription.IsActive() {
|
||||
return ErrSubscriptionInvalid
|
||||
}
|
||||
const (
|
||||
billingCircuitClosed billingCircuitBreakerState = iota
|
||||
billingCircuitOpen
|
||||
billingCircuitHalfOpen
|
||||
)
|
||||
|
||||
if !subscription.CheckDailyLimit(group, 0) {
|
||||
return ErrDailyLimitExceeded
|
||||
}
|
||||
|
||||
if !subscription.CheckWeeklyLimit(group, 0) {
|
||||
return ErrWeeklyLimitExceeded
|
||||
}
|
||||
|
||||
if !subscription.CheckMonthlyLimit(group, 0) {
|
||||
return ErrMonthlyLimitExceeded
|
||||
}
|
||||
|
||||
return nil
|
||||
type billingCircuitBreaker struct {
|
||||
mu sync.Mutex
|
||||
state billingCircuitBreakerState
|
||||
failures int
|
||||
openedAt time.Time
|
||||
failureThreshold int
|
||||
resetTimeout time.Duration
|
||||
halfOpenRequests int
|
||||
halfOpenRemaining int
|
||||
}
|
||||
|
||||
func newBillingCircuitBreaker(cfg config.CircuitBreakerConfig) *billingCircuitBreaker {
|
||||
if !cfg.Enabled {
|
||||
return nil
|
||||
}
|
||||
resetTimeout := time.Duration(cfg.ResetTimeoutSeconds) * time.Second
|
||||
if resetTimeout <= 0 {
|
||||
resetTimeout = 30 * time.Second
|
||||
}
|
||||
halfOpen := cfg.HalfOpenRequests
|
||||
if halfOpen <= 0 {
|
||||
halfOpen = 1
|
||||
}
|
||||
threshold := cfg.FailureThreshold
|
||||
if threshold <= 0 {
|
||||
threshold = 5
|
||||
}
|
||||
return &billingCircuitBreaker{
|
||||
state: billingCircuitClosed,
|
||||
failureThreshold: threshold,
|
||||
resetTimeout: resetTimeout,
|
||||
halfOpenRequests: halfOpen,
|
||||
}
|
||||
}
|
||||
|
||||
func (b *billingCircuitBreaker) Allow() bool {
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
|
||||
switch b.state {
|
||||
case billingCircuitClosed:
|
||||
return true
|
||||
case billingCircuitOpen:
|
||||
if time.Since(b.openedAt) < b.resetTimeout {
|
||||
return false
|
||||
}
|
||||
b.state = billingCircuitHalfOpen
|
||||
b.halfOpenRemaining = b.halfOpenRequests
|
||||
log.Printf("ALERT: billing circuit breaker entering half-open state")
|
||||
fallthrough
|
||||
case billingCircuitHalfOpen:
|
||||
if b.halfOpenRemaining <= 0 {
|
||||
return false
|
||||
}
|
||||
b.halfOpenRemaining--
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func (b *billingCircuitBreaker) OnFailure(err error) {
|
||||
if b == nil {
|
||||
return
|
||||
}
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
|
||||
switch b.state {
|
||||
case billingCircuitOpen:
|
||||
return
|
||||
case billingCircuitHalfOpen:
|
||||
b.state = billingCircuitOpen
|
||||
b.openedAt = time.Now()
|
||||
b.halfOpenRemaining = 0
|
||||
log.Printf("ALERT: billing circuit breaker opened after half-open failure: %v", err)
|
||||
return
|
||||
default:
|
||||
b.failures++
|
||||
if b.failures >= b.failureThreshold {
|
||||
b.state = billingCircuitOpen
|
||||
b.openedAt = time.Now()
|
||||
b.halfOpenRemaining = 0
|
||||
log.Printf("ALERT: billing circuit breaker opened after %d failures: %v", b.failures, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (b *billingCircuitBreaker) OnSuccess() {
|
||||
if b == nil {
|
||||
return
|
||||
}
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
|
||||
previousState := b.state
|
||||
previousFailures := b.failures
|
||||
|
||||
b.state = billingCircuitClosed
|
||||
b.failures = 0
|
||||
b.halfOpenRemaining = 0
|
||||
|
||||
// 只有状态真正发生变化时才记录日志
|
||||
if previousState != billingCircuitClosed {
|
||||
log.Printf("ALERT: billing circuit breaker closed (was %s)", circuitStateString(previousState))
|
||||
} else if previousFailures > 0 {
|
||||
log.Printf("INFO: billing circuit breaker failures reset from %d", previousFailures)
|
||||
}
|
||||
}
|
||||
|
||||
func circuitStateString(state billingCircuitBreakerState) string {
|
||||
switch state {
|
||||
case billingCircuitClosed:
|
||||
return "closed"
|
||||
case billingCircuitOpen:
|
||||
return "open"
|
||||
case billingCircuitHalfOpen:
|
||||
return "half-open"
|
||||
default:
|
||||
return "unknown"
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8,12 +8,13 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
|
||||
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
|
||||
)
|
||||
|
||||
type CRSSyncService struct {
|
||||
@@ -22,6 +23,7 @@ type CRSSyncService struct {
|
||||
oauthService *OAuthService
|
||||
openaiOAuthService *OpenAIOAuthService
|
||||
geminiOAuthService *GeminiOAuthService
|
||||
cfg *config.Config
|
||||
}
|
||||
|
||||
func NewCRSSyncService(
|
||||
@@ -30,6 +32,7 @@ func NewCRSSyncService(
|
||||
oauthService *OAuthService,
|
||||
openaiOAuthService *OpenAIOAuthService,
|
||||
geminiOAuthService *GeminiOAuthService,
|
||||
cfg *config.Config,
|
||||
) *CRSSyncService {
|
||||
return &CRSSyncService{
|
||||
accountRepo: accountRepo,
|
||||
@@ -37,6 +40,7 @@ func NewCRSSyncService(
|
||||
oauthService: oauthService,
|
||||
openaiOAuthService: openaiOAuthService,
|
||||
geminiOAuthService: geminiOAuthService,
|
||||
cfg: cfg,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -187,16 +191,31 @@ type crsGeminiAPIKeyAccount struct {
|
||||
}
|
||||
|
||||
func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput) (*SyncFromCRSResult, error) {
|
||||
baseURL, err := normalizeBaseURL(input.BaseURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
if s.cfg == nil {
|
||||
return nil, errors.New("config is not available")
|
||||
}
|
||||
baseURL := strings.TrimSpace(input.BaseURL)
|
||||
if s.cfg.Security.URLAllowlist.Enabled {
|
||||
normalized, err := normalizeBaseURL(baseURL, s.cfg.Security.URLAllowlist.CRSHosts, s.cfg.Security.URLAllowlist.AllowPrivateHosts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
baseURL = normalized
|
||||
} else {
|
||||
normalized, err := urlvalidator.ValidateURLFormat(baseURL, s.cfg.Security.URLAllowlist.AllowInsecureHTTP)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid base_url: %w", err)
|
||||
}
|
||||
baseURL = normalized
|
||||
}
|
||||
if strings.TrimSpace(input.Username) == "" || strings.TrimSpace(input.Password) == "" {
|
||||
return nil, errors.New("username and password are required")
|
||||
}
|
||||
|
||||
client, err := httpclient.GetClient(httpclient.Options{
|
||||
Timeout: 20 * time.Second,
|
||||
Timeout: 20 * time.Second,
|
||||
ValidateResolvedIP: s.cfg.Security.URLAllowlist.Enabled,
|
||||
AllowPrivateHosts: s.cfg.Security.URLAllowlist.AllowPrivateHosts,
|
||||
})
|
||||
if err != nil {
|
||||
client = &http.Client{Timeout: 20 * time.Second}
|
||||
@@ -1055,17 +1074,18 @@ func mapCRSStatus(isActive bool, status string) string {
|
||||
return "active"
|
||||
}
|
||||
|
||||
func normalizeBaseURL(raw string) (string, error) {
|
||||
trimmed := strings.TrimSpace(raw)
|
||||
if trimmed == "" {
|
||||
return "", errors.New("base_url is required")
|
||||
func normalizeBaseURL(raw string, allowlist []string, allowPrivate bool) (string, error) {
|
||||
// 当 allowlist 为空时,不强制要求白名单(只进行基本的 URL 和 SSRF 验证)
|
||||
requireAllowlist := len(allowlist) > 0
|
||||
normalized, err := urlvalidator.ValidateHTTPSURL(raw, urlvalidator.ValidationOptions{
|
||||
AllowedHosts: allowlist,
|
||||
RequireAllowlist: requireAllowlist,
|
||||
AllowPrivate: allowPrivate,
|
||||
})
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("invalid base_url: %w", err)
|
||||
}
|
||||
u, err := url.Parse(trimmed)
|
||||
if err != nil || u.Scheme == "" || u.Host == "" {
|
||||
return "", fmt.Errorf("invalid base_url: %s", trimmed)
|
||||
}
|
||||
u.Path = strings.TrimRight(u.Path, "/")
|
||||
return strings.TrimRight(u.String(), "/"), nil
|
||||
return normalized, nil
|
||||
}
|
||||
|
||||
// cleanBaseURL removes trailing suffix from base_url in credentials
|
||||
|
||||
@@ -101,6 +101,10 @@ const (
|
||||
SettingKeyFallbackModelOpenAI = "fallback_model_openai"
|
||||
SettingKeyFallbackModelGemini = "fallback_model_gemini"
|
||||
SettingKeyFallbackModelAntigravity = "fallback_model_antigravity"
|
||||
|
||||
// Request identity patch (Claude -> Gemini systemInstruction injection)
|
||||
SettingKeyEnableIdentityPatch = "enable_identity_patch"
|
||||
SettingKeyIdentityPatchPrompt = "identity_patch_prompt"
|
||||
)
|
||||
|
||||
// AdminAPIKeyPrefix is the prefix for admin API keys (distinct from user "sk-" keys).
|
||||
|
||||
@@ -84,25 +84,37 @@ func FilterThinkingBlocks(body []byte) []byte {
|
||||
return filterThinkingBlocksInternal(body, false)
|
||||
}
|
||||
|
||||
// FilterThinkingBlocksForRetry removes thinking blocks from HISTORICAL messages for retry scenarios.
|
||||
// This is used when upstream returns signature-related 400 errors.
|
||||
// FilterThinkingBlocksForRetry strips thinking-related constructs for retry scenarios.
|
||||
//
|
||||
// Key insight:
|
||||
// - User's thinking.type = "enabled" should be PRESERVED (user's intent)
|
||||
// - Only HISTORICAL assistant messages have thinking blocks with signatures
|
||||
// - These signatures may be invalid when switching accounts/platforms
|
||||
// - New responses will generate fresh thinking blocks without signature issues
|
||||
// Why:
|
||||
// - Upstreams may reject historical `thinking`/`redacted_thinking` blocks due to invalid/missing signatures.
|
||||
// - Anthropic extended thinking has a structural constraint: when top-level `thinking` is enabled and the
|
||||
// final message is an assistant prefill, the assistant content must start with a thinking block.
|
||||
// - If we remove thinking blocks but keep top-level `thinking` enabled, we can trigger:
|
||||
// "Expected `thinking` or `redacted_thinking`, but found `text`"
|
||||
//
|
||||
// Strategy:
|
||||
// - Keep thinking.type = "enabled" (preserve user intent)
|
||||
// - Remove thinking/redacted_thinking blocks from historical assistant messages
|
||||
// - Ensure no message has empty content after filtering
|
||||
// Strategy (B: preserve content as text):
|
||||
// - Disable top-level `thinking` (remove `thinking` field).
|
||||
// - Convert `thinking` blocks to `text` blocks (preserve the thinking content).
|
||||
// - Remove `redacted_thinking` blocks (cannot be converted to text).
|
||||
// - Ensure no message ends up with empty content.
|
||||
func FilterThinkingBlocksForRetry(body []byte) []byte {
|
||||
// Fast path: check for presence of thinking-related keys in messages
|
||||
if !bytes.Contains(body, []byte(`"type":"thinking"`)) &&
|
||||
!bytes.Contains(body, []byte(`"type": "thinking"`)) &&
|
||||
!bytes.Contains(body, []byte(`"type":"redacted_thinking"`)) &&
|
||||
!bytes.Contains(body, []byte(`"type": "redacted_thinking"`)) {
|
||||
hasThinkingContent := bytes.Contains(body, []byte(`"type":"thinking"`)) ||
|
||||
bytes.Contains(body, []byte(`"type": "thinking"`)) ||
|
||||
bytes.Contains(body, []byte(`"type":"redacted_thinking"`)) ||
|
||||
bytes.Contains(body, []byte(`"type": "redacted_thinking"`)) ||
|
||||
bytes.Contains(body, []byte(`"thinking":`)) ||
|
||||
bytes.Contains(body, []byte(`"thinking" :`))
|
||||
|
||||
// Also check for empty content arrays that need fixing.
|
||||
// Note: This is a heuristic check; the actual empty content handling is done below.
|
||||
hasEmptyContent := bytes.Contains(body, []byte(`"content":[]`)) ||
|
||||
bytes.Contains(body, []byte(`"content": []`)) ||
|
||||
bytes.Contains(body, []byte(`"content" : []`)) ||
|
||||
bytes.Contains(body, []byte(`"content" :[]`))
|
||||
|
||||
// Fast path: nothing to process
|
||||
if !hasThinkingContent && !hasEmptyContent {
|
||||
return body
|
||||
}
|
||||
|
||||
@@ -111,15 +123,19 @@ func FilterThinkingBlocksForRetry(body []byte) []byte {
|
||||
return body
|
||||
}
|
||||
|
||||
// DO NOT modify thinking.type - preserve user's intent to use thinking mode
|
||||
// The issue is with historical message signatures, not the thinking mode itself
|
||||
modified := false
|
||||
|
||||
messages, ok := req["messages"].([]any)
|
||||
if !ok {
|
||||
return body
|
||||
}
|
||||
|
||||
modified := false
|
||||
// Disable top-level thinking mode for retry to avoid structural/signature constraints upstream.
|
||||
if _, exists := req["thinking"]; exists {
|
||||
delete(req, "thinking")
|
||||
modified = true
|
||||
}
|
||||
|
||||
newMessages := make([]any, 0, len(messages))
|
||||
|
||||
for _, msg := range messages {
|
||||
@@ -149,33 +165,59 @@ func FilterThinkingBlocksForRetry(body []byte) []byte {
|
||||
|
||||
blockType, _ := blockMap["type"].(string)
|
||||
|
||||
// Remove thinking/redacted_thinking blocks from historical messages
|
||||
// These have signatures that may be invalid across different accounts
|
||||
if blockType == "thinking" || blockType == "redacted_thinking" {
|
||||
// Convert thinking blocks to text (preserve content) and drop redacted_thinking.
|
||||
switch blockType {
|
||||
case "thinking":
|
||||
modifiedThisMsg = true
|
||||
thinkingText, _ := blockMap["thinking"].(string)
|
||||
if thinkingText == "" {
|
||||
continue
|
||||
}
|
||||
newContent = append(newContent, map[string]any{
|
||||
"type": "text",
|
||||
"text": thinkingText,
|
||||
})
|
||||
continue
|
||||
case "redacted_thinking":
|
||||
modifiedThisMsg = true
|
||||
continue
|
||||
}
|
||||
|
||||
// Handle blocks without type discriminator but with a "thinking" field.
|
||||
if blockType == "" {
|
||||
if rawThinking, hasThinking := blockMap["thinking"]; hasThinking {
|
||||
modifiedThisMsg = true
|
||||
switch v := rawThinking.(type) {
|
||||
case string:
|
||||
if v != "" {
|
||||
newContent = append(newContent, map[string]any{"type": "text", "text": v})
|
||||
}
|
||||
default:
|
||||
if b, err := json.Marshal(v); err == nil && len(b) > 0 {
|
||||
newContent = append(newContent, map[string]any{"type": "text", "text": string(b)})
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
newContent = append(newContent, block)
|
||||
}
|
||||
|
||||
if modifiedThisMsg {
|
||||
// Handle empty content: either from filtering or originally empty
|
||||
if len(newContent) == 0 {
|
||||
modified = true
|
||||
// Handle empty content after filtering
|
||||
if len(newContent) == 0 {
|
||||
// For assistant messages, skip entirely (remove from conversation)
|
||||
// For user messages, add placeholder to avoid empty content error
|
||||
if role == "user" {
|
||||
newContent = append(newContent, map[string]any{
|
||||
"type": "text",
|
||||
"text": "(content removed)",
|
||||
})
|
||||
msgMap["content"] = newContent
|
||||
newMessages = append(newMessages, msgMap)
|
||||
}
|
||||
// Skip assistant messages with empty content (don't append)
|
||||
continue
|
||||
placeholder := "(content removed)"
|
||||
if role == "assistant" {
|
||||
placeholder = "(assistant content removed)"
|
||||
}
|
||||
newContent = append(newContent, map[string]any{
|
||||
"type": "text",
|
||||
"text": placeholder,
|
||||
})
|
||||
msgMap["content"] = newContent
|
||||
} else if modifiedThisMsg {
|
||||
modified = true
|
||||
msgMap["content"] = newContent
|
||||
}
|
||||
newMessages = append(newMessages, msgMap)
|
||||
@@ -183,6 +225,9 @@ func FilterThinkingBlocksForRetry(body []byte) []byte {
|
||||
|
||||
if modified {
|
||||
req["messages"] = newMessages
|
||||
} else {
|
||||
// Avoid rewriting JSON when no changes are needed.
|
||||
return body
|
||||
}
|
||||
|
||||
newBody, err := json.Marshal(req)
|
||||
@@ -192,6 +237,172 @@ func FilterThinkingBlocksForRetry(body []byte) []byte {
|
||||
return newBody
|
||||
}
|
||||
|
||||
// FilterSignatureSensitiveBlocksForRetry is a stronger retry filter for cases where upstream errors indicate
|
||||
// signature/thought_signature validation issues involving tool blocks.
|
||||
//
|
||||
// This performs everything in FilterThinkingBlocksForRetry, plus:
|
||||
// - Convert `tool_use` blocks to text (name/id/input) so we stop sending structured tool calls.
|
||||
// - Convert `tool_result` blocks to text so we keep tool results visible without tool semantics.
|
||||
//
|
||||
// Use this only when needed: converting tool blocks to text changes model behaviour and can increase the
|
||||
// risk of prompt injection (tool output becomes plain conversation text).
|
||||
func FilterSignatureSensitiveBlocksForRetry(body []byte) []byte {
|
||||
// Fast path: only run when we see likely relevant constructs.
|
||||
if !bytes.Contains(body, []byte(`"type":"thinking"`)) &&
|
||||
!bytes.Contains(body, []byte(`"type": "thinking"`)) &&
|
||||
!bytes.Contains(body, []byte(`"type":"redacted_thinking"`)) &&
|
||||
!bytes.Contains(body, []byte(`"type": "redacted_thinking"`)) &&
|
||||
!bytes.Contains(body, []byte(`"type":"tool_use"`)) &&
|
||||
!bytes.Contains(body, []byte(`"type": "tool_use"`)) &&
|
||||
!bytes.Contains(body, []byte(`"type":"tool_result"`)) &&
|
||||
!bytes.Contains(body, []byte(`"type": "tool_result"`)) &&
|
||||
!bytes.Contains(body, []byte(`"thinking":`)) &&
|
||||
!bytes.Contains(body, []byte(`"thinking" :`)) {
|
||||
return body
|
||||
}
|
||||
|
||||
var req map[string]any
|
||||
if err := json.Unmarshal(body, &req); err != nil {
|
||||
return body
|
||||
}
|
||||
|
||||
modified := false
|
||||
|
||||
// Disable top-level thinking for retry to avoid structural/signature constraints upstream.
|
||||
if _, exists := req["thinking"]; exists {
|
||||
delete(req, "thinking")
|
||||
modified = true
|
||||
}
|
||||
|
||||
messages, ok := req["messages"].([]any)
|
||||
if !ok {
|
||||
return body
|
||||
}
|
||||
|
||||
newMessages := make([]any, 0, len(messages))
|
||||
|
||||
for _, msg := range messages {
|
||||
msgMap, ok := msg.(map[string]any)
|
||||
if !ok {
|
||||
newMessages = append(newMessages, msg)
|
||||
continue
|
||||
}
|
||||
|
||||
role, _ := msgMap["role"].(string)
|
||||
content, ok := msgMap["content"].([]any)
|
||||
if !ok {
|
||||
newMessages = append(newMessages, msg)
|
||||
continue
|
||||
}
|
||||
|
||||
newContent := make([]any, 0, len(content))
|
||||
modifiedThisMsg := false
|
||||
|
||||
for _, block := range content {
|
||||
blockMap, ok := block.(map[string]any)
|
||||
if !ok {
|
||||
newContent = append(newContent, block)
|
||||
continue
|
||||
}
|
||||
|
||||
blockType, _ := blockMap["type"].(string)
|
||||
switch blockType {
|
||||
case "thinking":
|
||||
modifiedThisMsg = true
|
||||
thinkingText, _ := blockMap["thinking"].(string)
|
||||
if thinkingText == "" {
|
||||
continue
|
||||
}
|
||||
newContent = append(newContent, map[string]any{"type": "text", "text": thinkingText})
|
||||
continue
|
||||
case "redacted_thinking":
|
||||
modifiedThisMsg = true
|
||||
continue
|
||||
case "tool_use":
|
||||
modifiedThisMsg = true
|
||||
name, _ := blockMap["name"].(string)
|
||||
id, _ := blockMap["id"].(string)
|
||||
input := blockMap["input"]
|
||||
inputJSON, _ := json.Marshal(input)
|
||||
text := "(tool_use)"
|
||||
if name != "" {
|
||||
text += " name=" + name
|
||||
}
|
||||
if id != "" {
|
||||
text += " id=" + id
|
||||
}
|
||||
if len(inputJSON) > 0 && string(inputJSON) != "null" {
|
||||
text += " input=" + string(inputJSON)
|
||||
}
|
||||
newContent = append(newContent, map[string]any{"type": "text", "text": text})
|
||||
continue
|
||||
case "tool_result":
|
||||
modifiedThisMsg = true
|
||||
toolUseID, _ := blockMap["tool_use_id"].(string)
|
||||
isError, _ := blockMap["is_error"].(bool)
|
||||
content := blockMap["content"]
|
||||
contentJSON, _ := json.Marshal(content)
|
||||
text := "(tool_result)"
|
||||
if toolUseID != "" {
|
||||
text += " tool_use_id=" + toolUseID
|
||||
}
|
||||
if isError {
|
||||
text += " is_error=true"
|
||||
}
|
||||
if len(contentJSON) > 0 && string(contentJSON) != "null" {
|
||||
text += "\n" + string(contentJSON)
|
||||
}
|
||||
newContent = append(newContent, map[string]any{"type": "text", "text": text})
|
||||
continue
|
||||
}
|
||||
|
||||
if blockType == "" {
|
||||
if rawThinking, hasThinking := blockMap["thinking"]; hasThinking {
|
||||
modifiedThisMsg = true
|
||||
switch v := rawThinking.(type) {
|
||||
case string:
|
||||
if v != "" {
|
||||
newContent = append(newContent, map[string]any{"type": "text", "text": v})
|
||||
}
|
||||
default:
|
||||
if b, err := json.Marshal(v); err == nil && len(b) > 0 {
|
||||
newContent = append(newContent, map[string]any{"type": "text", "text": string(b)})
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
newContent = append(newContent, block)
|
||||
}
|
||||
|
||||
if modifiedThisMsg {
|
||||
modified = true
|
||||
if len(newContent) == 0 {
|
||||
placeholder := "(content removed)"
|
||||
if role == "assistant" {
|
||||
placeholder = "(assistant content removed)"
|
||||
}
|
||||
newContent = append(newContent, map[string]any{"type": "text", "text": placeholder})
|
||||
}
|
||||
msgMap["content"] = newContent
|
||||
}
|
||||
|
||||
newMessages = append(newMessages, msgMap)
|
||||
}
|
||||
|
||||
if !modified {
|
||||
return body
|
||||
}
|
||||
|
||||
req["messages"] = newMessages
|
||||
newBody, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
return body
|
||||
}
|
||||
return newBody
|
||||
}
|
||||
|
||||
// filterThinkingBlocksInternal removes invalid thinking blocks from request
|
||||
// Strategy:
|
||||
// - When thinking.type != "enabled": Remove all thinking blocks
|
||||
|
||||
@@ -151,3 +151,148 @@ func TestFilterThinkingBlocks(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFilterThinkingBlocksForRetry_DisablesThinkingAndPreservesAsText(t *testing.T) {
|
||||
input := []byte(`{
|
||||
"model":"claude-3-5-sonnet-20241022",
|
||||
"thinking":{"type":"enabled","budget_tokens":1024},
|
||||
"messages":[
|
||||
{"role":"user","content":[{"type":"text","text":"Hi"}]},
|
||||
{"role":"assistant","content":[
|
||||
{"type":"thinking","thinking":"Let me think...","signature":"bad_sig"},
|
||||
{"type":"text","text":"Answer"}
|
||||
]}
|
||||
]
|
||||
}`)
|
||||
|
||||
out := FilterThinkingBlocksForRetry(input)
|
||||
|
||||
var req map[string]any
|
||||
require.NoError(t, json.Unmarshal(out, &req))
|
||||
_, hasThinking := req["thinking"]
|
||||
require.False(t, hasThinking)
|
||||
|
||||
msgs, ok := req["messages"].([]any)
|
||||
require.True(t, ok)
|
||||
require.Len(t, msgs, 2)
|
||||
|
||||
assistant, ok := msgs[1].(map[string]any)
|
||||
require.True(t, ok)
|
||||
content, ok := assistant["content"].([]any)
|
||||
require.True(t, ok)
|
||||
require.Len(t, content, 2)
|
||||
|
||||
first, ok := content[0].(map[string]any)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "text", first["type"])
|
||||
require.Equal(t, "Let me think...", first["text"])
|
||||
}
|
||||
|
||||
func TestFilterThinkingBlocksForRetry_DisablesThinkingEvenWithoutThinkingBlocks(t *testing.T) {
|
||||
input := []byte(`{
|
||||
"model":"claude-3-5-sonnet-20241022",
|
||||
"thinking":{"type":"enabled","budget_tokens":1024},
|
||||
"messages":[
|
||||
{"role":"user","content":[{"type":"text","text":"Hi"}]},
|
||||
{"role":"assistant","content":[{"type":"text","text":"Prefill"}]}
|
||||
]
|
||||
}`)
|
||||
|
||||
out := FilterThinkingBlocksForRetry(input)
|
||||
|
||||
var req map[string]any
|
||||
require.NoError(t, json.Unmarshal(out, &req))
|
||||
_, hasThinking := req["thinking"]
|
||||
require.False(t, hasThinking)
|
||||
}
|
||||
|
||||
func TestFilterThinkingBlocksForRetry_RemovesRedactedThinkingAndKeepsValidContent(t *testing.T) {
|
||||
input := []byte(`{
|
||||
"thinking":{"type":"enabled","budget_tokens":1024},
|
||||
"messages":[
|
||||
{"role":"assistant","content":[
|
||||
{"type":"redacted_thinking","data":"..."},
|
||||
{"type":"text","text":"Visible"}
|
||||
]}
|
||||
]
|
||||
}`)
|
||||
|
||||
out := FilterThinkingBlocksForRetry(input)
|
||||
|
||||
var req map[string]any
|
||||
require.NoError(t, json.Unmarshal(out, &req))
|
||||
_, hasThinking := req["thinking"]
|
||||
require.False(t, hasThinking)
|
||||
|
||||
msgs, ok := req["messages"].([]any)
|
||||
require.True(t, ok)
|
||||
msg0, ok := msgs[0].(map[string]any)
|
||||
require.True(t, ok)
|
||||
content, ok := msg0["content"].([]any)
|
||||
require.True(t, ok)
|
||||
require.Len(t, content, 1)
|
||||
content0, ok := content[0].(map[string]any)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "text", content0["type"])
|
||||
require.Equal(t, "Visible", content0["text"])
|
||||
}
|
||||
|
||||
func TestFilterThinkingBlocksForRetry_EmptyContentGetsPlaceholder(t *testing.T) {
|
||||
input := []byte(`{
|
||||
"thinking":{"type":"enabled"},
|
||||
"messages":[
|
||||
{"role":"assistant","content":[{"type":"redacted_thinking","data":"..."}]}
|
||||
]
|
||||
}`)
|
||||
|
||||
out := FilterThinkingBlocksForRetry(input)
|
||||
|
||||
var req map[string]any
|
||||
require.NoError(t, json.Unmarshal(out, &req))
|
||||
msgs, ok := req["messages"].([]any)
|
||||
require.True(t, ok)
|
||||
msg0, ok := msgs[0].(map[string]any)
|
||||
require.True(t, ok)
|
||||
content, ok := msg0["content"].([]any)
|
||||
require.True(t, ok)
|
||||
require.Len(t, content, 1)
|
||||
content0, ok := content[0].(map[string]any)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "text", content0["type"])
|
||||
require.NotEmpty(t, content0["text"])
|
||||
}
|
||||
|
||||
func TestFilterSignatureSensitiveBlocksForRetry_DowngradesTools(t *testing.T) {
|
||||
input := []byte(`{
|
||||
"thinking":{"type":"enabled","budget_tokens":1024},
|
||||
"messages":[
|
||||
{"role":"assistant","content":[
|
||||
{"type":"tool_use","id":"t1","name":"Bash","input":{"command":"ls"}},
|
||||
{"type":"tool_result","tool_use_id":"t1","content":"ok","is_error":false}
|
||||
]}
|
||||
]
|
||||
}`)
|
||||
|
||||
out := FilterSignatureSensitiveBlocksForRetry(input)
|
||||
|
||||
var req map[string]any
|
||||
require.NoError(t, json.Unmarshal(out, &req))
|
||||
_, hasThinking := req["thinking"]
|
||||
require.False(t, hasThinking)
|
||||
|
||||
msgs, ok := req["messages"].([]any)
|
||||
require.True(t, ok)
|
||||
msg0, ok := msgs[0].(map[string]any)
|
||||
require.True(t, ok)
|
||||
content, ok := msg0["content"].([]any)
|
||||
require.True(t, ok)
|
||||
require.Len(t, content, 2)
|
||||
content0, ok := content[0].(map[string]any)
|
||||
require.True(t, ok)
|
||||
content1, ok := content[1].(map[string]any)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "text", content0["type"])
|
||||
require.Equal(t, "text", content1["type"])
|
||||
require.Contains(t, content0["text"], "tool_use")
|
||||
require.Contains(t, content1["text"], "tool_result")
|
||||
}
|
||||
|
||||
@@ -15,11 +15,14 @@ import (
|
||||
"regexp"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
||||
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
|
||||
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
|
||||
@@ -30,6 +33,7 @@ const (
|
||||
claudeAPIURL = "https://api.anthropic.com/v1/messages?beta=true"
|
||||
claudeAPICountTokensURL = "https://api.anthropic.com/v1/messages/count_tokens?beta=true"
|
||||
stickySessionTTL = time.Hour // 粘性会话TTL
|
||||
defaultMaxLineSize = 10 * 1024 * 1024
|
||||
claudeCodeSystemPrompt = "You are Claude Code, Anthropic's official CLI for Claude."
|
||||
)
|
||||
|
||||
@@ -933,8 +937,16 @@ func (s *GatewayService) getOAuthToken(ctx context.Context, account *Account) (s
|
||||
|
||||
// 重试相关常量
|
||||
const (
|
||||
maxRetries = 10 // 最大重试次数
|
||||
retryDelay = 3 * time.Second // 重试等待时间
|
||||
// 最大尝试次数(包含首次请求)。过多重试会导致请求堆积与资源耗尽。
|
||||
maxRetryAttempts = 5
|
||||
|
||||
// 指数退避:第 N 次失败后的等待 = retryBaseDelay * 2^(N-1),并且上限为 retryMaxDelay。
|
||||
retryBaseDelay = 300 * time.Millisecond
|
||||
retryMaxDelay = 3 * time.Second
|
||||
|
||||
// 最大重试耗时(包含请求本身耗时 + 退避等待时间)。
|
||||
// 用于防止极端情况下 goroutine 长时间堆积导致资源耗尽。
|
||||
maxRetryElapsed = 10 * time.Second
|
||||
)
|
||||
|
||||
func (s *GatewayService) shouldRetryUpstreamError(account *Account, statusCode int) bool {
|
||||
@@ -957,6 +969,40 @@ func (s *GatewayService) shouldFailoverUpstreamError(statusCode int) bool {
|
||||
}
|
||||
}
|
||||
|
||||
func retryBackoffDelay(attempt int) time.Duration {
|
||||
// attempt 从 1 开始,表示第 attempt 次请求刚失败,需要等待后进行第 attempt+1 次请求。
|
||||
if attempt <= 0 {
|
||||
return retryBaseDelay
|
||||
}
|
||||
delay := retryBaseDelay * time.Duration(1<<(attempt-1))
|
||||
if delay > retryMaxDelay {
|
||||
return retryMaxDelay
|
||||
}
|
||||
return delay
|
||||
}
|
||||
|
||||
func sleepWithContext(ctx context.Context, d time.Duration) error {
|
||||
if d <= 0 {
|
||||
return nil
|
||||
}
|
||||
timer := time.NewTimer(d)
|
||||
defer func() {
|
||||
if !timer.Stop() {
|
||||
select {
|
||||
case <-timer.C:
|
||||
default:
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-timer.C:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// isClaudeCodeClient 判断请求是否来自 Claude Code 客户端
|
||||
// 简化判断:User-Agent 匹配 + metadata.user_id 存在
|
||||
func isClaudeCodeClient(userAgent string, metadataUserID string) bool {
|
||||
@@ -1073,7 +1119,8 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
|
||||
// 重试循环
|
||||
var resp *http.Response
|
||||
for attempt := 1; attempt <= maxRetries; attempt++ {
|
||||
retryStart := time.Now()
|
||||
for attempt := 1; attempt <= maxRetryAttempts; attempt++ {
|
||||
// 构建上游请求(每次重试需要重新构建,因为请求体需要重新读取)
|
||||
upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, body, token, tokenType, reqModel)
|
||||
if err != nil {
|
||||
@@ -1083,6 +1130,9 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
// 发送请求
|
||||
resp, err = s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
|
||||
if err != nil {
|
||||
if resp != nil && resp.Body != nil {
|
||||
_ = resp.Body.Close()
|
||||
}
|
||||
return nil, fmt.Errorf("upstream request failed: %w", err)
|
||||
}
|
||||
|
||||
@@ -1093,28 +1143,80 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
_ = resp.Body.Close()
|
||||
|
||||
if s.isThinkingBlockSignatureError(respBody) {
|
||||
looksLikeToolSignatureError := func(msg string) bool {
|
||||
m := strings.ToLower(msg)
|
||||
return strings.Contains(m, "tool_use") ||
|
||||
strings.Contains(m, "tool_result") ||
|
||||
strings.Contains(m, "functioncall") ||
|
||||
strings.Contains(m, "function_call") ||
|
||||
strings.Contains(m, "functionresponse") ||
|
||||
strings.Contains(m, "function_response")
|
||||
}
|
||||
|
||||
// 避免在重试预算已耗尽时再发起额外请求
|
||||
if time.Since(retryStart) >= maxRetryElapsed {
|
||||
resp.Body = io.NopCloser(bytes.NewReader(respBody))
|
||||
break
|
||||
}
|
||||
log.Printf("Account %d: detected thinking block signature error, retrying with filtered thinking blocks", account.ID)
|
||||
|
||||
// 过滤thinking blocks并重试(使用更激进的过滤)
|
||||
// Conservative two-stage fallback:
|
||||
// 1) Disable thinking + thinking->text (preserve content)
|
||||
// 2) Only if upstream still errors AND error message points to tool/function signature issues:
|
||||
// also downgrade tool_use/tool_result blocks to text.
|
||||
|
||||
filteredBody := FilterThinkingBlocksForRetry(body)
|
||||
retryReq, buildErr := s.buildUpstreamRequest(ctx, c, account, filteredBody, token, tokenType, reqModel)
|
||||
if buildErr == nil {
|
||||
retryResp, retryErr := s.httpUpstream.Do(retryReq, proxyURL, account.ID, account.Concurrency)
|
||||
if retryErr == nil {
|
||||
// 使用重试后的响应,继续后续处理
|
||||
if retryResp.StatusCode < 400 {
|
||||
log.Printf("Account %d: signature error retry succeeded", account.ID)
|
||||
} else {
|
||||
log.Printf("Account %d: signature error retry returned status %d", account.ID, retryResp.StatusCode)
|
||||
log.Printf("Account %d: signature error retry succeeded (thinking downgraded)", account.ID)
|
||||
resp = retryResp
|
||||
break
|
||||
}
|
||||
|
||||
retryRespBody, retryReadErr := io.ReadAll(io.LimitReader(retryResp.Body, 2<<20))
|
||||
_ = retryResp.Body.Close()
|
||||
if retryReadErr == nil && retryResp.StatusCode == 400 && s.isThinkingBlockSignatureError(retryRespBody) {
|
||||
msg2 := extractUpstreamErrorMessage(retryRespBody)
|
||||
if looksLikeToolSignatureError(msg2) && time.Since(retryStart) < maxRetryElapsed {
|
||||
log.Printf("Account %d: signature retry still failing and looks tool-related, retrying with tool blocks downgraded", account.ID)
|
||||
filteredBody2 := FilterSignatureSensitiveBlocksForRetry(body)
|
||||
retryReq2, buildErr2 := s.buildUpstreamRequest(ctx, c, account, filteredBody2, token, tokenType, reqModel)
|
||||
if buildErr2 == nil {
|
||||
retryResp2, retryErr2 := s.httpUpstream.Do(retryReq2, proxyURL, account.ID, account.Concurrency)
|
||||
if retryErr2 == nil {
|
||||
resp = retryResp2
|
||||
break
|
||||
}
|
||||
if retryResp2 != nil && retryResp2.Body != nil {
|
||||
_ = retryResp2.Body.Close()
|
||||
}
|
||||
log.Printf("Account %d: tool-downgrade signature retry failed: %v", account.ID, retryErr2)
|
||||
} else {
|
||||
log.Printf("Account %d: tool-downgrade signature retry build failed: %v", account.ID, buildErr2)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Fall back to the original retry response context.
|
||||
resp = &http.Response{
|
||||
StatusCode: retryResp.StatusCode,
|
||||
Header: retryResp.Header.Clone(),
|
||||
Body: io.NopCloser(bytes.NewReader(retryRespBody)),
|
||||
}
|
||||
resp = retryResp
|
||||
break
|
||||
}
|
||||
if retryResp != nil && retryResp.Body != nil {
|
||||
_ = retryResp.Body.Close()
|
||||
}
|
||||
log.Printf("Account %d: signature error retry failed: %v", account.ID, retryErr)
|
||||
} else {
|
||||
log.Printf("Account %d: signature error retry build request failed: %v", account.ID, buildErr)
|
||||
}
|
||||
// 重试失败,恢复原始响应体继续处理
|
||||
|
||||
// Retry failed: restore original response body and continue handling.
|
||||
resp.Body = io.NopCloser(bytes.NewReader(respBody))
|
||||
break
|
||||
}
|
||||
@@ -1125,11 +1227,27 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
|
||||
// 检查是否需要通用重试(排除400,因为400已经在上面特殊处理过了)
|
||||
if resp.StatusCode >= 400 && resp.StatusCode != 400 && s.shouldRetryUpstreamError(account, resp.StatusCode) {
|
||||
if attempt < maxRetries {
|
||||
log.Printf("Account %d: upstream error %d, retry %d/%d after %v",
|
||||
account.ID, resp.StatusCode, attempt, maxRetries, retryDelay)
|
||||
if attempt < maxRetryAttempts {
|
||||
elapsed := time.Since(retryStart)
|
||||
if elapsed >= maxRetryElapsed {
|
||||
break
|
||||
}
|
||||
|
||||
delay := retryBackoffDelay(attempt)
|
||||
remaining := maxRetryElapsed - elapsed
|
||||
if delay > remaining {
|
||||
delay = remaining
|
||||
}
|
||||
if delay <= 0 {
|
||||
break
|
||||
}
|
||||
|
||||
log.Printf("Account %d: upstream error %d, retry %d/%d after %v (elapsed=%v/%v)",
|
||||
account.ID, resp.StatusCode, attempt, maxRetryAttempts, delay, elapsed, maxRetryElapsed)
|
||||
_ = resp.Body.Close()
|
||||
time.Sleep(retryDelay)
|
||||
if err := sleepWithContext(ctx, delay); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
continue
|
||||
}
|
||||
// 最后一次尝试也失败,跳出循环处理重试耗尽
|
||||
@@ -1146,6 +1264,9 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
}
|
||||
break
|
||||
}
|
||||
if resp == nil || resp.Body == nil {
|
||||
return nil, errors.New("upstream request failed: empty response")
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
// 处理重试耗尽的情况
|
||||
@@ -1229,7 +1350,13 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
|
||||
targetURL := claudeAPIURL
|
||||
if account.Type == AccountTypeAPIKey {
|
||||
baseURL := account.GetBaseURL()
|
||||
targetURL = baseURL + "/v1/messages"
|
||||
if baseURL != "" {
|
||||
validatedURL, err := s.validateUpstreamBaseURL(baseURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
targetURL = validatedURL + "/v1/messages"
|
||||
}
|
||||
}
|
||||
|
||||
// OAuth账号:应用统一指纹
|
||||
@@ -1537,10 +1664,10 @@ func (s *GatewayService) handleRetryExhaustedSideEffects(ctx context.Context, re
|
||||
// OAuth/Setup Token 账号的 403:标记账号异常
|
||||
if account.IsOAuth() && statusCode == 403 {
|
||||
s.rateLimitService.HandleUpstreamError(ctx, account, statusCode, resp.Header, body)
|
||||
log.Printf("Account %d: marked as error after %d retries for status %d", account.ID, maxRetries, statusCode)
|
||||
log.Printf("Account %d: marked as error after %d retries for status %d", account.ID, maxRetryAttempts, statusCode)
|
||||
} else {
|
||||
// API Key 未配置错误码:不标记账号状态
|
||||
log.Printf("Account %d: upstream error %d after %d retries (not marking account)", account.ID, statusCode, maxRetries)
|
||||
log.Printf("Account %d: upstream error %d after %d retries (not marking account)", account.ID, statusCode, maxRetryAttempts)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1577,6 +1704,10 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
|
||||
// 更新5h窗口状态
|
||||
s.rateLimitService.UpdateSessionWindow(ctx, account, resp.Header)
|
||||
|
||||
if s.cfg != nil {
|
||||
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.cfg.Security.ResponseHeaders)
|
||||
}
|
||||
|
||||
// 设置SSE响应头
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
@@ -1598,51 +1729,133 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
|
||||
var firstTokenMs *int
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
// 设置更大的buffer以处理长行
|
||||
scanner.Buffer(make([]byte, 64*1024), 1024*1024)
|
||||
maxLineSize := defaultMaxLineSize
|
||||
if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 {
|
||||
maxLineSize = s.cfg.Gateway.MaxLineSize
|
||||
}
|
||||
scanner.Buffer(make([]byte, 64*1024), maxLineSize)
|
||||
|
||||
type scanEvent struct {
|
||||
line string
|
||||
err error
|
||||
}
|
||||
// 独立 goroutine 读取上游,避免读取阻塞导致超时/keepalive无法处理
|
||||
events := make(chan scanEvent, 16)
|
||||
done := make(chan struct{})
|
||||
sendEvent := func(ev scanEvent) bool {
|
||||
select {
|
||||
case events <- ev:
|
||||
return true
|
||||
case <-done:
|
||||
return false
|
||||
}
|
||||
}
|
||||
var lastReadAt int64
|
||||
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
|
||||
go func() {
|
||||
defer close(events)
|
||||
for scanner.Scan() {
|
||||
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
|
||||
if !sendEvent(scanEvent{line: scanner.Text()}) {
|
||||
return
|
||||
}
|
||||
}
|
||||
if err := scanner.Err(); err != nil {
|
||||
_ = sendEvent(scanEvent{err: err})
|
||||
}
|
||||
}()
|
||||
defer close(done)
|
||||
|
||||
streamInterval := time.Duration(0)
|
||||
if s.cfg != nil && s.cfg.Gateway.StreamDataIntervalTimeout > 0 {
|
||||
streamInterval = time.Duration(s.cfg.Gateway.StreamDataIntervalTimeout) * time.Second
|
||||
}
|
||||
// 仅监控上游数据间隔超时,避免下游写入阻塞导致误判
|
||||
var intervalTicker *time.Ticker
|
||||
if streamInterval > 0 {
|
||||
intervalTicker = time.NewTicker(streamInterval)
|
||||
defer intervalTicker.Stop()
|
||||
}
|
||||
var intervalCh <-chan time.Time
|
||||
if intervalTicker != nil {
|
||||
intervalCh = intervalTicker.C
|
||||
}
|
||||
|
||||
// 仅发送一次错误事件,避免多次写入导致协议混乱(写失败时尽力通知客户端)
|
||||
errorEventSent := false
|
||||
sendErrorEvent := func(reason string) {
|
||||
if errorEventSent {
|
||||
return
|
||||
}
|
||||
errorEventSent = true
|
||||
_, _ = fmt.Fprintf(w, "event: error\ndata: {\"error\":\"%s\"}\n\n", reason)
|
||||
flusher.Flush()
|
||||
}
|
||||
|
||||
needModelReplace := originalModel != mappedModel
|
||||
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if line == "event: error" {
|
||||
return nil, errors.New("have error in stream")
|
||||
}
|
||||
|
||||
// Extract data from SSE line (supports both "data: " and "data:" formats)
|
||||
if sseDataRe.MatchString(line) {
|
||||
data := sseDataRe.ReplaceAllString(line, "")
|
||||
|
||||
// 如果有模型映射,替换响应中的model字段
|
||||
if needModelReplace {
|
||||
line = s.replaceModelInSSELine(line, mappedModel, originalModel)
|
||||
for {
|
||||
select {
|
||||
case ev, ok := <-events:
|
||||
if !ok {
|
||||
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, nil
|
||||
}
|
||||
if ev.err != nil {
|
||||
if errors.Is(ev.err, bufio.ErrTooLong) {
|
||||
log.Printf("SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, ev.err)
|
||||
sendErrorEvent("response_too_large")
|
||||
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, ev.err
|
||||
}
|
||||
sendErrorEvent("stream_read_error")
|
||||
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream read error: %w", ev.err)
|
||||
}
|
||||
line := ev.line
|
||||
if line == "event: error" {
|
||||
return nil, errors.New("have error in stream")
|
||||
}
|
||||
|
||||
// 转发行
|
||||
if _, err := fmt.Fprintf(w, "%s\n", line); err != nil {
|
||||
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, err
|
||||
}
|
||||
flusher.Flush()
|
||||
// Extract data from SSE line (supports both "data: " and "data:" formats)
|
||||
if sseDataRe.MatchString(line) {
|
||||
data := sseDataRe.ReplaceAllString(line, "")
|
||||
|
||||
// 记录首字时间:第一个有效的 content_block_delta 或 message_start
|
||||
if firstTokenMs == nil && data != "" && data != "[DONE]" {
|
||||
ms := int(time.Since(startTime).Milliseconds())
|
||||
firstTokenMs = &ms
|
||||
// 如果有模型映射,替换响应中的model字段
|
||||
if needModelReplace {
|
||||
line = s.replaceModelInSSELine(line, mappedModel, originalModel)
|
||||
}
|
||||
|
||||
// 转发行
|
||||
if _, err := fmt.Fprintf(w, "%s\n", line); err != nil {
|
||||
sendErrorEvent("write_failed")
|
||||
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, err
|
||||
}
|
||||
flusher.Flush()
|
||||
|
||||
// 记录首字时间:第一个有效的 content_block_delta 或 message_start
|
||||
if firstTokenMs == nil && data != "" && data != "[DONE]" {
|
||||
ms := int(time.Since(startTime).Milliseconds())
|
||||
firstTokenMs = &ms
|
||||
}
|
||||
s.parseSSEUsage(data, usage)
|
||||
} else {
|
||||
// 非 data 行直接转发
|
||||
if _, err := fmt.Fprintf(w, "%s\n", line); err != nil {
|
||||
sendErrorEvent("write_failed")
|
||||
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, err
|
||||
}
|
||||
flusher.Flush()
|
||||
}
|
||||
s.parseSSEUsage(data, usage)
|
||||
} else {
|
||||
// 非 data 行直接转发
|
||||
if _, err := fmt.Fprintf(w, "%s\n", line); err != nil {
|
||||
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, err
|
||||
|
||||
case <-intervalCh:
|
||||
lastRead := time.Unix(0, atomic.LoadInt64(&lastReadAt))
|
||||
if time.Since(lastRead) < streamInterval {
|
||||
continue
|
||||
}
|
||||
flusher.Flush()
|
||||
log.Printf("Stream data interval timeout: account=%d model=%s interval=%s", account.ID, originalModel, streamInterval)
|
||||
sendErrorEvent("stream_timeout")
|
||||
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout")
|
||||
}
|
||||
}
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream read error: %w", err)
|
||||
}
|
||||
|
||||
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, nil
|
||||
}
|
||||
|
||||
// replaceModelInSSELine 替换SSE数据行中的model字段
|
||||
@@ -1747,15 +1960,17 @@ func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *h
|
||||
body = s.replaceModelInResponseBody(body, mappedModel, originalModel)
|
||||
}
|
||||
|
||||
// 透传响应头
|
||||
for key, values := range resp.Header {
|
||||
for _, value := range values {
|
||||
c.Header(key, value)
|
||||
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.cfg.Security.ResponseHeaders)
|
||||
|
||||
contentType := "application/json"
|
||||
if s.cfg != nil && !s.cfg.Security.ResponseHeaders.Enabled {
|
||||
if upstreamType := resp.Header.Get("Content-Type"); upstreamType != "" {
|
||||
contentType = upstreamType
|
||||
}
|
||||
}
|
||||
|
||||
// 写入响应
|
||||
c.Data(resp.StatusCode, "application/json", body)
|
||||
c.Data(resp.StatusCode, contentType, body)
|
||||
|
||||
return &response.Usage, nil
|
||||
}
|
||||
@@ -1989,7 +2204,7 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
|
||||
if resp.StatusCode == 400 && s.isThinkingBlockSignatureError(respBody) {
|
||||
log.Printf("Account %d: detected thinking block signature error on count_tokens, retrying with filtered thinking blocks", account.ID)
|
||||
|
||||
filteredBody := FilterThinkingBlocks(body)
|
||||
filteredBody := FilterThinkingBlocksForRetry(body)
|
||||
retryReq, buildErr := s.buildCountTokensRequest(ctx, c, account, filteredBody, token, tokenType, reqModel)
|
||||
if buildErr == nil {
|
||||
retryResp, retryErr := s.httpUpstream.Do(retryReq, proxyURL, account.ID, account.Concurrency)
|
||||
@@ -2045,7 +2260,13 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
|
||||
targetURL := claudeAPICountTokensURL
|
||||
if account.Type == AccountTypeAPIKey {
|
||||
baseURL := account.GetBaseURL()
|
||||
targetURL = baseURL + "/v1/messages/count_tokens"
|
||||
if baseURL != "" {
|
||||
validatedURL, err := s.validateUpstreamBaseURL(baseURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
targetURL = validatedURL + "/v1/messages/count_tokens"
|
||||
}
|
||||
}
|
||||
|
||||
// OAuth 账号:应用统一指纹和重写 userID
|
||||
@@ -2125,6 +2346,25 @@ func (s *GatewayService) countTokensError(c *gin.Context, status int, errType, m
|
||||
})
|
||||
}
|
||||
|
||||
func (s *GatewayService) validateUpstreamBaseURL(raw string) (string, error) {
|
||||
if s.cfg != nil && !s.cfg.Security.URLAllowlist.Enabled {
|
||||
normalized, err := urlvalidator.ValidateURLFormat(raw, s.cfg.Security.URLAllowlist.AllowInsecureHTTP)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("invalid base_url: %w", err)
|
||||
}
|
||||
return normalized, nil
|
||||
}
|
||||
normalized, err := urlvalidator.ValidateHTTPSURL(raw, urlvalidator.ValidationOptions{
|
||||
AllowedHosts: s.cfg.Security.URLAllowlist.UpstreamHosts,
|
||||
RequireAllowlist: true,
|
||||
AllowPrivate: s.cfg.Security.URLAllowlist.AllowPrivateHosts,
|
||||
})
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("invalid base_url: %w", err)
|
||||
}
|
||||
return normalized, nil
|
||||
}
|
||||
|
||||
// GetAvailableModels returns the list of models available for a group
|
||||
// It aggregates model_mapping keys from all schedulable accounts in the group
|
||||
func (s *GatewayService) GetAvailableModels(ctx context.Context, groupID *int64, platform string) []string {
|
||||
|
||||
@@ -18,9 +18,12 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/geminicli"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/googleapi"
|
||||
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
|
||||
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
@@ -41,6 +44,7 @@ type GeminiMessagesCompatService struct {
|
||||
rateLimitService *RateLimitService
|
||||
httpUpstream HTTPUpstream
|
||||
antigravityGatewayService *AntigravityGatewayService
|
||||
cfg *config.Config
|
||||
}
|
||||
|
||||
func NewGeminiMessagesCompatService(
|
||||
@@ -51,6 +55,7 @@ func NewGeminiMessagesCompatService(
|
||||
rateLimitService *RateLimitService,
|
||||
httpUpstream HTTPUpstream,
|
||||
antigravityGatewayService *AntigravityGatewayService,
|
||||
cfg *config.Config,
|
||||
) *GeminiMessagesCompatService {
|
||||
return &GeminiMessagesCompatService{
|
||||
accountRepo: accountRepo,
|
||||
@@ -60,6 +65,7 @@ func NewGeminiMessagesCompatService(
|
||||
rateLimitService: rateLimitService,
|
||||
httpUpstream: httpUpstream,
|
||||
antigravityGatewayService: antigravityGatewayService,
|
||||
cfg: cfg,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -230,6 +236,25 @@ func (s *GeminiMessagesCompatService) GetAntigravityGatewayService() *Antigravit
|
||||
return s.antigravityGatewayService
|
||||
}
|
||||
|
||||
func (s *GeminiMessagesCompatService) validateUpstreamBaseURL(raw string) (string, error) {
|
||||
if s.cfg != nil && !s.cfg.Security.URLAllowlist.Enabled {
|
||||
normalized, err := urlvalidator.ValidateURLFormat(raw, s.cfg.Security.URLAllowlist.AllowInsecureHTTP)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("invalid base_url: %w", err)
|
||||
}
|
||||
return normalized, nil
|
||||
}
|
||||
normalized, err := urlvalidator.ValidateHTTPSURL(raw, urlvalidator.ValidationOptions{
|
||||
AllowedHosts: s.cfg.Security.URLAllowlist.UpstreamHosts,
|
||||
RequireAllowlist: true,
|
||||
AllowPrivate: s.cfg.Security.URLAllowlist.AllowPrivateHosts,
|
||||
})
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("invalid base_url: %w", err)
|
||||
}
|
||||
return normalized, nil
|
||||
}
|
||||
|
||||
// HasAntigravityAccounts 检查是否有可用的 antigravity 账户
|
||||
func (s *GeminiMessagesCompatService) HasAntigravityAccounts(ctx context.Context, groupID *int64) (bool, error) {
|
||||
var accounts []Account
|
||||
@@ -359,6 +384,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
|
||||
if err != nil {
|
||||
return nil, s.writeClaudeError(c, http.StatusBadRequest, "invalid_request_error", err.Error())
|
||||
}
|
||||
originalClaudeBody := body
|
||||
|
||||
proxyURL := ""
|
||||
if account.ProxyID != nil && account.Proxy != nil {
|
||||
@@ -381,16 +407,20 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
|
||||
return nil, "", errors.New("gemini api_key not configured")
|
||||
}
|
||||
|
||||
baseURL := strings.TrimRight(account.GetCredential("base_url"), "/")
|
||||
baseURL := strings.TrimSpace(account.GetCredential("base_url"))
|
||||
if baseURL == "" {
|
||||
baseURL = geminicli.AIStudioBaseURL
|
||||
}
|
||||
normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
action := "generateContent"
|
||||
if req.Stream {
|
||||
action = "streamGenerateContent"
|
||||
}
|
||||
fullURL := fmt.Sprintf("%s/v1beta/models/%s:%s", strings.TrimRight(baseURL, "/"), mappedModel, action)
|
||||
fullURL := fmt.Sprintf("%s/v1beta/models/%s:%s", strings.TrimRight(normalizedBaseURL, "/"), mappedModel, action)
|
||||
if req.Stream {
|
||||
fullURL += "?alt=sse"
|
||||
}
|
||||
@@ -427,7 +457,11 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
|
||||
// 2. Without project_id -> AI Studio API (direct OAuth, like API key but with Bearer token)
|
||||
if projectID != "" {
|
||||
// Mode 1: Code Assist API
|
||||
fullURL := fmt.Sprintf("%s/v1internal:%s", geminicli.GeminiCliBaseURL, action)
|
||||
baseURL, err := s.validateUpstreamBaseURL(geminicli.GeminiCliBaseURL)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
fullURL := fmt.Sprintf("%s/v1internal:%s", strings.TrimRight(baseURL, "/"), action)
|
||||
if useUpstreamStream {
|
||||
fullURL += "?alt=sse"
|
||||
}
|
||||
@@ -453,12 +487,16 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
|
||||
return upstreamReq, "x-request-id", nil
|
||||
} else {
|
||||
// Mode 2: AI Studio API with OAuth (like API key mode, but using Bearer token)
|
||||
baseURL := strings.TrimRight(account.GetCredential("base_url"), "/")
|
||||
baseURL := strings.TrimSpace(account.GetCredential("base_url"))
|
||||
if baseURL == "" {
|
||||
baseURL = geminicli.AIStudioBaseURL
|
||||
}
|
||||
normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
fullURL := fmt.Sprintf("%s/v1beta/models/%s:%s", baseURL, mappedModel, action)
|
||||
fullURL := fmt.Sprintf("%s/v1beta/models/%s:%s", strings.TrimRight(normalizedBaseURL, "/"), mappedModel, action)
|
||||
if useUpstreamStream {
|
||||
fullURL += "?alt=sse"
|
||||
}
|
||||
@@ -479,6 +517,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
|
||||
}
|
||||
|
||||
var resp *http.Response
|
||||
signatureRetryStage := 0
|
||||
for attempt := 1; attempt <= geminiMaxRetries; attempt++ {
|
||||
upstreamReq, idHeader, err := buildReq(ctx)
|
||||
if err != nil {
|
||||
@@ -503,6 +542,46 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
|
||||
return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Upstream request failed after retries: "+sanitizeUpstreamErrorMessage(err.Error()))
|
||||
}
|
||||
|
||||
// Special-case: signature/thought_signature validation errors are not transient, but may be fixed by
|
||||
// downgrading Claude thinking/tool history to plain text (conservative two-stage retry).
|
||||
if resp.StatusCode == http.StatusBadRequest && signatureRetryStage < 2 {
|
||||
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||||
_ = resp.Body.Close()
|
||||
|
||||
if isGeminiSignatureRelatedError(respBody) {
|
||||
var strippedClaudeBody []byte
|
||||
stageName := ""
|
||||
switch signatureRetryStage {
|
||||
case 0:
|
||||
// Stage 1: disable thinking + thinking->text
|
||||
strippedClaudeBody = FilterThinkingBlocksForRetry(originalClaudeBody)
|
||||
stageName = "thinking-only"
|
||||
signatureRetryStage = 1
|
||||
default:
|
||||
// Stage 2: additionally downgrade tool_use/tool_result blocks to text
|
||||
strippedClaudeBody = FilterSignatureSensitiveBlocksForRetry(originalClaudeBody)
|
||||
stageName = "thinking+tools"
|
||||
signatureRetryStage = 2
|
||||
}
|
||||
retryGeminiReq, txErr := convertClaudeMessagesToGeminiGenerateContent(strippedClaudeBody)
|
||||
if txErr == nil {
|
||||
log.Printf("Gemini account %d: detected signature-related 400, retrying with downgraded Claude blocks (%s)", account.ID, stageName)
|
||||
geminiReq = retryGeminiReq
|
||||
// Consume one retry budget attempt and continue with the updated request payload.
|
||||
sleepGeminiBackoff(1)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// Restore body for downstream error handling.
|
||||
resp = &http.Response{
|
||||
StatusCode: http.StatusBadRequest,
|
||||
Header: resp.Header.Clone(),
|
||||
Body: io.NopCloser(bytes.NewReader(respBody)),
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
if resp.StatusCode >= 400 && s.shouldRetryGeminiUpstreamError(account, resp.StatusCode) {
|
||||
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||||
_ = resp.Body.Close()
|
||||
@@ -600,6 +679,14 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
|
||||
}, nil
|
||||
}
|
||||
|
||||
func isGeminiSignatureRelatedError(respBody []byte) bool {
|
||||
msg := strings.ToLower(strings.TrimSpace(extractAntigravityErrorMessage(respBody)))
|
||||
if msg == "" {
|
||||
msg = strings.ToLower(string(respBody))
|
||||
}
|
||||
return strings.Contains(msg, "thought_signature") || strings.Contains(msg, "signature")
|
||||
}
|
||||
|
||||
func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.Context, account *Account, originalModel string, action string, stream bool, body []byte) (*ForwardResult, error) {
|
||||
startTime := time.Now()
|
||||
|
||||
@@ -650,12 +737,16 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
|
||||
return nil, "", errors.New("gemini api_key not configured")
|
||||
}
|
||||
|
||||
baseURL := strings.TrimRight(account.GetCredential("base_url"), "/")
|
||||
baseURL := strings.TrimSpace(account.GetCredential("base_url"))
|
||||
if baseURL == "" {
|
||||
baseURL = geminicli.AIStudioBaseURL
|
||||
}
|
||||
normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
fullURL := fmt.Sprintf("%s/v1beta/models/%s:%s", strings.TrimRight(baseURL, "/"), mappedModel, upstreamAction)
|
||||
fullURL := fmt.Sprintf("%s/v1beta/models/%s:%s", strings.TrimRight(normalizedBaseURL, "/"), mappedModel, upstreamAction)
|
||||
if useUpstreamStream {
|
||||
fullURL += "?alt=sse"
|
||||
}
|
||||
@@ -687,7 +778,11 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
|
||||
// 2. Without project_id -> AI Studio API (direct OAuth, like API key but with Bearer token)
|
||||
if projectID != "" && !forceAIStudio {
|
||||
// Mode 1: Code Assist API
|
||||
fullURL := fmt.Sprintf("%s/v1internal:%s", geminicli.GeminiCliBaseURL, upstreamAction)
|
||||
baseURL, err := s.validateUpstreamBaseURL(geminicli.GeminiCliBaseURL)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
fullURL := fmt.Sprintf("%s/v1internal:%s", strings.TrimRight(baseURL, "/"), upstreamAction)
|
||||
if useUpstreamStream {
|
||||
fullURL += "?alt=sse"
|
||||
}
|
||||
@@ -713,12 +808,16 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
|
||||
return upstreamReq, "x-request-id", nil
|
||||
} else {
|
||||
// Mode 2: AI Studio API with OAuth (like API key mode, but using Bearer token)
|
||||
baseURL := strings.TrimRight(account.GetCredential("base_url"), "/")
|
||||
baseURL := strings.TrimSpace(account.GetCredential("base_url"))
|
||||
if baseURL == "" {
|
||||
baseURL = geminicli.AIStudioBaseURL
|
||||
}
|
||||
normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
fullURL := fmt.Sprintf("%s/v1beta/models/%s:%s", baseURL, mappedModel, upstreamAction)
|
||||
fullURL := fmt.Sprintf("%s/v1beta/models/%s:%s", strings.TrimRight(normalizedBaseURL, "/"), mappedModel, upstreamAction)
|
||||
if useUpstreamStream {
|
||||
fullURL += "?alt=sse"
|
||||
}
|
||||
@@ -1652,6 +1751,8 @@ func (s *GeminiMessagesCompatService) handleNativeNonStreamingResponse(c *gin.Co
|
||||
_ = json.Unmarshal(respBody, &parsed)
|
||||
}
|
||||
|
||||
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.cfg.Security.ResponseHeaders)
|
||||
|
||||
contentType := resp.Header.Get("Content-Type")
|
||||
if contentType == "" {
|
||||
contentType = "application/json"
|
||||
@@ -1676,6 +1777,10 @@ func (s *GeminiMessagesCompatService) handleNativeStreamingResponse(c *gin.Conte
|
||||
}
|
||||
log.Printf("[GeminiAPI] ====================================================")
|
||||
|
||||
if s.cfg != nil {
|
||||
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.cfg.Security.ResponseHeaders)
|
||||
}
|
||||
|
||||
c.Status(resp.StatusCode)
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
@@ -1773,11 +1878,15 @@ func (s *GeminiMessagesCompatService) ForwardAIStudioGET(ctx context.Context, ac
|
||||
return nil, errors.New("invalid path")
|
||||
}
|
||||
|
||||
baseURL := strings.TrimRight(account.GetCredential("base_url"), "/")
|
||||
baseURL := strings.TrimSpace(account.GetCredential("base_url"))
|
||||
if baseURL == "" {
|
||||
baseURL = geminicli.AIStudioBaseURL
|
||||
}
|
||||
fullURL := strings.TrimRight(baseURL, "/") + path
|
||||
normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
fullURL := strings.TrimRight(normalizedBaseURL, "/") + path
|
||||
|
||||
var proxyURL string
|
||||
if account.ProxyID != nil && account.Proxy != nil {
|
||||
@@ -1816,9 +1925,14 @@ func (s *GeminiMessagesCompatService) ForwardAIStudioGET(ctx context.Context, ac
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
body, _ := io.ReadAll(io.LimitReader(resp.Body, 8<<20))
|
||||
wwwAuthenticate := resp.Header.Get("Www-Authenticate")
|
||||
filteredHeaders := responseheaders.FilterHeaders(resp.Header, s.cfg.Security.ResponseHeaders)
|
||||
if wwwAuthenticate != "" {
|
||||
filteredHeaders.Set("Www-Authenticate", wwwAuthenticate)
|
||||
}
|
||||
return &UpstreamHTTPResult{
|
||||
StatusCode: resp.StatusCode,
|
||||
Headers: resp.Header.Clone(),
|
||||
Headers: filteredHeaders,
|
||||
Body: body,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -1000,8 +1000,9 @@ func fetchProjectIDFromResourceManager(ctx context.Context, accessToken, proxyUR
|
||||
req.Header.Set("User-Agent", geminicli.GeminiCLIUserAgent)
|
||||
|
||||
client, err := httpclient.GetClient(httpclient.Options{
|
||||
ProxyURL: strings.TrimSpace(proxyURL),
|
||||
Timeout: 30 * time.Second,
|
||||
ProxyURL: strings.TrimSpace(proxyURL),
|
||||
Timeout: 30 * time.Second,
|
||||
ValidateResolvedIP: true,
|
||||
})
|
||||
if err != nil {
|
||||
client = &http.Client{Timeout: 30 * time.Second}
|
||||
|
||||
@@ -16,9 +16,12 @@ import (
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
|
||||
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
@@ -630,10 +633,14 @@ func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin.
|
||||
case AccountTypeAPIKey:
|
||||
// API Key accounts use Platform API or custom base URL
|
||||
baseURL := account.GetOpenAIBaseURL()
|
||||
if baseURL != "" {
|
||||
targetURL = baseURL + "/responses"
|
||||
} else {
|
||||
if baseURL == "" {
|
||||
targetURL = openaiPlatformAPIURL
|
||||
} else {
|
||||
validatedURL, err := s.validateUpstreamBaseURL(baseURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
targetURL = validatedURL + "/responses"
|
||||
}
|
||||
default:
|
||||
targetURL = openaiPlatformAPIURL
|
||||
@@ -755,6 +762,10 @@ type openaiStreamingResult struct {
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, startTime time.Time, originalModel, mappedModel string) (*openaiStreamingResult, error) {
|
||||
if s.cfg != nil {
|
||||
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.cfg.Security.ResponseHeaders)
|
||||
}
|
||||
|
||||
// Set SSE response headers
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
@@ -775,48 +786,158 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
|
||||
usage := &OpenAIUsage{}
|
||||
var firstTokenMs *int
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner.Buffer(make([]byte, 64*1024), 1024*1024)
|
||||
maxLineSize := defaultMaxLineSize
|
||||
if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 {
|
||||
maxLineSize = s.cfg.Gateway.MaxLineSize
|
||||
}
|
||||
scanner.Buffer(make([]byte, 64*1024), maxLineSize)
|
||||
|
||||
type scanEvent struct {
|
||||
line string
|
||||
err error
|
||||
}
|
||||
// 独立 goroutine 读取上游,避免读取阻塞影响 keepalive/超时处理
|
||||
events := make(chan scanEvent, 16)
|
||||
done := make(chan struct{})
|
||||
sendEvent := func(ev scanEvent) bool {
|
||||
select {
|
||||
case events <- ev:
|
||||
return true
|
||||
case <-done:
|
||||
return false
|
||||
}
|
||||
}
|
||||
var lastReadAt int64
|
||||
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
|
||||
go func() {
|
||||
defer close(events)
|
||||
for scanner.Scan() {
|
||||
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
|
||||
if !sendEvent(scanEvent{line: scanner.Text()}) {
|
||||
return
|
||||
}
|
||||
}
|
||||
if err := scanner.Err(); err != nil {
|
||||
_ = sendEvent(scanEvent{err: err})
|
||||
}
|
||||
}()
|
||||
defer close(done)
|
||||
|
||||
streamInterval := time.Duration(0)
|
||||
if s.cfg != nil && s.cfg.Gateway.StreamDataIntervalTimeout > 0 {
|
||||
streamInterval = time.Duration(s.cfg.Gateway.StreamDataIntervalTimeout) * time.Second
|
||||
}
|
||||
// 仅监控上游数据间隔超时,不被下游写入阻塞影响
|
||||
var intervalTicker *time.Ticker
|
||||
if streamInterval > 0 {
|
||||
intervalTicker = time.NewTicker(streamInterval)
|
||||
defer intervalTicker.Stop()
|
||||
}
|
||||
var intervalCh <-chan time.Time
|
||||
if intervalTicker != nil {
|
||||
intervalCh = intervalTicker.C
|
||||
}
|
||||
|
||||
keepaliveInterval := time.Duration(0)
|
||||
if s.cfg != nil && s.cfg.Gateway.StreamKeepaliveInterval > 0 {
|
||||
keepaliveInterval = time.Duration(s.cfg.Gateway.StreamKeepaliveInterval) * time.Second
|
||||
}
|
||||
// 下游 keepalive 仅用于防止代理空闲断开
|
||||
var keepaliveTicker *time.Ticker
|
||||
if keepaliveInterval > 0 {
|
||||
keepaliveTicker = time.NewTicker(keepaliveInterval)
|
||||
defer keepaliveTicker.Stop()
|
||||
}
|
||||
var keepaliveCh <-chan time.Time
|
||||
if keepaliveTicker != nil {
|
||||
keepaliveCh = keepaliveTicker.C
|
||||
}
|
||||
// 记录上次收到上游数据的时间,用于控制 keepalive 发送频率
|
||||
lastDataAt := time.Now()
|
||||
|
||||
// 仅发送一次错误事件,避免多次写入导致协议混乱(写失败时尽力通知客户端)
|
||||
errorEventSent := false
|
||||
sendErrorEvent := func(reason string) {
|
||||
if errorEventSent {
|
||||
return
|
||||
}
|
||||
errorEventSent = true
|
||||
_, _ = fmt.Fprintf(w, "event: error\ndata: {\"error\":\"%s\"}\n\n", reason)
|
||||
flusher.Flush()
|
||||
}
|
||||
|
||||
needModelReplace := originalModel != mappedModel
|
||||
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
|
||||
// Extract data from SSE line (supports both "data: " and "data:" formats)
|
||||
if openaiSSEDataRe.MatchString(line) {
|
||||
data := openaiSSEDataRe.ReplaceAllString(line, "")
|
||||
|
||||
// Replace model in response if needed
|
||||
if needModelReplace {
|
||||
line = s.replaceModelInSSELine(line, mappedModel, originalModel)
|
||||
for {
|
||||
select {
|
||||
case ev, ok := <-events:
|
||||
if !ok {
|
||||
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, nil
|
||||
}
|
||||
if ev.err != nil {
|
||||
if errors.Is(ev.err, bufio.ErrTooLong) {
|
||||
log.Printf("SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, ev.err)
|
||||
sendErrorEvent("response_too_large")
|
||||
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, ev.err
|
||||
}
|
||||
sendErrorEvent("stream_read_error")
|
||||
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream read error: %w", ev.err)
|
||||
}
|
||||
|
||||
// Forward line
|
||||
if _, err := fmt.Fprintf(w, "%s\n", line); err != nil {
|
||||
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, err
|
||||
}
|
||||
flusher.Flush()
|
||||
line := ev.line
|
||||
lastDataAt = time.Now()
|
||||
|
||||
// Record first token time
|
||||
if firstTokenMs == nil && data != "" && data != "[DONE]" {
|
||||
ms := int(time.Since(startTime).Milliseconds())
|
||||
firstTokenMs = &ms
|
||||
// Extract data from SSE line (supports both "data: " and "data:" formats)
|
||||
if openaiSSEDataRe.MatchString(line) {
|
||||
data := openaiSSEDataRe.ReplaceAllString(line, "")
|
||||
|
||||
// Replace model in response if needed
|
||||
if needModelReplace {
|
||||
line = s.replaceModelInSSELine(line, mappedModel, originalModel)
|
||||
}
|
||||
|
||||
// Forward line
|
||||
if _, err := fmt.Fprintf(w, "%s\n", line); err != nil {
|
||||
sendErrorEvent("write_failed")
|
||||
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, err
|
||||
}
|
||||
flusher.Flush()
|
||||
|
||||
// Record first token time
|
||||
if firstTokenMs == nil && data != "" && data != "[DONE]" {
|
||||
ms := int(time.Since(startTime).Milliseconds())
|
||||
firstTokenMs = &ms
|
||||
}
|
||||
s.parseSSEUsage(data, usage)
|
||||
} else {
|
||||
// Forward non-data lines as-is
|
||||
if _, err := fmt.Fprintf(w, "%s\n", line); err != nil {
|
||||
sendErrorEvent("write_failed")
|
||||
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, err
|
||||
}
|
||||
flusher.Flush()
|
||||
}
|
||||
s.parseSSEUsage(data, usage)
|
||||
} else {
|
||||
// Forward non-data lines as-is
|
||||
if _, err := fmt.Fprintf(w, "%s\n", line); err != nil {
|
||||
|
||||
case <-intervalCh:
|
||||
lastRead := time.Unix(0, atomic.LoadInt64(&lastReadAt))
|
||||
if time.Since(lastRead) < streamInterval {
|
||||
continue
|
||||
}
|
||||
log.Printf("Stream data interval timeout: account=%d model=%s interval=%s", account.ID, originalModel, streamInterval)
|
||||
sendErrorEvent("stream_timeout")
|
||||
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout")
|
||||
|
||||
case <-keepaliveCh:
|
||||
if time.Since(lastDataAt) < keepaliveInterval {
|
||||
continue
|
||||
}
|
||||
if _, err := fmt.Fprint(w, ":\n\n"); err != nil {
|
||||
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, err
|
||||
}
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream read error: %w", err)
|
||||
}
|
||||
|
||||
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, nil
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) replaceModelInSSELine(line, fromModel, toModel string) string {
|
||||
@@ -911,18 +1032,39 @@ func (s *OpenAIGatewayService) handleNonStreamingResponse(ctx context.Context, r
|
||||
body = s.replaceModelInResponseBody(body, mappedModel, originalModel)
|
||||
}
|
||||
|
||||
// Pass through headers
|
||||
for key, values := range resp.Header {
|
||||
for _, value := range values {
|
||||
c.Header(key, value)
|
||||
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.cfg.Security.ResponseHeaders)
|
||||
|
||||
contentType := "application/json"
|
||||
if s.cfg != nil && !s.cfg.Security.ResponseHeaders.Enabled {
|
||||
if upstreamType := resp.Header.Get("Content-Type"); upstreamType != "" {
|
||||
contentType = upstreamType
|
||||
}
|
||||
}
|
||||
|
||||
c.Data(resp.StatusCode, "application/json", body)
|
||||
c.Data(resp.StatusCode, contentType, body)
|
||||
|
||||
return usage, nil
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) validateUpstreamBaseURL(raw string) (string, error) {
|
||||
if s.cfg != nil && !s.cfg.Security.URLAllowlist.Enabled {
|
||||
normalized, err := urlvalidator.ValidateURLFormat(raw, s.cfg.Security.URLAllowlist.AllowInsecureHTTP)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("invalid base_url: %w", err)
|
||||
}
|
||||
return normalized, nil
|
||||
}
|
||||
normalized, err := urlvalidator.ValidateHTTPSURL(raw, urlvalidator.ValidationOptions{
|
||||
AllowedHosts: s.cfg.Security.URLAllowlist.UpstreamHosts,
|
||||
RequireAllowlist: true,
|
||||
AllowPrivate: s.cfg.Security.URLAllowlist.AllowPrivateHosts,
|
||||
})
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("invalid base_url: %w", err)
|
||||
}
|
||||
return normalized, nil
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) replaceModelInResponseBody(body []byte, fromModel, toModel string) []byte {
|
||||
var resp map[string]any
|
||||
if err := json.Unmarshal(body, &resp); err != nil {
|
||||
|
||||
286
backend/internal/service/openai_gateway_service_test.go
Normal file
286
backend/internal/service/openai_gateway_service_test.go
Normal file
@@ -0,0 +1,286 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func TestOpenAIStreamingTimeout(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
cfg := &config.Config{
|
||||
Gateway: config.GatewayConfig{
|
||||
StreamDataIntervalTimeout: 1,
|
||||
StreamKeepaliveInterval: 0,
|
||||
MaxLineSize: defaultMaxLineSize,
|
||||
},
|
||||
}
|
||||
svc := &OpenAIGatewayService{cfg: cfg}
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
|
||||
|
||||
pr, pw := io.Pipe()
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: pr,
|
||||
Header: http.Header{},
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
_, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 1}, start, "model", "model")
|
||||
_ = pw.Close()
|
||||
_ = pr.Close()
|
||||
|
||||
if err == nil || !strings.Contains(err.Error(), "stream data interval timeout") {
|
||||
t.Fatalf("expected stream timeout error, got %v", err)
|
||||
}
|
||||
if !strings.Contains(rec.Body.String(), "stream_timeout") {
|
||||
t.Fatalf("expected stream_timeout SSE error, got %q", rec.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAIStreamingTooLong(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
cfg := &config.Config{
|
||||
Gateway: config.GatewayConfig{
|
||||
StreamDataIntervalTimeout: 0,
|
||||
StreamKeepaliveInterval: 0,
|
||||
MaxLineSize: 64 * 1024,
|
||||
},
|
||||
}
|
||||
svc := &OpenAIGatewayService{cfg: cfg}
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
|
||||
|
||||
pr, pw := io.Pipe()
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: pr,
|
||||
Header: http.Header{},
|
||||
}
|
||||
|
||||
go func() {
|
||||
defer func() { _ = pw.Close() }()
|
||||
// 写入超过 MaxLineSize 的单行数据,触发 ErrTooLong
|
||||
payload := "data: " + strings.Repeat("a", 128*1024) + "\n"
|
||||
_, _ = pw.Write([]byte(payload))
|
||||
}()
|
||||
|
||||
_, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 2}, time.Now(), "model", "model")
|
||||
_ = pr.Close()
|
||||
|
||||
if !errors.Is(err, bufio.ErrTooLong) {
|
||||
t.Fatalf("expected ErrTooLong, got %v", err)
|
||||
}
|
||||
if !strings.Contains(rec.Body.String(), "response_too_large") {
|
||||
t.Fatalf("expected response_too_large SSE error, got %q", rec.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAINonStreamingContentTypePassThrough(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
cfg := &config.Config{
|
||||
Security: config.SecurityConfig{
|
||||
ResponseHeaders: config.ResponseHeaderConfig{Enabled: false},
|
||||
},
|
||||
}
|
||||
svc := &OpenAIGatewayService{cfg: cfg}
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
|
||||
|
||||
body := []byte(`{"usage":{"input_tokens":1,"output_tokens":2,"input_tokens_details":{"cached_tokens":0}}}`)
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: io.NopCloser(bytes.NewReader(body)),
|
||||
Header: http.Header{"Content-Type": []string{"application/vnd.test+json"}},
|
||||
}
|
||||
|
||||
_, err := svc.handleNonStreamingResponse(c.Request.Context(), resp, c, &Account{}, "model", "model")
|
||||
if err != nil {
|
||||
t.Fatalf("handleNonStreamingResponse error: %v", err)
|
||||
}
|
||||
|
||||
if !strings.Contains(rec.Header().Get("Content-Type"), "application/vnd.test+json") {
|
||||
t.Fatalf("expected Content-Type passthrough, got %q", rec.Header().Get("Content-Type"))
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAINonStreamingContentTypeDefault(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
cfg := &config.Config{
|
||||
Security: config.SecurityConfig{
|
||||
ResponseHeaders: config.ResponseHeaderConfig{Enabled: false},
|
||||
},
|
||||
}
|
||||
svc := &OpenAIGatewayService{cfg: cfg}
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
|
||||
|
||||
body := []byte(`{"usage":{"input_tokens":1,"output_tokens":2,"input_tokens_details":{"cached_tokens":0}}}`)
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: io.NopCloser(bytes.NewReader(body)),
|
||||
Header: http.Header{},
|
||||
}
|
||||
|
||||
_, err := svc.handleNonStreamingResponse(c.Request.Context(), resp, c, &Account{}, "model", "model")
|
||||
if err != nil {
|
||||
t.Fatalf("handleNonStreamingResponse error: %v", err)
|
||||
}
|
||||
|
||||
if !strings.Contains(rec.Header().Get("Content-Type"), "application/json") {
|
||||
t.Fatalf("expected default Content-Type, got %q", rec.Header().Get("Content-Type"))
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAIStreamingHeadersOverride(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
cfg := &config.Config{
|
||||
Security: config.SecurityConfig{
|
||||
ResponseHeaders: config.ResponseHeaderConfig{Enabled: false},
|
||||
},
|
||||
Gateway: config.GatewayConfig{
|
||||
StreamDataIntervalTimeout: 0,
|
||||
StreamKeepaliveInterval: 0,
|
||||
MaxLineSize: defaultMaxLineSize,
|
||||
},
|
||||
}
|
||||
svc := &OpenAIGatewayService{cfg: cfg}
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
|
||||
|
||||
pr, pw := io.Pipe()
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: pr,
|
||||
Header: http.Header{
|
||||
"Cache-Control": []string{"upstream"},
|
||||
"X-Request-Id": []string{"req-123"},
|
||||
"Content-Type": []string{"application/custom"},
|
||||
},
|
||||
}
|
||||
|
||||
go func() {
|
||||
defer func() { _ = pw.Close() }()
|
||||
_, _ = pw.Write([]byte("data: {}\n\n"))
|
||||
}()
|
||||
|
||||
_, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 1}, time.Now(), "model", "model")
|
||||
_ = pr.Close()
|
||||
if err != nil {
|
||||
t.Fatalf("handleStreamingResponse error: %v", err)
|
||||
}
|
||||
|
||||
if rec.Header().Get("Cache-Control") != "no-cache" {
|
||||
t.Fatalf("expected Cache-Control override, got %q", rec.Header().Get("Cache-Control"))
|
||||
}
|
||||
if rec.Header().Get("Content-Type") != "text/event-stream" {
|
||||
t.Fatalf("expected Content-Type override, got %q", rec.Header().Get("Content-Type"))
|
||||
}
|
||||
if rec.Header().Get("X-Request-Id") != "req-123" {
|
||||
t.Fatalf("expected X-Request-Id passthrough, got %q", rec.Header().Get("X-Request-Id"))
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAIInvalidBaseURLWhenAllowlistDisabled(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
cfg := &config.Config{
|
||||
Security: config.SecurityConfig{
|
||||
URLAllowlist: config.URLAllowlistConfig{Enabled: false},
|
||||
},
|
||||
}
|
||||
svc := &OpenAIGatewayService{cfg: cfg}
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
|
||||
|
||||
account := &Account{
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeAPIKey,
|
||||
Credentials: map[string]any{"base_url": "://invalid-url"},
|
||||
}
|
||||
|
||||
_, err := svc.buildUpstreamRequest(c.Request.Context(), c, account, []byte("{}"), "token", false)
|
||||
if err == nil {
|
||||
t.Fatalf("expected error for invalid base_url when allowlist disabled")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAIValidateUpstreamBaseURLDisabledRequiresHTTPS(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
Security: config.SecurityConfig{
|
||||
URLAllowlist: config.URLAllowlistConfig{Enabled: false},
|
||||
},
|
||||
}
|
||||
svc := &OpenAIGatewayService{cfg: cfg}
|
||||
|
||||
if _, err := svc.validateUpstreamBaseURL("http://not-https.example.com"); err == nil {
|
||||
t.Fatalf("expected http to be rejected when allow_insecure_http is false")
|
||||
}
|
||||
normalized, err := svc.validateUpstreamBaseURL("https://example.com")
|
||||
if err != nil {
|
||||
t.Fatalf("expected https to be allowed when allowlist disabled, got %v", err)
|
||||
}
|
||||
if normalized != "https://example.com" {
|
||||
t.Fatalf("expected raw url passthrough, got %q", normalized)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAIValidateUpstreamBaseURLDisabledAllowsHTTP(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
Security: config.SecurityConfig{
|
||||
URLAllowlist: config.URLAllowlistConfig{
|
||||
Enabled: false,
|
||||
AllowInsecureHTTP: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
svc := &OpenAIGatewayService{cfg: cfg}
|
||||
|
||||
normalized, err := svc.validateUpstreamBaseURL("http://not-https.example.com")
|
||||
if err != nil {
|
||||
t.Fatalf("expected http allowed when allow_insecure_http is true, got %v", err)
|
||||
}
|
||||
if normalized != "http://not-https.example.com" {
|
||||
t.Fatalf("expected raw url passthrough, got %q", normalized)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAIValidateUpstreamBaseURLEnabledEnforcesAllowlist(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
Security: config.SecurityConfig{
|
||||
URLAllowlist: config.URLAllowlistConfig{
|
||||
Enabled: true,
|
||||
UpstreamHosts: []string{"example.com"},
|
||||
},
|
||||
},
|
||||
}
|
||||
svc := &OpenAIGatewayService{cfg: cfg}
|
||||
|
||||
if _, err := svc.validateUpstreamBaseURL("https://example.com"); err != nil {
|
||||
t.Fatalf("expected allowlisted host to pass, got %v", err)
|
||||
}
|
||||
if _, err := svc.validateUpstreamBaseURL("https://evil.com"); err == nil {
|
||||
t.Fatalf("expected non-allowlisted host to fail")
|
||||
}
|
||||
}
|
||||
@@ -16,6 +16,7 @@ import (
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
||||
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -213,16 +214,35 @@ func (s *PricingService) syncWithRemote() error {
|
||||
|
||||
// downloadPricingData 从远程下载价格数据
|
||||
func (s *PricingService) downloadPricingData() error {
|
||||
log.Printf("[Pricing] Downloading from %s", s.cfg.Pricing.RemoteURL)
|
||||
remoteURL, err := s.validatePricingURL(s.cfg.Pricing.RemoteURL)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
log.Printf("[Pricing] Downloading from %s", remoteURL)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
body, err := s.remoteClient.FetchPricingJSON(ctx, s.cfg.Pricing.RemoteURL)
|
||||
var expectedHash string
|
||||
if strings.TrimSpace(s.cfg.Pricing.HashURL) != "" {
|
||||
expectedHash, err = s.fetchRemoteHash()
|
||||
if err != nil {
|
||||
return fmt.Errorf("fetch remote hash: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
body, err := s.remoteClient.FetchPricingJSON(ctx, remoteURL)
|
||||
if err != nil {
|
||||
return fmt.Errorf("download failed: %w", err)
|
||||
}
|
||||
|
||||
if expectedHash != "" {
|
||||
actualHash := sha256.Sum256(body)
|
||||
if !strings.EqualFold(expectedHash, hex.EncodeToString(actualHash[:])) {
|
||||
return fmt.Errorf("pricing hash mismatch")
|
||||
}
|
||||
}
|
||||
|
||||
// 解析JSON数据(使用灵活的解析方式)
|
||||
data, err := s.parsePricingData(body)
|
||||
if err != nil {
|
||||
@@ -378,10 +398,38 @@ func (s *PricingService) useFallbackPricing() error {
|
||||
|
||||
// fetchRemoteHash 从远程获取哈希值
|
||||
func (s *PricingService) fetchRemoteHash() (string, error) {
|
||||
hashURL, err := s.validatePricingURL(s.cfg.Pricing.HashURL)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
return s.remoteClient.FetchHashText(ctx, s.cfg.Pricing.HashURL)
|
||||
hash, err := s.remoteClient.FetchHashText(ctx, hashURL)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return strings.TrimSpace(hash), nil
|
||||
}
|
||||
|
||||
func (s *PricingService) validatePricingURL(raw string) (string, error) {
|
||||
if s.cfg != nil && !s.cfg.Security.URLAllowlist.Enabled {
|
||||
normalized, err := urlvalidator.ValidateURLFormat(raw, s.cfg.Security.URLAllowlist.AllowInsecureHTTP)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("invalid pricing url: %w", err)
|
||||
}
|
||||
return normalized, nil
|
||||
}
|
||||
normalized, err := urlvalidator.ValidateHTTPSURL(raw, urlvalidator.ValidationOptions{
|
||||
AllowedHosts: s.cfg.Security.URLAllowlist.PricingHosts,
|
||||
RequireAllowlist: true,
|
||||
AllowPrivate: s.cfg.Security.URLAllowlist.AllowPrivateHosts,
|
||||
})
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("invalid pricing url: %w", err)
|
||||
}
|
||||
return normalized, nil
|
||||
}
|
||||
|
||||
// computeFileHash 计算文件哈希
|
||||
|
||||
@@ -130,6 +130,10 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet
|
||||
updates[SettingKeyFallbackModelGemini] = settings.FallbackModelGemini
|
||||
updates[SettingKeyFallbackModelAntigravity] = settings.FallbackModelAntigravity
|
||||
|
||||
// Identity patch configuration (Claude -> Gemini)
|
||||
updates[SettingKeyEnableIdentityPatch] = strconv.FormatBool(settings.EnableIdentityPatch)
|
||||
updates[SettingKeyIdentityPatchPrompt] = settings.IdentityPatchPrompt
|
||||
|
||||
return s.settingRepo.SetMultiple(ctx, updates)
|
||||
}
|
||||
|
||||
@@ -213,6 +217,9 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
|
||||
SettingKeyFallbackModelOpenAI: "gpt-4o",
|
||||
SettingKeyFallbackModelGemini: "gemini-2.5-pro",
|
||||
SettingKeyFallbackModelAntigravity: "gemini-2.5-pro",
|
||||
// Identity patch defaults
|
||||
SettingKeyEnableIdentityPatch: "true",
|
||||
SettingKeyIdentityPatchPrompt: "",
|
||||
}
|
||||
|
||||
return s.settingRepo.SetMultiple(ctx, defaults)
|
||||
@@ -221,21 +228,23 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
|
||||
// parseSettings 解析设置到结构体
|
||||
func (s *SettingService) parseSettings(settings map[string]string) *SystemSettings {
|
||||
result := &SystemSettings{
|
||||
RegistrationEnabled: settings[SettingKeyRegistrationEnabled] == "true",
|
||||
EmailVerifyEnabled: settings[SettingKeyEmailVerifyEnabled] == "true",
|
||||
SMTPHost: settings[SettingKeySMTPHost],
|
||||
SMTPUsername: settings[SettingKeySMTPUsername],
|
||||
SMTPFrom: settings[SettingKeySMTPFrom],
|
||||
SMTPFromName: settings[SettingKeySMTPFromName],
|
||||
SMTPUseTLS: settings[SettingKeySMTPUseTLS] == "true",
|
||||
TurnstileEnabled: settings[SettingKeyTurnstileEnabled] == "true",
|
||||
TurnstileSiteKey: settings[SettingKeyTurnstileSiteKey],
|
||||
SiteName: s.getStringOrDefault(settings, SettingKeySiteName, "Sub2API"),
|
||||
SiteLogo: settings[SettingKeySiteLogo],
|
||||
SiteSubtitle: s.getStringOrDefault(settings, SettingKeySiteSubtitle, "Subscription to API Conversion Platform"),
|
||||
APIBaseURL: settings[SettingKeyAPIBaseURL],
|
||||
ContactInfo: settings[SettingKeyContactInfo],
|
||||
DocURL: settings[SettingKeyDocURL],
|
||||
RegistrationEnabled: settings[SettingKeyRegistrationEnabled] == "true",
|
||||
EmailVerifyEnabled: settings[SettingKeyEmailVerifyEnabled] == "true",
|
||||
SMTPHost: settings[SettingKeySMTPHost],
|
||||
SMTPUsername: settings[SettingKeySMTPUsername],
|
||||
SMTPFrom: settings[SettingKeySMTPFrom],
|
||||
SMTPFromName: settings[SettingKeySMTPFromName],
|
||||
SMTPUseTLS: settings[SettingKeySMTPUseTLS] == "true",
|
||||
SMTPPasswordConfigured: settings[SettingKeySMTPPassword] != "",
|
||||
TurnstileEnabled: settings[SettingKeyTurnstileEnabled] == "true",
|
||||
TurnstileSiteKey: settings[SettingKeyTurnstileSiteKey],
|
||||
TurnstileSecretKeyConfigured: settings[SettingKeyTurnstileSecretKey] != "",
|
||||
SiteName: s.getStringOrDefault(settings, SettingKeySiteName, "Sub2API"),
|
||||
SiteLogo: settings[SettingKeySiteLogo],
|
||||
SiteSubtitle: s.getStringOrDefault(settings, SettingKeySiteSubtitle, "Subscription to API Conversion Platform"),
|
||||
APIBaseURL: settings[SettingKeyAPIBaseURL],
|
||||
ContactInfo: settings[SettingKeyContactInfo],
|
||||
DocURL: settings[SettingKeyDocURL],
|
||||
}
|
||||
|
||||
// 解析整数类型
|
||||
@@ -269,6 +278,14 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
|
||||
result.FallbackModelGemini = s.getStringOrDefault(settings, SettingKeyFallbackModelGemini, "gemini-2.5-pro")
|
||||
result.FallbackModelAntigravity = s.getStringOrDefault(settings, SettingKeyFallbackModelAntigravity, "gemini-2.5-pro")
|
||||
|
||||
// Identity patch settings (default: enabled, to preserve existing behavior)
|
||||
if v, ok := settings[SettingKeyEnableIdentityPatch]; ok && v != "" {
|
||||
result.EnableIdentityPatch = v == "true"
|
||||
} else {
|
||||
result.EnableIdentityPatch = true
|
||||
}
|
||||
result.IdentityPatchPrompt = settings[SettingKeyIdentityPatchPrompt]
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
@@ -298,6 +315,25 @@ func (s *SettingService) GetTurnstileSecretKey(ctx context.Context) string {
|
||||
return value
|
||||
}
|
||||
|
||||
// IsIdentityPatchEnabled 检查是否启用身份补丁(Claude -> Gemini systemInstruction 注入)
|
||||
func (s *SettingService) IsIdentityPatchEnabled(ctx context.Context) bool {
|
||||
value, err := s.settingRepo.GetValue(ctx, SettingKeyEnableIdentityPatch)
|
||||
if err != nil {
|
||||
// 默认开启,保持兼容
|
||||
return true
|
||||
}
|
||||
return value == "true"
|
||||
}
|
||||
|
||||
// GetIdentityPatchPrompt 获取自定义身份补丁提示词(为空表示使用内置默认模板)
|
||||
func (s *SettingService) GetIdentityPatchPrompt(ctx context.Context) string {
|
||||
value, err := s.settingRepo.GetValue(ctx, SettingKeyIdentityPatchPrompt)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return value
|
||||
}
|
||||
|
||||
// GenerateAdminAPIKey 生成新的管理员 API Key
|
||||
func (s *SettingService) GenerateAdminAPIKey(ctx context.Context) (string, error) {
|
||||
// 生成 32 字节随机数 = 64 位十六进制字符
|
||||
|
||||
@@ -4,17 +4,19 @@ type SystemSettings struct {
|
||||
RegistrationEnabled bool
|
||||
EmailVerifyEnabled bool
|
||||
|
||||
SMTPHost string
|
||||
SMTPPort int
|
||||
SMTPUsername string
|
||||
SMTPPassword string
|
||||
SMTPFrom string
|
||||
SMTPFromName string
|
||||
SMTPUseTLS bool
|
||||
SMTPHost string
|
||||
SMTPPort int
|
||||
SMTPUsername string
|
||||
SMTPPassword string
|
||||
SMTPPasswordConfigured bool
|
||||
SMTPFrom string
|
||||
SMTPFromName string
|
||||
SMTPUseTLS bool
|
||||
|
||||
TurnstileEnabled bool
|
||||
TurnstileSiteKey string
|
||||
TurnstileSecretKey string
|
||||
TurnstileEnabled bool
|
||||
TurnstileSiteKey string
|
||||
TurnstileSecretKey string
|
||||
TurnstileSecretKeyConfigured bool
|
||||
|
||||
SiteName string
|
||||
SiteLogo string
|
||||
@@ -32,6 +34,10 @@ type SystemSettings struct {
|
||||
FallbackModelOpenAI string `json:"fallback_model_openai"`
|
||||
FallbackModelGemini string `json:"fallback_model_gemini"`
|
||||
FallbackModelAntigravity string `json:"fallback_model_antigravity"`
|
||||
|
||||
// Identity patch configuration (Claude -> Gemini)
|
||||
EnableIdentityPatch bool `json:"enable_identity_patch"`
|
||||
IdentityPatchPrompt string `json:"identity_patch_prompt"`
|
||||
}
|
||||
|
||||
type PublicSettings struct {
|
||||
|
||||
Reference in New Issue
Block a user