feat(安全): 强化安全策略与配置校验

- 增加 CORS/CSP/安全响应头与代理信任配置

- 引入 URL 白名单与私网开关,校验上游与价格源

- 改善 API Key 处理与网关错误返回

- 管理端设置隐藏敏感字段并优化前端提示

- 增加计费熔断与相关配置示例

测试: go test ./...
This commit is contained in:
yangjianbo
2026-01-02 17:40:57 +08:00
parent 3fd9bd4a80
commit bd4bf00856
46 changed files with 1572 additions and 220 deletions

View File

@@ -7,6 +7,7 @@ import (
"crypto/rand"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"io"
"log"
@@ -15,9 +16,11 @@ import (
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
"github.com/Wei-Shaw/sub2api/internal/pkg/geminicli"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
)
@@ -49,6 +52,7 @@ type AccountTestService struct {
geminiTokenProvider *GeminiTokenProvider
antigravityGatewayService *AntigravityGatewayService
httpUpstream HTTPUpstream
cfg *config.Config
}
// NewAccountTestService creates a new AccountTestService
@@ -59,6 +63,7 @@ func NewAccountTestService(
geminiTokenProvider *GeminiTokenProvider,
antigravityGatewayService *AntigravityGatewayService,
httpUpstream HTTPUpstream,
cfg *config.Config,
) *AccountTestService {
return &AccountTestService{
accountRepo: accountRepo,
@@ -67,9 +72,25 @@ func NewAccountTestService(
geminiTokenProvider: geminiTokenProvider,
antigravityGatewayService: antigravityGatewayService,
httpUpstream: httpUpstream,
cfg: cfg,
}
}
func (s *AccountTestService) validateUpstreamBaseURL(raw string) (string, error) {
if s.cfg == nil {
return "", errors.New("config is not available")
}
normalized, err := urlvalidator.ValidateHTTPSURL(raw, urlvalidator.ValidationOptions{
AllowedHosts: s.cfg.Security.URLAllowlist.UpstreamHosts,
RequireAllowlist: true,
AllowPrivate: s.cfg.Security.URLAllowlist.AllowPrivateHosts,
})
if err != nil {
return "", err
}
return normalized, nil
}
// generateSessionString generates a Claude Code style session string
func generateSessionString() (string, error) {
bytes := make([]byte, 32)
@@ -207,11 +228,15 @@ func (s *AccountTestService) testClaudeAccountConnection(c *gin.Context, account
return s.sendErrorAndEnd(c, "No API key available")
}
apiURL = account.GetBaseURL()
if apiURL == "" {
apiURL = "https://api.anthropic.com"
baseURL := account.GetBaseURL()
if baseURL == "" {
baseURL = "https://api.anthropic.com"
}
apiURL = strings.TrimSuffix(apiURL, "/") + "/v1/messages"
normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
if err != nil {
return s.sendErrorAndEnd(c, fmt.Sprintf("Invalid base URL: %s", err.Error()))
}
apiURL = strings.TrimSuffix(normalizedBaseURL, "/") + "/v1/messages"
} else {
return s.sendErrorAndEnd(c, fmt.Sprintf("Unsupported account type: %s", account.Type))
}
@@ -333,7 +358,11 @@ func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account
if baseURL == "" {
baseURL = "https://api.openai.com"
}
apiURL = strings.TrimSuffix(baseURL, "/") + "/responses"
normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
if err != nil {
return s.sendErrorAndEnd(c, fmt.Sprintf("Invalid base URL: %s", err.Error()))
}
apiURL = strings.TrimSuffix(normalizedBaseURL, "/") + "/responses"
} else {
return s.sendErrorAndEnd(c, fmt.Sprintf("Unsupported account type: %s", account.Type))
}
@@ -513,10 +542,14 @@ func (s *AccountTestService) buildGeminiAPIKeyRequest(ctx context.Context, accou
if baseURL == "" {
baseURL = geminicli.AIStudioBaseURL
}
normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
if err != nil {
return nil, err
}
// Use streamGenerateContent for real-time feedback
fullURL := fmt.Sprintf("%s/v1beta/models/%s:streamGenerateContent?alt=sse",
strings.TrimRight(baseURL, "/"), modelID)
strings.TrimRight(normalizedBaseURL, "/"), modelID)
req, err := http.NewRequestWithContext(ctx, "POST", fullURL, bytes.NewReader(payload))
if err != nil {
@@ -548,7 +581,11 @@ func (s *AccountTestService) buildGeminiOAuthRequest(ctx context.Context, accoun
if strings.TrimSpace(baseURL) == "" {
baseURL = geminicli.AIStudioBaseURL
}
fullURL := fmt.Sprintf("%s/v1beta/models/%s:streamGenerateContent?alt=sse", strings.TrimRight(baseURL, "/"), modelID)
normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
if err != nil {
return nil, err
}
fullURL := fmt.Sprintf("%s/v1beta/models/%s:streamGenerateContent?alt=sse", strings.TrimRight(normalizedBaseURL, "/"), modelID)
req, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader(payload))
if err != nil {
@@ -577,7 +614,11 @@ func (s *AccountTestService) buildCodeAssistRequest(ctx context.Context, accessT
}
wrappedBytes, _ := json.Marshal(wrapped)
fullURL := fmt.Sprintf("%s/v1internal:streamGenerateContent?alt=sse", geminicli.GeminiCliBaseURL)
normalizedBaseURL, err := s.validateUpstreamBaseURL(geminicli.GeminiCliBaseURL)
if err != nil {
return nil, err
}
fullURL := fmt.Sprintf("%s/v1internal:streamGenerateContent?alt=sse", normalizedBaseURL)
req, err := http.NewRequestWithContext(ctx, "POST", fullURL, bytes.NewReader(wrappedBytes))
if err != nil {

View File

@@ -221,9 +221,33 @@ func (s *AuthService) SendVerifyCodeAsync(ctx context.Context, email string) (*S
// VerifyTurnstile 验证Turnstile token
func (s *AuthService) VerifyTurnstile(ctx context.Context, token string, remoteIP string) error {
required := s.cfg != nil && s.cfg.Server.Mode == "release" && s.cfg.Turnstile.Required
if required {
if s.settingService == nil {
log.Println("[Auth] Turnstile required but settings service is not configured")
return ErrTurnstileNotConfigured
}
enabled := s.settingService.IsTurnstileEnabled(ctx)
secretConfigured := s.settingService.GetTurnstileSecretKey(ctx) != ""
if !enabled || !secretConfigured {
log.Printf("[Auth] Turnstile required but not configured (enabled=%v, secret_configured=%v)", enabled, secretConfigured)
return ErrTurnstileNotConfigured
}
}
if s.turnstileService == nil {
if required {
log.Println("[Auth] Turnstile required but service not configured")
return ErrTurnstileNotConfigured
}
return nil // 服务未配置则跳过验证
}
if !required && s.settingService != nil && s.settingService.IsTurnstileEnabled(ctx) && s.settingService.GetTurnstileSecretKey(ctx) == "" {
log.Println("[Auth] Turnstile enabled but secret key not configured")
}
return s.turnstileService.VerifyToken(ctx, token, remoteIP)
}

View File

@@ -17,6 +17,7 @@ import (
// 注ErrDailyLimitExceeded/ErrWeeklyLimitExceeded/ErrMonthlyLimitExceeded在subscription_service.go中定义
var (
ErrSubscriptionInvalid = infraerrors.Forbidden("SUBSCRIPTION_INVALID", "subscription is invalid or expired")
ErrBillingServiceUnavailable = infraerrors.ServiceUnavailable("BILLING_SERVICE_ERROR", "Billing service temporarily unavailable. Please retry later.")
)
// subscriptionCacheData 订阅缓存数据结构(内部使用)
@@ -76,6 +77,7 @@ type BillingCacheService struct {
userRepo UserRepository
subRepo UserSubscriptionRepository
cfg *config.Config
circuitBreaker *billingCircuitBreaker
cacheWriteChan chan cacheWriteTask
cacheWriteWg sync.WaitGroup
@@ -95,6 +97,7 @@ func NewBillingCacheService(cache BillingCache, userRepo UserRepository, subRepo
subRepo: subRepo,
cfg: cfg,
}
svc.circuitBreaker = newBillingCircuitBreaker(cfg.Billing.CircuitBreaker)
svc.startCacheWriteWorkers()
return svc
}
@@ -450,6 +453,9 @@ func (s *BillingCacheService) CheckBillingEligibility(ctx context.Context, user
if s.cfg.RunMode == config.RunModeSimple {
return nil
}
if s.circuitBreaker != nil && !s.circuitBreaker.Allow() {
return ErrBillingServiceUnavailable
}
// 判断计费模式
isSubscriptionMode := group != nil && group.IsSubscriptionType() && subscription != nil
@@ -465,9 +471,14 @@ func (s *BillingCacheService) CheckBillingEligibility(ctx context.Context, user
func (s *BillingCacheService) checkBalanceEligibility(ctx context.Context, userID int64) error {
balance, err := s.GetUserBalance(ctx, userID)
if err != nil {
// 缓存/数据库错误,允许通过(降级处理)
log.Printf("Warning: get user balance failed, allowing request: %v", err)
return nil
if s.circuitBreaker != nil {
s.circuitBreaker.OnFailure(err)
}
log.Printf("ALERT: billing balance check failed for user %d: %v", userID, err)
return ErrBillingServiceUnavailable.WithCause(err)
}
if s.circuitBreaker != nil {
s.circuitBreaker.OnSuccess()
}
if balance <= 0 {
@@ -482,9 +493,14 @@ func (s *BillingCacheService) checkSubscriptionEligibility(ctx context.Context,
// 获取订阅缓存数据
subData, err := s.GetSubscriptionStatus(ctx, userID, group.ID)
if err != nil {
// 缓存/数据库错误降级使用传入的subscription进行检查
log.Printf("Warning: get subscription cache failed, using fallback: %v", err)
return s.checkSubscriptionLimitsFallback(subscription, group)
if s.circuitBreaker != nil {
s.circuitBreaker.OnFailure(err)
}
log.Printf("ALERT: billing subscription check failed for user %d group %d: %v", userID, group.ID, err)
return ErrBillingServiceUnavailable.WithCause(err)
}
if s.circuitBreaker != nil {
s.circuitBreaker.OnSuccess()
}
// 检查订阅状态
@@ -513,6 +529,137 @@ func (s *BillingCacheService) checkSubscriptionEligibility(ctx context.Context,
return nil
}
type billingCircuitBreakerState int
const (
billingCircuitClosed billingCircuitBreakerState = iota
billingCircuitOpen
billingCircuitHalfOpen
)
type billingCircuitBreaker struct {
mu sync.Mutex
state billingCircuitBreakerState
failures int
openedAt time.Time
failureThreshold int
resetTimeout time.Duration
halfOpenRequests int
halfOpenRemaining int
}
func newBillingCircuitBreaker(cfg config.CircuitBreakerConfig) *billingCircuitBreaker {
if !cfg.Enabled {
return nil
}
resetTimeout := time.Duration(cfg.ResetTimeoutSeconds) * time.Second
if resetTimeout <= 0 {
resetTimeout = 30 * time.Second
}
halfOpen := cfg.HalfOpenRequests
if halfOpen <= 0 {
halfOpen = 1
}
threshold := cfg.FailureThreshold
if threshold <= 0 {
threshold = 5
}
return &billingCircuitBreaker{
state: billingCircuitClosed,
failureThreshold: threshold,
resetTimeout: resetTimeout,
halfOpenRequests: halfOpen,
}
}
func (b *billingCircuitBreaker) Allow() bool {
b.mu.Lock()
defer b.mu.Unlock()
switch b.state {
case billingCircuitClosed:
return true
case billingCircuitOpen:
if time.Since(b.openedAt) < b.resetTimeout {
return false
}
b.state = billingCircuitHalfOpen
b.halfOpenRemaining = b.halfOpenRequests
log.Printf("ALERT: billing circuit breaker entering half-open state")
fallthrough
case billingCircuitHalfOpen:
if b.halfOpenRemaining <= 0 {
return false
}
b.halfOpenRemaining--
return true
default:
return false
}
}
func (b *billingCircuitBreaker) OnFailure(err error) {
if b == nil {
return
}
b.mu.Lock()
defer b.mu.Unlock()
switch b.state {
case billingCircuitOpen:
return
case billingCircuitHalfOpen:
b.state = billingCircuitOpen
b.openedAt = time.Now()
b.halfOpenRemaining = 0
log.Printf("ALERT: billing circuit breaker opened after half-open failure: %v", err)
return
default:
b.failures++
if b.failures >= b.failureThreshold {
b.state = billingCircuitOpen
b.openedAt = time.Now()
b.halfOpenRemaining = 0
log.Printf("ALERT: billing circuit breaker opened after %d failures: %v", b.failures, err)
}
}
}
func (b *billingCircuitBreaker) OnSuccess() {
if b == nil {
return
}
b.mu.Lock()
defer b.mu.Unlock()
previousState := b.state
previousFailures := b.failures
b.state = billingCircuitClosed
b.failures = 0
b.halfOpenRemaining = 0
// 只有状态真正发生变化时才记录日志
if previousState != billingCircuitClosed {
log.Printf("ALERT: billing circuit breaker closed (was %s)", circuitStateString(previousState))
} else if previousFailures > 0 {
log.Printf("INFO: billing circuit breaker failures reset from %d", previousFailures)
}
}
func circuitStateString(state billingCircuitBreakerState) string {
switch state {
case billingCircuitClosed:
return "closed"
case billingCircuitOpen:
return "open"
case billingCircuitHalfOpen:
return "half-open"
default:
return "unknown"
}
}
// checkSubscriptionLimitsFallback 降级检查订阅限额
func (s *BillingCacheService) checkSubscriptionLimitsFallback(subscription *UserSubscription, group *Group) error {
if subscription == nil {

View File

@@ -8,12 +8,13 @@ import (
"fmt"
"io"
"net/http"
"net/url"
"strconv"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
)
type CRSSyncService struct {
@@ -22,6 +23,7 @@ type CRSSyncService struct {
oauthService *OAuthService
openaiOAuthService *OpenAIOAuthService
geminiOAuthService *GeminiOAuthService
cfg *config.Config
}
func NewCRSSyncService(
@@ -30,6 +32,7 @@ func NewCRSSyncService(
oauthService *OAuthService,
openaiOAuthService *OpenAIOAuthService,
geminiOAuthService *GeminiOAuthService,
cfg *config.Config,
) *CRSSyncService {
return &CRSSyncService{
accountRepo: accountRepo,
@@ -37,6 +40,7 @@ func NewCRSSyncService(
oauthService: oauthService,
openaiOAuthService: openaiOAuthService,
geminiOAuthService: geminiOAuthService,
cfg: cfg,
}
}
@@ -187,7 +191,10 @@ type crsGeminiAPIKeyAccount struct {
}
func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput) (*SyncFromCRSResult, error) {
baseURL, err := normalizeBaseURL(input.BaseURL)
if s.cfg == nil {
return nil, errors.New("config is not available")
}
baseURL, err := normalizeBaseURL(input.BaseURL, s.cfg.Security.URLAllowlist.CRSHosts, s.cfg.Security.URLAllowlist.AllowPrivateHosts)
if err != nil {
return nil, err
}
@@ -1055,17 +1062,18 @@ func mapCRSStatus(isActive bool, status string) string {
return "active"
}
func normalizeBaseURL(raw string) (string, error) {
trimmed := strings.TrimSpace(raw)
if trimmed == "" {
return "", errors.New("base_url is required")
func normalizeBaseURL(raw string, allowlist []string, allowPrivate bool) (string, error) {
// 当 allowlist 为空时,不强制要求白名单(只进行基本的 URL 和 SSRF 验证)
requireAllowlist := len(allowlist) > 0
normalized, err := urlvalidator.ValidateHTTPSURL(raw, urlvalidator.ValidationOptions{
AllowedHosts: allowlist,
RequireAllowlist: requireAllowlist,
AllowPrivate: allowPrivate,
})
if err != nil {
return "", fmt.Errorf("invalid base_url: %w", err)
}
u, err := url.Parse(trimmed)
if err != nil || u.Scheme == "" || u.Host == "" {
return "", fmt.Errorf("invalid base_url: %s", trimmed)
}
u.Path = strings.TrimRight(u.Path, "/")
return strings.TrimRight(u.String(), "/"), nil
return normalized, nil
}
// cleanBaseURL removes trailing suffix from base_url in credentials

View File

@@ -19,6 +19,8 @@ import (
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
"github.com/tidwall/sjson"
"github.com/gin-gonic/gin"
@@ -724,7 +726,13 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
targetURL := claudeAPIURL
if account.Type == AccountTypeApiKey {
baseURL := account.GetBaseURL()
targetURL = baseURL + "/v1/messages"
if baseURL != "" {
validatedURL, err := s.validateUpstreamBaseURL(baseURL)
if err != nil {
return nil, err
}
targetURL = validatedURL + "/v1/messages"
}
}
// OAuth账号应用统一指纹
@@ -1107,12 +1115,7 @@ func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *h
body = s.replaceModelInResponseBody(body, mappedModel, originalModel)
}
// 透传响应头
for key, values := range resp.Header {
for _, value := range values {
c.Header(key, value)
}
}
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.cfg.Security.ResponseHeaders)
// 写入响应
c.Data(resp.StatusCode, "application/json", body)
@@ -1352,7 +1355,13 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
targetURL := claudeAPICountTokensURL
if account.Type == AccountTypeApiKey {
baseURL := account.GetBaseURL()
targetURL = baseURL + "/v1/messages/count_tokens"
if baseURL != "" {
validatedURL, err := s.validateUpstreamBaseURL(baseURL)
if err != nil {
return nil, err
}
targetURL = validatedURL + "/v1/messages/count_tokens"
}
}
// OAuth 账号:应用统一指纹和重写 userID
@@ -1424,3 +1433,15 @@ func (s *GatewayService) countTokensError(c *gin.Context, status int, errType, m
},
})
}
func (s *GatewayService) validateUpstreamBaseURL(raw string) (string, error) {
normalized, err := urlvalidator.ValidateHTTPSURL(raw, urlvalidator.ValidationOptions{
AllowedHosts: s.cfg.Security.URLAllowlist.UpstreamHosts,
RequireAllowlist: true,
AllowPrivate: s.cfg.Security.URLAllowlist.AllowPrivateHosts,
})
if err != nil {
return "", fmt.Errorf("invalid base_url: %w", err)
}
return normalized, nil
}

View File

@@ -18,9 +18,12 @@ import (
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/Wei-Shaw/sub2api/internal/pkg/geminicli"
"github.com/Wei-Shaw/sub2api/internal/pkg/googleapi"
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
"github.com/gin-gonic/gin"
)
@@ -41,6 +44,7 @@ type GeminiMessagesCompatService struct {
rateLimitService *RateLimitService
httpUpstream HTTPUpstream
antigravityGatewayService *AntigravityGatewayService
cfg *config.Config
}
func NewGeminiMessagesCompatService(
@@ -51,6 +55,7 @@ func NewGeminiMessagesCompatService(
rateLimitService *RateLimitService,
httpUpstream HTTPUpstream,
antigravityGatewayService *AntigravityGatewayService,
cfg *config.Config,
) *GeminiMessagesCompatService {
return &GeminiMessagesCompatService{
accountRepo: accountRepo,
@@ -60,6 +65,7 @@ func NewGeminiMessagesCompatService(
rateLimitService: rateLimitService,
httpUpstream: httpUpstream,
antigravityGatewayService: antigravityGatewayService,
cfg: cfg,
}
}
@@ -209,6 +215,18 @@ func (s *GeminiMessagesCompatService) GetAntigravityGatewayService() *Antigravit
return s.antigravityGatewayService
}
func (s *GeminiMessagesCompatService) validateUpstreamBaseURL(raw string) (string, error) {
normalized, err := urlvalidator.ValidateHTTPSURL(raw, urlvalidator.ValidationOptions{
AllowedHosts: s.cfg.Security.URLAllowlist.UpstreamHosts,
RequireAllowlist: true,
AllowPrivate: s.cfg.Security.URLAllowlist.AllowPrivateHosts,
})
if err != nil {
return "", fmt.Errorf("invalid base_url: %w", err)
}
return normalized, nil
}
// HasAntigravityAccounts 检查是否有可用的 antigravity 账户
func (s *GeminiMessagesCompatService) HasAntigravityAccounts(ctx context.Context, groupID *int64) (bool, error) {
var accounts []Account
@@ -360,16 +378,20 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
return nil, "", errors.New("gemini api_key not configured")
}
baseURL := strings.TrimRight(account.GetCredential("base_url"), "/")
baseURL := strings.TrimSpace(account.GetCredential("base_url"))
if baseURL == "" {
baseURL = geminicli.AIStudioBaseURL
}
normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
if err != nil {
return nil, "", err
}
action := "generateContent"
if req.Stream {
action = "streamGenerateContent"
}
fullURL := fmt.Sprintf("%s/v1beta/models/%s:%s", strings.TrimRight(baseURL, "/"), mappedModel, action)
fullURL := fmt.Sprintf("%s/v1beta/models/%s:%s", strings.TrimRight(normalizedBaseURL, "/"), mappedModel, action)
if req.Stream {
fullURL += "?alt=sse"
}
@@ -406,7 +428,11 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
// 2. Without project_id -> AI Studio API (direct OAuth, like API key but with Bearer token)
if projectID != "" {
// Mode 1: Code Assist API
fullURL := fmt.Sprintf("%s/v1internal:%s", geminicli.GeminiCliBaseURL, action)
baseURL, err := s.validateUpstreamBaseURL(geminicli.GeminiCliBaseURL)
if err != nil {
return nil, "", err
}
fullURL := fmt.Sprintf("%s/v1internal:%s", strings.TrimRight(baseURL, "/"), action)
if useUpstreamStream {
fullURL += "?alt=sse"
}
@@ -432,12 +458,16 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
return upstreamReq, "x-request-id", nil
} else {
// Mode 2: AI Studio API with OAuth (like API key mode, but using Bearer token)
baseURL := strings.TrimRight(account.GetCredential("base_url"), "/")
baseURL := strings.TrimSpace(account.GetCredential("base_url"))
if baseURL == "" {
baseURL = geminicli.AIStudioBaseURL
}
normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
if err != nil {
return nil, "", err
}
fullURL := fmt.Sprintf("%s/v1beta/models/%s:%s", baseURL, mappedModel, action)
fullURL := fmt.Sprintf("%s/v1beta/models/%s:%s", strings.TrimRight(normalizedBaseURL, "/"), mappedModel, action)
if useUpstreamStream {
fullURL += "?alt=sse"
}
@@ -622,12 +652,16 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
return nil, "", errors.New("gemini api_key not configured")
}
baseURL := strings.TrimRight(account.GetCredential("base_url"), "/")
baseURL := strings.TrimSpace(account.GetCredential("base_url"))
if baseURL == "" {
baseURL = geminicli.AIStudioBaseURL
}
normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
if err != nil {
return nil, "", err
}
fullURL := fmt.Sprintf("%s/v1beta/models/%s:%s", strings.TrimRight(baseURL, "/"), mappedModel, upstreamAction)
fullURL := fmt.Sprintf("%s/v1beta/models/%s:%s", strings.TrimRight(normalizedBaseURL, "/"), mappedModel, upstreamAction)
if useUpstreamStream {
fullURL += "?alt=sse"
}
@@ -659,7 +693,11 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
// 2. Without project_id -> AI Studio API (direct OAuth, like API key but with Bearer token)
if projectID != "" && !forceAIStudio {
// Mode 1: Code Assist API
fullURL := fmt.Sprintf("%s/v1internal:%s", geminicli.GeminiCliBaseURL, upstreamAction)
baseURL, err := s.validateUpstreamBaseURL(geminicli.GeminiCliBaseURL)
if err != nil {
return nil, "", err
}
fullURL := fmt.Sprintf("%s/v1internal:%s", strings.TrimRight(baseURL, "/"), upstreamAction)
if useUpstreamStream {
fullURL += "?alt=sse"
}
@@ -685,12 +723,16 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
return upstreamReq, "x-request-id", nil
} else {
// Mode 2: AI Studio API with OAuth (like API key mode, but using Bearer token)
baseURL := strings.TrimRight(account.GetCredential("base_url"), "/")
baseURL := strings.TrimSpace(account.GetCredential("base_url"))
if baseURL == "" {
baseURL = geminicli.AIStudioBaseURL
}
normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
if err != nil {
return nil, "", err
}
fullURL := fmt.Sprintf("%s/v1beta/models/%s:%s", baseURL, mappedModel, upstreamAction)
fullURL := fmt.Sprintf("%s/v1beta/models/%s:%s", strings.TrimRight(normalizedBaseURL, "/"), mappedModel, upstreamAction)
if useUpstreamStream {
fullURL += "?alt=sse"
}
@@ -1608,6 +1650,8 @@ func (s *GeminiMessagesCompatService) handleNativeNonStreamingResponse(c *gin.Co
_ = json.Unmarshal(respBody, &parsed)
}
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.cfg.Security.ResponseHeaders)
contentType := resp.Header.Get("Content-Type")
if contentType == "" {
contentType = "application/json"
@@ -1720,11 +1764,15 @@ func (s *GeminiMessagesCompatService) ForwardAIStudioGET(ctx context.Context, ac
return nil, errors.New("invalid path")
}
baseURL := strings.TrimRight(account.GetCredential("base_url"), "/")
baseURL := strings.TrimSpace(account.GetCredential("base_url"))
if baseURL == "" {
baseURL = geminicli.AIStudioBaseURL
}
fullURL := strings.TrimRight(baseURL, "/") + path
normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
if err != nil {
return nil, err
}
fullURL := strings.TrimRight(normalizedBaseURL, "/") + path
var proxyURL string
if account.ProxyID != nil && account.Proxy != nil {
@@ -1763,9 +1811,14 @@ func (s *GeminiMessagesCompatService) ForwardAIStudioGET(ctx context.Context, ac
defer func() { _ = resp.Body.Close() }()
body, _ := io.ReadAll(io.LimitReader(resp.Body, 8<<20))
wwwAuthenticate := resp.Header.Get("Www-Authenticate")
filteredHeaders := responseheaders.FilterHeaders(resp.Header, s.cfg.Security.ResponseHeaders)
if wwwAuthenticate != "" {
filteredHeaders.Set("Www-Authenticate", wwwAuthenticate)
}
return &UpstreamHTTPResult{
StatusCode: resp.StatusCode,
Headers: resp.Header.Clone(),
Headers: filteredHeaders,
Body: body,
}, nil
}

View File

@@ -18,6 +18,8 @@ import (
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
"github.com/gin-gonic/gin"
)
@@ -370,10 +372,14 @@ func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin.
case AccountTypeApiKey:
// API Key accounts use Platform API or custom base URL
baseURL := account.GetOpenAIBaseURL()
if baseURL != "" {
targetURL = baseURL + "/responses"
} else {
if baseURL == "" {
targetURL = openaiPlatformAPIURL
} else {
validatedURL, err := s.validateUpstreamBaseURL(baseURL)
if err != nil {
return nil, err
}
targetURL = validatedURL + "/responses"
}
default:
targetURL = openaiPlatformAPIURL
@@ -645,18 +651,25 @@ func (s *OpenAIGatewayService) handleNonStreamingResponse(ctx context.Context, r
body = s.replaceModelInResponseBody(body, mappedModel, originalModel)
}
// Pass through headers
for key, values := range resp.Header {
for _, value := range values {
c.Header(key, value)
}
}
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.cfg.Security.ResponseHeaders)
c.Data(resp.StatusCode, "application/json", body)
return usage, nil
}
func (s *OpenAIGatewayService) validateUpstreamBaseURL(raw string) (string, error) {
normalized, err := urlvalidator.ValidateHTTPSURL(raw, urlvalidator.ValidationOptions{
AllowedHosts: s.cfg.Security.URLAllowlist.UpstreamHosts,
RequireAllowlist: true,
AllowPrivate: s.cfg.Security.URLAllowlist.AllowPrivateHosts,
})
if err != nil {
return "", fmt.Errorf("invalid base_url: %w", err)
}
return normalized, nil
}
func (s *OpenAIGatewayService) replaceModelInResponseBody(body []byte, fromModel, toModel string) []byte {
var resp map[string]any
if err := json.Unmarshal(body, &resp); err != nil {

View File

@@ -16,6 +16,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
)
var (
@@ -211,16 +212,35 @@ func (s *PricingService) syncWithRemote() error {
// downloadPricingData 从远程下载价格数据
func (s *PricingService) downloadPricingData() error {
log.Printf("[Pricing] Downloading from %s", s.cfg.Pricing.RemoteURL)
remoteURL, err := s.validatePricingURL(s.cfg.Pricing.RemoteURL)
if err != nil {
return err
}
log.Printf("[Pricing] Downloading from %s", remoteURL)
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
body, err := s.remoteClient.FetchPricingJSON(ctx, s.cfg.Pricing.RemoteURL)
var expectedHash string
if strings.TrimSpace(s.cfg.Pricing.HashURL) != "" {
expectedHash, err = s.fetchRemoteHash()
if err != nil {
return fmt.Errorf("fetch remote hash: %w", err)
}
}
body, err := s.remoteClient.FetchPricingJSON(ctx, remoteURL)
if err != nil {
return fmt.Errorf("download failed: %w", err)
}
if expectedHash != "" {
actualHash := sha256.Sum256(body)
if !strings.EqualFold(expectedHash, hex.EncodeToString(actualHash[:])) {
return fmt.Errorf("pricing hash mismatch")
}
}
// 解析JSON数据使用灵活的解析方式
data, err := s.parsePricingData(body)
if err != nil {
@@ -373,10 +393,31 @@ func (s *PricingService) useFallbackPricing() error {
// fetchRemoteHash 从远程获取哈希值
func (s *PricingService) fetchRemoteHash() (string, error) {
hashURL, err := s.validatePricingURL(s.cfg.Pricing.HashURL)
if err != nil {
return "", err
}
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
return s.remoteClient.FetchHashText(ctx, s.cfg.Pricing.HashURL)
hash, err := s.remoteClient.FetchHashText(ctx, hashURL)
if err != nil {
return "", err
}
return strings.TrimSpace(hash), nil
}
func (s *PricingService) validatePricingURL(raw string) (string, error) {
normalized, err := urlvalidator.ValidateHTTPSURL(raw, urlvalidator.ValidationOptions{
AllowedHosts: s.cfg.Security.URLAllowlist.PricingHosts,
RequireAllowlist: true,
AllowPrivate: s.cfg.Security.URLAllowlist.AllowPrivateHosts,
})
if err != nil {
return "", fmt.Errorf("invalid pricing url: %w", err)
}
return normalized, nil
}
// computeFileHash 计算文件哈希

View File

@@ -215,8 +215,10 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
SmtpFrom: settings[SettingKeySmtpFrom],
SmtpFromName: settings[SettingKeySmtpFromName],
SmtpUseTLS: settings[SettingKeySmtpUseTLS] == "true",
SmtpPasswordConfigured: settings[SettingKeySmtpPassword] != "",
TurnstileEnabled: settings[SettingKeyTurnstileEnabled] == "true",
TurnstileSiteKey: settings[SettingKeyTurnstileSiteKey],
TurnstileSecretKeyConfigured: settings[SettingKeyTurnstileSecretKey] != "",
SiteName: s.getStringOrDefault(settings, SettingKeySiteName, "Sub2API"),
SiteLogo: settings[SettingKeySiteLogo],
SiteSubtitle: s.getStringOrDefault(settings, SettingKeySiteSubtitle, "Subscription to API Conversion Platform"),
@@ -245,10 +247,6 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
result.DefaultBalance = s.cfg.Default.UserBalance
}
// 敏感信息直接返回,方便测试连接时使用
result.SmtpPassword = settings[SettingKeySmtpPassword]
result.TurnstileSecretKey = settings[SettingKeyTurnstileSecretKey]
return result
}

View File

@@ -8,6 +8,7 @@ type SystemSettings struct {
SmtpPort int
SmtpUsername string
SmtpPassword string
SmtpPasswordConfigured bool
SmtpFrom string
SmtpFromName string
SmtpUseTLS bool
@@ -15,6 +16,7 @@ type SystemSettings struct {
TurnstileEnabled bool
TurnstileSiteKey string
TurnstileSecretKey string
TurnstileSecretKeyConfigured bool
SiteName string
SiteLogo string