feat(安全): 强化安全策略与配置校验
- 增加 CORS/CSP/安全响应头与代理信任配置 - 引入 URL 白名单与私网开关,校验上游与价格源 - 改善 API Key 处理与网关错误返回 - 管理端设置隐藏敏感字段并优化前端提示 - 增加计费熔断与相关配置示例 测试: go test ./...
This commit is contained in:
@@ -7,6 +7,7 @@ import (
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
@@ -15,9 +16,11 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/geminicli"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
||||
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
@@ -49,6 +52,7 @@ type AccountTestService struct {
|
||||
geminiTokenProvider *GeminiTokenProvider
|
||||
antigravityGatewayService *AntigravityGatewayService
|
||||
httpUpstream HTTPUpstream
|
||||
cfg *config.Config
|
||||
}
|
||||
|
||||
// NewAccountTestService creates a new AccountTestService
|
||||
@@ -59,6 +63,7 @@ func NewAccountTestService(
|
||||
geminiTokenProvider *GeminiTokenProvider,
|
||||
antigravityGatewayService *AntigravityGatewayService,
|
||||
httpUpstream HTTPUpstream,
|
||||
cfg *config.Config,
|
||||
) *AccountTestService {
|
||||
return &AccountTestService{
|
||||
accountRepo: accountRepo,
|
||||
@@ -67,9 +72,25 @@ func NewAccountTestService(
|
||||
geminiTokenProvider: geminiTokenProvider,
|
||||
antigravityGatewayService: antigravityGatewayService,
|
||||
httpUpstream: httpUpstream,
|
||||
cfg: cfg,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *AccountTestService) validateUpstreamBaseURL(raw string) (string, error) {
|
||||
if s.cfg == nil {
|
||||
return "", errors.New("config is not available")
|
||||
}
|
||||
normalized, err := urlvalidator.ValidateHTTPSURL(raw, urlvalidator.ValidationOptions{
|
||||
AllowedHosts: s.cfg.Security.URLAllowlist.UpstreamHosts,
|
||||
RequireAllowlist: true,
|
||||
AllowPrivate: s.cfg.Security.URLAllowlist.AllowPrivateHosts,
|
||||
})
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return normalized, nil
|
||||
}
|
||||
|
||||
// generateSessionString generates a Claude Code style session string
|
||||
func generateSessionString() (string, error) {
|
||||
bytes := make([]byte, 32)
|
||||
@@ -207,11 +228,15 @@ func (s *AccountTestService) testClaudeAccountConnection(c *gin.Context, account
|
||||
return s.sendErrorAndEnd(c, "No API key available")
|
||||
}
|
||||
|
||||
apiURL = account.GetBaseURL()
|
||||
if apiURL == "" {
|
||||
apiURL = "https://api.anthropic.com"
|
||||
baseURL := account.GetBaseURL()
|
||||
if baseURL == "" {
|
||||
baseURL = "https://api.anthropic.com"
|
||||
}
|
||||
apiURL = strings.TrimSuffix(apiURL, "/") + "/v1/messages"
|
||||
normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
|
||||
if err != nil {
|
||||
return s.sendErrorAndEnd(c, fmt.Sprintf("Invalid base URL: %s", err.Error()))
|
||||
}
|
||||
apiURL = strings.TrimSuffix(normalizedBaseURL, "/") + "/v1/messages"
|
||||
} else {
|
||||
return s.sendErrorAndEnd(c, fmt.Sprintf("Unsupported account type: %s", account.Type))
|
||||
}
|
||||
@@ -333,7 +358,11 @@ func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account
|
||||
if baseURL == "" {
|
||||
baseURL = "https://api.openai.com"
|
||||
}
|
||||
apiURL = strings.TrimSuffix(baseURL, "/") + "/responses"
|
||||
normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
|
||||
if err != nil {
|
||||
return s.sendErrorAndEnd(c, fmt.Sprintf("Invalid base URL: %s", err.Error()))
|
||||
}
|
||||
apiURL = strings.TrimSuffix(normalizedBaseURL, "/") + "/responses"
|
||||
} else {
|
||||
return s.sendErrorAndEnd(c, fmt.Sprintf("Unsupported account type: %s", account.Type))
|
||||
}
|
||||
@@ -513,10 +542,14 @@ func (s *AccountTestService) buildGeminiAPIKeyRequest(ctx context.Context, accou
|
||||
if baseURL == "" {
|
||||
baseURL = geminicli.AIStudioBaseURL
|
||||
}
|
||||
normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Use streamGenerateContent for real-time feedback
|
||||
fullURL := fmt.Sprintf("%s/v1beta/models/%s:streamGenerateContent?alt=sse",
|
||||
strings.TrimRight(baseURL, "/"), modelID)
|
||||
strings.TrimRight(normalizedBaseURL, "/"), modelID)
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", fullURL, bytes.NewReader(payload))
|
||||
if err != nil {
|
||||
@@ -548,7 +581,11 @@ func (s *AccountTestService) buildGeminiOAuthRequest(ctx context.Context, accoun
|
||||
if strings.TrimSpace(baseURL) == "" {
|
||||
baseURL = geminicli.AIStudioBaseURL
|
||||
}
|
||||
fullURL := fmt.Sprintf("%s/v1beta/models/%s:streamGenerateContent?alt=sse", strings.TrimRight(baseURL, "/"), modelID)
|
||||
normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
fullURL := fmt.Sprintf("%s/v1beta/models/%s:streamGenerateContent?alt=sse", strings.TrimRight(normalizedBaseURL, "/"), modelID)
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader(payload))
|
||||
if err != nil {
|
||||
@@ -577,7 +614,11 @@ func (s *AccountTestService) buildCodeAssistRequest(ctx context.Context, accessT
|
||||
}
|
||||
wrappedBytes, _ := json.Marshal(wrapped)
|
||||
|
||||
fullURL := fmt.Sprintf("%s/v1internal:streamGenerateContent?alt=sse", geminicli.GeminiCliBaseURL)
|
||||
normalizedBaseURL, err := s.validateUpstreamBaseURL(geminicli.GeminiCliBaseURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
fullURL := fmt.Sprintf("%s/v1internal:streamGenerateContent?alt=sse", normalizedBaseURL)
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", fullURL, bytes.NewReader(wrappedBytes))
|
||||
if err != nil {
|
||||
|
||||
@@ -221,9 +221,33 @@ func (s *AuthService) SendVerifyCodeAsync(ctx context.Context, email string) (*S
|
||||
|
||||
// VerifyTurnstile 验证Turnstile token
|
||||
func (s *AuthService) VerifyTurnstile(ctx context.Context, token string, remoteIP string) error {
|
||||
required := s.cfg != nil && s.cfg.Server.Mode == "release" && s.cfg.Turnstile.Required
|
||||
|
||||
if required {
|
||||
if s.settingService == nil {
|
||||
log.Println("[Auth] Turnstile required but settings service is not configured")
|
||||
return ErrTurnstileNotConfigured
|
||||
}
|
||||
enabled := s.settingService.IsTurnstileEnabled(ctx)
|
||||
secretConfigured := s.settingService.GetTurnstileSecretKey(ctx) != ""
|
||||
if !enabled || !secretConfigured {
|
||||
log.Printf("[Auth] Turnstile required but not configured (enabled=%v, secret_configured=%v)", enabled, secretConfigured)
|
||||
return ErrTurnstileNotConfigured
|
||||
}
|
||||
}
|
||||
|
||||
if s.turnstileService == nil {
|
||||
if required {
|
||||
log.Println("[Auth] Turnstile required but service not configured")
|
||||
return ErrTurnstileNotConfigured
|
||||
}
|
||||
return nil // 服务未配置则跳过验证
|
||||
}
|
||||
|
||||
if !required && s.settingService != nil && s.settingService.IsTurnstileEnabled(ctx) && s.settingService.GetTurnstileSecretKey(ctx) == "" {
|
||||
log.Println("[Auth] Turnstile enabled but secret key not configured")
|
||||
}
|
||||
|
||||
return s.turnstileService.VerifyToken(ctx, token, remoteIP)
|
||||
}
|
||||
|
||||
|
||||
@@ -17,6 +17,7 @@ import (
|
||||
// 注:ErrDailyLimitExceeded/ErrWeeklyLimitExceeded/ErrMonthlyLimitExceeded在subscription_service.go中定义
|
||||
var (
|
||||
ErrSubscriptionInvalid = infraerrors.Forbidden("SUBSCRIPTION_INVALID", "subscription is invalid or expired")
|
||||
ErrBillingServiceUnavailable = infraerrors.ServiceUnavailable("BILLING_SERVICE_ERROR", "Billing service temporarily unavailable. Please retry later.")
|
||||
)
|
||||
|
||||
// subscriptionCacheData 订阅缓存数据结构(内部使用)
|
||||
@@ -76,6 +77,7 @@ type BillingCacheService struct {
|
||||
userRepo UserRepository
|
||||
subRepo UserSubscriptionRepository
|
||||
cfg *config.Config
|
||||
circuitBreaker *billingCircuitBreaker
|
||||
|
||||
cacheWriteChan chan cacheWriteTask
|
||||
cacheWriteWg sync.WaitGroup
|
||||
@@ -95,6 +97,7 @@ func NewBillingCacheService(cache BillingCache, userRepo UserRepository, subRepo
|
||||
subRepo: subRepo,
|
||||
cfg: cfg,
|
||||
}
|
||||
svc.circuitBreaker = newBillingCircuitBreaker(cfg.Billing.CircuitBreaker)
|
||||
svc.startCacheWriteWorkers()
|
||||
return svc
|
||||
}
|
||||
@@ -450,6 +453,9 @@ func (s *BillingCacheService) CheckBillingEligibility(ctx context.Context, user
|
||||
if s.cfg.RunMode == config.RunModeSimple {
|
||||
return nil
|
||||
}
|
||||
if s.circuitBreaker != nil && !s.circuitBreaker.Allow() {
|
||||
return ErrBillingServiceUnavailable
|
||||
}
|
||||
|
||||
// 判断计费模式
|
||||
isSubscriptionMode := group != nil && group.IsSubscriptionType() && subscription != nil
|
||||
@@ -465,9 +471,14 @@ func (s *BillingCacheService) CheckBillingEligibility(ctx context.Context, user
|
||||
func (s *BillingCacheService) checkBalanceEligibility(ctx context.Context, userID int64) error {
|
||||
balance, err := s.GetUserBalance(ctx, userID)
|
||||
if err != nil {
|
||||
// 缓存/数据库错误,允许通过(降级处理)
|
||||
log.Printf("Warning: get user balance failed, allowing request: %v", err)
|
||||
return nil
|
||||
if s.circuitBreaker != nil {
|
||||
s.circuitBreaker.OnFailure(err)
|
||||
}
|
||||
log.Printf("ALERT: billing balance check failed for user %d: %v", userID, err)
|
||||
return ErrBillingServiceUnavailable.WithCause(err)
|
||||
}
|
||||
if s.circuitBreaker != nil {
|
||||
s.circuitBreaker.OnSuccess()
|
||||
}
|
||||
|
||||
if balance <= 0 {
|
||||
@@ -482,9 +493,14 @@ func (s *BillingCacheService) checkSubscriptionEligibility(ctx context.Context,
|
||||
// 获取订阅缓存数据
|
||||
subData, err := s.GetSubscriptionStatus(ctx, userID, group.ID)
|
||||
if err != nil {
|
||||
// 缓存/数据库错误,降级使用传入的subscription进行检查
|
||||
log.Printf("Warning: get subscription cache failed, using fallback: %v", err)
|
||||
return s.checkSubscriptionLimitsFallback(subscription, group)
|
||||
if s.circuitBreaker != nil {
|
||||
s.circuitBreaker.OnFailure(err)
|
||||
}
|
||||
log.Printf("ALERT: billing subscription check failed for user %d group %d: %v", userID, group.ID, err)
|
||||
return ErrBillingServiceUnavailable.WithCause(err)
|
||||
}
|
||||
if s.circuitBreaker != nil {
|
||||
s.circuitBreaker.OnSuccess()
|
||||
}
|
||||
|
||||
// 检查订阅状态
|
||||
@@ -513,6 +529,137 @@ func (s *BillingCacheService) checkSubscriptionEligibility(ctx context.Context,
|
||||
return nil
|
||||
}
|
||||
|
||||
type billingCircuitBreakerState int
|
||||
|
||||
const (
|
||||
billingCircuitClosed billingCircuitBreakerState = iota
|
||||
billingCircuitOpen
|
||||
billingCircuitHalfOpen
|
||||
)
|
||||
|
||||
type billingCircuitBreaker struct {
|
||||
mu sync.Mutex
|
||||
state billingCircuitBreakerState
|
||||
failures int
|
||||
openedAt time.Time
|
||||
failureThreshold int
|
||||
resetTimeout time.Duration
|
||||
halfOpenRequests int
|
||||
halfOpenRemaining int
|
||||
}
|
||||
|
||||
func newBillingCircuitBreaker(cfg config.CircuitBreakerConfig) *billingCircuitBreaker {
|
||||
if !cfg.Enabled {
|
||||
return nil
|
||||
}
|
||||
resetTimeout := time.Duration(cfg.ResetTimeoutSeconds) * time.Second
|
||||
if resetTimeout <= 0 {
|
||||
resetTimeout = 30 * time.Second
|
||||
}
|
||||
halfOpen := cfg.HalfOpenRequests
|
||||
if halfOpen <= 0 {
|
||||
halfOpen = 1
|
||||
}
|
||||
threshold := cfg.FailureThreshold
|
||||
if threshold <= 0 {
|
||||
threshold = 5
|
||||
}
|
||||
return &billingCircuitBreaker{
|
||||
state: billingCircuitClosed,
|
||||
failureThreshold: threshold,
|
||||
resetTimeout: resetTimeout,
|
||||
halfOpenRequests: halfOpen,
|
||||
}
|
||||
}
|
||||
|
||||
func (b *billingCircuitBreaker) Allow() bool {
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
|
||||
switch b.state {
|
||||
case billingCircuitClosed:
|
||||
return true
|
||||
case billingCircuitOpen:
|
||||
if time.Since(b.openedAt) < b.resetTimeout {
|
||||
return false
|
||||
}
|
||||
b.state = billingCircuitHalfOpen
|
||||
b.halfOpenRemaining = b.halfOpenRequests
|
||||
log.Printf("ALERT: billing circuit breaker entering half-open state")
|
||||
fallthrough
|
||||
case billingCircuitHalfOpen:
|
||||
if b.halfOpenRemaining <= 0 {
|
||||
return false
|
||||
}
|
||||
b.halfOpenRemaining--
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func (b *billingCircuitBreaker) OnFailure(err error) {
|
||||
if b == nil {
|
||||
return
|
||||
}
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
|
||||
switch b.state {
|
||||
case billingCircuitOpen:
|
||||
return
|
||||
case billingCircuitHalfOpen:
|
||||
b.state = billingCircuitOpen
|
||||
b.openedAt = time.Now()
|
||||
b.halfOpenRemaining = 0
|
||||
log.Printf("ALERT: billing circuit breaker opened after half-open failure: %v", err)
|
||||
return
|
||||
default:
|
||||
b.failures++
|
||||
if b.failures >= b.failureThreshold {
|
||||
b.state = billingCircuitOpen
|
||||
b.openedAt = time.Now()
|
||||
b.halfOpenRemaining = 0
|
||||
log.Printf("ALERT: billing circuit breaker opened after %d failures: %v", b.failures, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (b *billingCircuitBreaker) OnSuccess() {
|
||||
if b == nil {
|
||||
return
|
||||
}
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
|
||||
previousState := b.state
|
||||
previousFailures := b.failures
|
||||
|
||||
b.state = billingCircuitClosed
|
||||
b.failures = 0
|
||||
b.halfOpenRemaining = 0
|
||||
|
||||
// 只有状态真正发生变化时才记录日志
|
||||
if previousState != billingCircuitClosed {
|
||||
log.Printf("ALERT: billing circuit breaker closed (was %s)", circuitStateString(previousState))
|
||||
} else if previousFailures > 0 {
|
||||
log.Printf("INFO: billing circuit breaker failures reset from %d", previousFailures)
|
||||
}
|
||||
}
|
||||
|
||||
func circuitStateString(state billingCircuitBreakerState) string {
|
||||
switch state {
|
||||
case billingCircuitClosed:
|
||||
return "closed"
|
||||
case billingCircuitOpen:
|
||||
return "open"
|
||||
case billingCircuitHalfOpen:
|
||||
return "half-open"
|
||||
default:
|
||||
return "unknown"
|
||||
}
|
||||
}
|
||||
|
||||
// checkSubscriptionLimitsFallback 降级检查订阅限额
|
||||
func (s *BillingCacheService) checkSubscriptionLimitsFallback(subscription *UserSubscription, group *Group) error {
|
||||
if subscription == nil {
|
||||
|
||||
@@ -8,12 +8,13 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
|
||||
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
|
||||
)
|
||||
|
||||
type CRSSyncService struct {
|
||||
@@ -22,6 +23,7 @@ type CRSSyncService struct {
|
||||
oauthService *OAuthService
|
||||
openaiOAuthService *OpenAIOAuthService
|
||||
geminiOAuthService *GeminiOAuthService
|
||||
cfg *config.Config
|
||||
}
|
||||
|
||||
func NewCRSSyncService(
|
||||
@@ -30,6 +32,7 @@ func NewCRSSyncService(
|
||||
oauthService *OAuthService,
|
||||
openaiOAuthService *OpenAIOAuthService,
|
||||
geminiOAuthService *GeminiOAuthService,
|
||||
cfg *config.Config,
|
||||
) *CRSSyncService {
|
||||
return &CRSSyncService{
|
||||
accountRepo: accountRepo,
|
||||
@@ -37,6 +40,7 @@ func NewCRSSyncService(
|
||||
oauthService: oauthService,
|
||||
openaiOAuthService: openaiOAuthService,
|
||||
geminiOAuthService: geminiOAuthService,
|
||||
cfg: cfg,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -187,7 +191,10 @@ type crsGeminiAPIKeyAccount struct {
|
||||
}
|
||||
|
||||
func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput) (*SyncFromCRSResult, error) {
|
||||
baseURL, err := normalizeBaseURL(input.BaseURL)
|
||||
if s.cfg == nil {
|
||||
return nil, errors.New("config is not available")
|
||||
}
|
||||
baseURL, err := normalizeBaseURL(input.BaseURL, s.cfg.Security.URLAllowlist.CRSHosts, s.cfg.Security.URLAllowlist.AllowPrivateHosts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -1055,17 +1062,18 @@ func mapCRSStatus(isActive bool, status string) string {
|
||||
return "active"
|
||||
}
|
||||
|
||||
func normalizeBaseURL(raw string) (string, error) {
|
||||
trimmed := strings.TrimSpace(raw)
|
||||
if trimmed == "" {
|
||||
return "", errors.New("base_url is required")
|
||||
func normalizeBaseURL(raw string, allowlist []string, allowPrivate bool) (string, error) {
|
||||
// 当 allowlist 为空时,不强制要求白名单(只进行基本的 URL 和 SSRF 验证)
|
||||
requireAllowlist := len(allowlist) > 0
|
||||
normalized, err := urlvalidator.ValidateHTTPSURL(raw, urlvalidator.ValidationOptions{
|
||||
AllowedHosts: allowlist,
|
||||
RequireAllowlist: requireAllowlist,
|
||||
AllowPrivate: allowPrivate,
|
||||
})
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("invalid base_url: %w", err)
|
||||
}
|
||||
u, err := url.Parse(trimmed)
|
||||
if err != nil || u.Scheme == "" || u.Host == "" {
|
||||
return "", fmt.Errorf("invalid base_url: %s", trimmed)
|
||||
}
|
||||
u.Path = strings.TrimRight(u.Path, "/")
|
||||
return strings.TrimRight(u.String(), "/"), nil
|
||||
return normalized, nil
|
||||
}
|
||||
|
||||
// cleanBaseURL removes trailing suffix from base_url in credentials
|
||||
|
||||
@@ -19,6 +19,8 @@ import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
||||
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
|
||||
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
|
||||
"github.com/tidwall/sjson"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -724,7 +726,13 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
|
||||
targetURL := claudeAPIURL
|
||||
if account.Type == AccountTypeApiKey {
|
||||
baseURL := account.GetBaseURL()
|
||||
targetURL = baseURL + "/v1/messages"
|
||||
if baseURL != "" {
|
||||
validatedURL, err := s.validateUpstreamBaseURL(baseURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
targetURL = validatedURL + "/v1/messages"
|
||||
}
|
||||
}
|
||||
|
||||
// OAuth账号:应用统一指纹
|
||||
@@ -1107,12 +1115,7 @@ func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *h
|
||||
body = s.replaceModelInResponseBody(body, mappedModel, originalModel)
|
||||
}
|
||||
|
||||
// 透传响应头
|
||||
for key, values := range resp.Header {
|
||||
for _, value := range values {
|
||||
c.Header(key, value)
|
||||
}
|
||||
}
|
||||
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.cfg.Security.ResponseHeaders)
|
||||
|
||||
// 写入响应
|
||||
c.Data(resp.StatusCode, "application/json", body)
|
||||
@@ -1352,7 +1355,13 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
|
||||
targetURL := claudeAPICountTokensURL
|
||||
if account.Type == AccountTypeApiKey {
|
||||
baseURL := account.GetBaseURL()
|
||||
targetURL = baseURL + "/v1/messages/count_tokens"
|
||||
if baseURL != "" {
|
||||
validatedURL, err := s.validateUpstreamBaseURL(baseURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
targetURL = validatedURL + "/v1/messages/count_tokens"
|
||||
}
|
||||
}
|
||||
|
||||
// OAuth 账号:应用统一指纹和重写 userID
|
||||
@@ -1424,3 +1433,15 @@ func (s *GatewayService) countTokensError(c *gin.Context, status int, errType, m
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func (s *GatewayService) validateUpstreamBaseURL(raw string) (string, error) {
|
||||
normalized, err := urlvalidator.ValidateHTTPSURL(raw, urlvalidator.ValidationOptions{
|
||||
AllowedHosts: s.cfg.Security.URLAllowlist.UpstreamHosts,
|
||||
RequireAllowlist: true,
|
||||
AllowPrivate: s.cfg.Security.URLAllowlist.AllowPrivateHosts,
|
||||
})
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("invalid base_url: %w", err)
|
||||
}
|
||||
return normalized, nil
|
||||
}
|
||||
|
||||
@@ -18,9 +18,12 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/geminicli"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/googleapi"
|
||||
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
|
||||
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
@@ -41,6 +44,7 @@ type GeminiMessagesCompatService struct {
|
||||
rateLimitService *RateLimitService
|
||||
httpUpstream HTTPUpstream
|
||||
antigravityGatewayService *AntigravityGatewayService
|
||||
cfg *config.Config
|
||||
}
|
||||
|
||||
func NewGeminiMessagesCompatService(
|
||||
@@ -51,6 +55,7 @@ func NewGeminiMessagesCompatService(
|
||||
rateLimitService *RateLimitService,
|
||||
httpUpstream HTTPUpstream,
|
||||
antigravityGatewayService *AntigravityGatewayService,
|
||||
cfg *config.Config,
|
||||
) *GeminiMessagesCompatService {
|
||||
return &GeminiMessagesCompatService{
|
||||
accountRepo: accountRepo,
|
||||
@@ -60,6 +65,7 @@ func NewGeminiMessagesCompatService(
|
||||
rateLimitService: rateLimitService,
|
||||
httpUpstream: httpUpstream,
|
||||
antigravityGatewayService: antigravityGatewayService,
|
||||
cfg: cfg,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -209,6 +215,18 @@ func (s *GeminiMessagesCompatService) GetAntigravityGatewayService() *Antigravit
|
||||
return s.antigravityGatewayService
|
||||
}
|
||||
|
||||
func (s *GeminiMessagesCompatService) validateUpstreamBaseURL(raw string) (string, error) {
|
||||
normalized, err := urlvalidator.ValidateHTTPSURL(raw, urlvalidator.ValidationOptions{
|
||||
AllowedHosts: s.cfg.Security.URLAllowlist.UpstreamHosts,
|
||||
RequireAllowlist: true,
|
||||
AllowPrivate: s.cfg.Security.URLAllowlist.AllowPrivateHosts,
|
||||
})
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("invalid base_url: %w", err)
|
||||
}
|
||||
return normalized, nil
|
||||
}
|
||||
|
||||
// HasAntigravityAccounts 检查是否有可用的 antigravity 账户
|
||||
func (s *GeminiMessagesCompatService) HasAntigravityAccounts(ctx context.Context, groupID *int64) (bool, error) {
|
||||
var accounts []Account
|
||||
@@ -360,16 +378,20 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
|
||||
return nil, "", errors.New("gemini api_key not configured")
|
||||
}
|
||||
|
||||
baseURL := strings.TrimRight(account.GetCredential("base_url"), "/")
|
||||
baseURL := strings.TrimSpace(account.GetCredential("base_url"))
|
||||
if baseURL == "" {
|
||||
baseURL = geminicli.AIStudioBaseURL
|
||||
}
|
||||
normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
action := "generateContent"
|
||||
if req.Stream {
|
||||
action = "streamGenerateContent"
|
||||
}
|
||||
fullURL := fmt.Sprintf("%s/v1beta/models/%s:%s", strings.TrimRight(baseURL, "/"), mappedModel, action)
|
||||
fullURL := fmt.Sprintf("%s/v1beta/models/%s:%s", strings.TrimRight(normalizedBaseURL, "/"), mappedModel, action)
|
||||
if req.Stream {
|
||||
fullURL += "?alt=sse"
|
||||
}
|
||||
@@ -406,7 +428,11 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
|
||||
// 2. Without project_id -> AI Studio API (direct OAuth, like API key but with Bearer token)
|
||||
if projectID != "" {
|
||||
// Mode 1: Code Assist API
|
||||
fullURL := fmt.Sprintf("%s/v1internal:%s", geminicli.GeminiCliBaseURL, action)
|
||||
baseURL, err := s.validateUpstreamBaseURL(geminicli.GeminiCliBaseURL)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
fullURL := fmt.Sprintf("%s/v1internal:%s", strings.TrimRight(baseURL, "/"), action)
|
||||
if useUpstreamStream {
|
||||
fullURL += "?alt=sse"
|
||||
}
|
||||
@@ -432,12 +458,16 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
|
||||
return upstreamReq, "x-request-id", nil
|
||||
} else {
|
||||
// Mode 2: AI Studio API with OAuth (like API key mode, but using Bearer token)
|
||||
baseURL := strings.TrimRight(account.GetCredential("base_url"), "/")
|
||||
baseURL := strings.TrimSpace(account.GetCredential("base_url"))
|
||||
if baseURL == "" {
|
||||
baseURL = geminicli.AIStudioBaseURL
|
||||
}
|
||||
normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
fullURL := fmt.Sprintf("%s/v1beta/models/%s:%s", baseURL, mappedModel, action)
|
||||
fullURL := fmt.Sprintf("%s/v1beta/models/%s:%s", strings.TrimRight(normalizedBaseURL, "/"), mappedModel, action)
|
||||
if useUpstreamStream {
|
||||
fullURL += "?alt=sse"
|
||||
}
|
||||
@@ -622,12 +652,16 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
|
||||
return nil, "", errors.New("gemini api_key not configured")
|
||||
}
|
||||
|
||||
baseURL := strings.TrimRight(account.GetCredential("base_url"), "/")
|
||||
baseURL := strings.TrimSpace(account.GetCredential("base_url"))
|
||||
if baseURL == "" {
|
||||
baseURL = geminicli.AIStudioBaseURL
|
||||
}
|
||||
normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
fullURL := fmt.Sprintf("%s/v1beta/models/%s:%s", strings.TrimRight(baseURL, "/"), mappedModel, upstreamAction)
|
||||
fullURL := fmt.Sprintf("%s/v1beta/models/%s:%s", strings.TrimRight(normalizedBaseURL, "/"), mappedModel, upstreamAction)
|
||||
if useUpstreamStream {
|
||||
fullURL += "?alt=sse"
|
||||
}
|
||||
@@ -659,7 +693,11 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
|
||||
// 2. Without project_id -> AI Studio API (direct OAuth, like API key but with Bearer token)
|
||||
if projectID != "" && !forceAIStudio {
|
||||
// Mode 1: Code Assist API
|
||||
fullURL := fmt.Sprintf("%s/v1internal:%s", geminicli.GeminiCliBaseURL, upstreamAction)
|
||||
baseURL, err := s.validateUpstreamBaseURL(geminicli.GeminiCliBaseURL)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
fullURL := fmt.Sprintf("%s/v1internal:%s", strings.TrimRight(baseURL, "/"), upstreamAction)
|
||||
if useUpstreamStream {
|
||||
fullURL += "?alt=sse"
|
||||
}
|
||||
@@ -685,12 +723,16 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
|
||||
return upstreamReq, "x-request-id", nil
|
||||
} else {
|
||||
// Mode 2: AI Studio API with OAuth (like API key mode, but using Bearer token)
|
||||
baseURL := strings.TrimRight(account.GetCredential("base_url"), "/")
|
||||
baseURL := strings.TrimSpace(account.GetCredential("base_url"))
|
||||
if baseURL == "" {
|
||||
baseURL = geminicli.AIStudioBaseURL
|
||||
}
|
||||
normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
fullURL := fmt.Sprintf("%s/v1beta/models/%s:%s", baseURL, mappedModel, upstreamAction)
|
||||
fullURL := fmt.Sprintf("%s/v1beta/models/%s:%s", strings.TrimRight(normalizedBaseURL, "/"), mappedModel, upstreamAction)
|
||||
if useUpstreamStream {
|
||||
fullURL += "?alt=sse"
|
||||
}
|
||||
@@ -1608,6 +1650,8 @@ func (s *GeminiMessagesCompatService) handleNativeNonStreamingResponse(c *gin.Co
|
||||
_ = json.Unmarshal(respBody, &parsed)
|
||||
}
|
||||
|
||||
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.cfg.Security.ResponseHeaders)
|
||||
|
||||
contentType := resp.Header.Get("Content-Type")
|
||||
if contentType == "" {
|
||||
contentType = "application/json"
|
||||
@@ -1720,11 +1764,15 @@ func (s *GeminiMessagesCompatService) ForwardAIStudioGET(ctx context.Context, ac
|
||||
return nil, errors.New("invalid path")
|
||||
}
|
||||
|
||||
baseURL := strings.TrimRight(account.GetCredential("base_url"), "/")
|
||||
baseURL := strings.TrimSpace(account.GetCredential("base_url"))
|
||||
if baseURL == "" {
|
||||
baseURL = geminicli.AIStudioBaseURL
|
||||
}
|
||||
fullURL := strings.TrimRight(baseURL, "/") + path
|
||||
normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
fullURL := strings.TrimRight(normalizedBaseURL, "/") + path
|
||||
|
||||
var proxyURL string
|
||||
if account.ProxyID != nil && account.Proxy != nil {
|
||||
@@ -1763,9 +1811,14 @@ func (s *GeminiMessagesCompatService) ForwardAIStudioGET(ctx context.Context, ac
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
body, _ := io.ReadAll(io.LimitReader(resp.Body, 8<<20))
|
||||
wwwAuthenticate := resp.Header.Get("Www-Authenticate")
|
||||
filteredHeaders := responseheaders.FilterHeaders(resp.Header, s.cfg.Security.ResponseHeaders)
|
||||
if wwwAuthenticate != "" {
|
||||
filteredHeaders.Set("Www-Authenticate", wwwAuthenticate)
|
||||
}
|
||||
return &UpstreamHTTPResult{
|
||||
StatusCode: resp.StatusCode,
|
||||
Headers: resp.Header.Clone(),
|
||||
Headers: filteredHeaders,
|
||||
Body: body,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -18,6 +18,8 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
|
||||
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
@@ -370,10 +372,14 @@ func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin.
|
||||
case AccountTypeApiKey:
|
||||
// API Key accounts use Platform API or custom base URL
|
||||
baseURL := account.GetOpenAIBaseURL()
|
||||
if baseURL != "" {
|
||||
targetURL = baseURL + "/responses"
|
||||
} else {
|
||||
if baseURL == "" {
|
||||
targetURL = openaiPlatformAPIURL
|
||||
} else {
|
||||
validatedURL, err := s.validateUpstreamBaseURL(baseURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
targetURL = validatedURL + "/responses"
|
||||
}
|
||||
default:
|
||||
targetURL = openaiPlatformAPIURL
|
||||
@@ -645,18 +651,25 @@ func (s *OpenAIGatewayService) handleNonStreamingResponse(ctx context.Context, r
|
||||
body = s.replaceModelInResponseBody(body, mappedModel, originalModel)
|
||||
}
|
||||
|
||||
// Pass through headers
|
||||
for key, values := range resp.Header {
|
||||
for _, value := range values {
|
||||
c.Header(key, value)
|
||||
}
|
||||
}
|
||||
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.cfg.Security.ResponseHeaders)
|
||||
|
||||
c.Data(resp.StatusCode, "application/json", body)
|
||||
|
||||
return usage, nil
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) validateUpstreamBaseURL(raw string) (string, error) {
|
||||
normalized, err := urlvalidator.ValidateHTTPSURL(raw, urlvalidator.ValidationOptions{
|
||||
AllowedHosts: s.cfg.Security.URLAllowlist.UpstreamHosts,
|
||||
RequireAllowlist: true,
|
||||
AllowPrivate: s.cfg.Security.URLAllowlist.AllowPrivateHosts,
|
||||
})
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("invalid base_url: %w", err)
|
||||
}
|
||||
return normalized, nil
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) replaceModelInResponseBody(body []byte, fromModel, toModel string) []byte {
|
||||
var resp map[string]any
|
||||
if err := json.Unmarshal(body, &resp); err != nil {
|
||||
|
||||
@@ -16,6 +16,7 @@ import (
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
||||
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -211,16 +212,35 @@ func (s *PricingService) syncWithRemote() error {
|
||||
|
||||
// downloadPricingData 从远程下载价格数据
|
||||
func (s *PricingService) downloadPricingData() error {
|
||||
log.Printf("[Pricing] Downloading from %s", s.cfg.Pricing.RemoteURL)
|
||||
remoteURL, err := s.validatePricingURL(s.cfg.Pricing.RemoteURL)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
log.Printf("[Pricing] Downloading from %s", remoteURL)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
body, err := s.remoteClient.FetchPricingJSON(ctx, s.cfg.Pricing.RemoteURL)
|
||||
var expectedHash string
|
||||
if strings.TrimSpace(s.cfg.Pricing.HashURL) != "" {
|
||||
expectedHash, err = s.fetchRemoteHash()
|
||||
if err != nil {
|
||||
return fmt.Errorf("fetch remote hash: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
body, err := s.remoteClient.FetchPricingJSON(ctx, remoteURL)
|
||||
if err != nil {
|
||||
return fmt.Errorf("download failed: %w", err)
|
||||
}
|
||||
|
||||
if expectedHash != "" {
|
||||
actualHash := sha256.Sum256(body)
|
||||
if !strings.EqualFold(expectedHash, hex.EncodeToString(actualHash[:])) {
|
||||
return fmt.Errorf("pricing hash mismatch")
|
||||
}
|
||||
}
|
||||
|
||||
// 解析JSON数据(使用灵活的解析方式)
|
||||
data, err := s.parsePricingData(body)
|
||||
if err != nil {
|
||||
@@ -373,10 +393,31 @@ func (s *PricingService) useFallbackPricing() error {
|
||||
|
||||
// fetchRemoteHash 从远程获取哈希值
|
||||
func (s *PricingService) fetchRemoteHash() (string, error) {
|
||||
hashURL, err := s.validatePricingURL(s.cfg.Pricing.HashURL)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
return s.remoteClient.FetchHashText(ctx, s.cfg.Pricing.HashURL)
|
||||
hash, err := s.remoteClient.FetchHashText(ctx, hashURL)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return strings.TrimSpace(hash), nil
|
||||
}
|
||||
|
||||
func (s *PricingService) validatePricingURL(raw string) (string, error) {
|
||||
normalized, err := urlvalidator.ValidateHTTPSURL(raw, urlvalidator.ValidationOptions{
|
||||
AllowedHosts: s.cfg.Security.URLAllowlist.PricingHosts,
|
||||
RequireAllowlist: true,
|
||||
AllowPrivate: s.cfg.Security.URLAllowlist.AllowPrivateHosts,
|
||||
})
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("invalid pricing url: %w", err)
|
||||
}
|
||||
return normalized, nil
|
||||
}
|
||||
|
||||
// computeFileHash 计算文件哈希
|
||||
|
||||
@@ -215,8 +215,10 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
|
||||
SmtpFrom: settings[SettingKeySmtpFrom],
|
||||
SmtpFromName: settings[SettingKeySmtpFromName],
|
||||
SmtpUseTLS: settings[SettingKeySmtpUseTLS] == "true",
|
||||
SmtpPasswordConfigured: settings[SettingKeySmtpPassword] != "",
|
||||
TurnstileEnabled: settings[SettingKeyTurnstileEnabled] == "true",
|
||||
TurnstileSiteKey: settings[SettingKeyTurnstileSiteKey],
|
||||
TurnstileSecretKeyConfigured: settings[SettingKeyTurnstileSecretKey] != "",
|
||||
SiteName: s.getStringOrDefault(settings, SettingKeySiteName, "Sub2API"),
|
||||
SiteLogo: settings[SettingKeySiteLogo],
|
||||
SiteSubtitle: s.getStringOrDefault(settings, SettingKeySiteSubtitle, "Subscription to API Conversion Platform"),
|
||||
@@ -245,10 +247,6 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
|
||||
result.DefaultBalance = s.cfg.Default.UserBalance
|
||||
}
|
||||
|
||||
// 敏感信息直接返回,方便测试连接时使用
|
||||
result.SmtpPassword = settings[SettingKeySmtpPassword]
|
||||
result.TurnstileSecretKey = settings[SettingKeyTurnstileSecretKey]
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
|
||||
@@ -8,6 +8,7 @@ type SystemSettings struct {
|
||||
SmtpPort int
|
||||
SmtpUsername string
|
||||
SmtpPassword string
|
||||
SmtpPasswordConfigured bool
|
||||
SmtpFrom string
|
||||
SmtpFromName string
|
||||
SmtpUseTLS bool
|
||||
@@ -15,6 +16,7 @@ type SystemSettings struct {
|
||||
TurnstileEnabled bool
|
||||
TurnstileSiteKey string
|
||||
TurnstileSecretKey string
|
||||
TurnstileSecretKeyConfigured bool
|
||||
|
||||
SiteName string
|
||||
SiteLogo string
|
||||
|
||||
Reference in New Issue
Block a user