Merge branch 'main' into feat/api-key-ip-restriction
This commit is contained in:
@@ -1 +1 @@
|
||||
0.1.1
|
||||
0.1.46
|
||||
|
||||
@@ -53,7 +53,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
emailQueueService := service.ProvideEmailQueueService(emailService)
|
||||
authService := service.NewAuthService(userRepository, configConfig, settingService, emailService, turnstileService, emailQueueService)
|
||||
userService := service.NewUserService(userRepository)
|
||||
authHandler := handler.NewAuthHandler(configConfig, authService, userService)
|
||||
authHandler := handler.NewAuthHandler(configConfig, authService, userService, settingService)
|
||||
userHandler := handler.NewUserHandler(userService)
|
||||
apiKeyRepository := repository.NewAPIKeyRepository(client)
|
||||
groupRepository := repository.NewGroupRepository(client, db)
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/url"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -35,24 +36,25 @@ const (
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
Server ServerConfig `mapstructure:"server"`
|
||||
CORS CORSConfig `mapstructure:"cors"`
|
||||
Security SecurityConfig `mapstructure:"security"`
|
||||
Billing BillingConfig `mapstructure:"billing"`
|
||||
Turnstile TurnstileConfig `mapstructure:"turnstile"`
|
||||
Database DatabaseConfig `mapstructure:"database"`
|
||||
Redis RedisConfig `mapstructure:"redis"`
|
||||
JWT JWTConfig `mapstructure:"jwt"`
|
||||
Default DefaultConfig `mapstructure:"default"`
|
||||
RateLimit RateLimitConfig `mapstructure:"rate_limit"`
|
||||
Pricing PricingConfig `mapstructure:"pricing"`
|
||||
Gateway GatewayConfig `mapstructure:"gateway"`
|
||||
Concurrency ConcurrencyConfig `mapstructure:"concurrency"`
|
||||
TokenRefresh TokenRefreshConfig `mapstructure:"token_refresh"`
|
||||
RunMode string `mapstructure:"run_mode" yaml:"run_mode"`
|
||||
Timezone string `mapstructure:"timezone"` // e.g. "Asia/Shanghai", "UTC"
|
||||
Gemini GeminiConfig `mapstructure:"gemini"`
|
||||
Update UpdateConfig `mapstructure:"update"`
|
||||
Server ServerConfig `mapstructure:"server"`
|
||||
CORS CORSConfig `mapstructure:"cors"`
|
||||
Security SecurityConfig `mapstructure:"security"`
|
||||
Billing BillingConfig `mapstructure:"billing"`
|
||||
Turnstile TurnstileConfig `mapstructure:"turnstile"`
|
||||
Database DatabaseConfig `mapstructure:"database"`
|
||||
Redis RedisConfig `mapstructure:"redis"`
|
||||
JWT JWTConfig `mapstructure:"jwt"`
|
||||
LinuxDo LinuxDoConnectConfig `mapstructure:"linuxdo_connect"`
|
||||
Default DefaultConfig `mapstructure:"default"`
|
||||
RateLimit RateLimitConfig `mapstructure:"rate_limit"`
|
||||
Pricing PricingConfig `mapstructure:"pricing"`
|
||||
Gateway GatewayConfig `mapstructure:"gateway"`
|
||||
Concurrency ConcurrencyConfig `mapstructure:"concurrency"`
|
||||
TokenRefresh TokenRefreshConfig `mapstructure:"token_refresh"`
|
||||
RunMode string `mapstructure:"run_mode" yaml:"run_mode"`
|
||||
Timezone string `mapstructure:"timezone"` // e.g. "Asia/Shanghai", "UTC"
|
||||
Gemini GeminiConfig `mapstructure:"gemini"`
|
||||
Update UpdateConfig `mapstructure:"update"`
|
||||
}
|
||||
|
||||
// UpdateConfig 在线更新相关配置
|
||||
@@ -322,6 +324,30 @@ type TurnstileConfig struct {
|
||||
Required bool `mapstructure:"required"`
|
||||
}
|
||||
|
||||
// LinuxDoConnectConfig 用于 LinuxDo Connect OAuth 登录(终端用户 SSO)。
|
||||
//
|
||||
// 注意:这与上游账号的 OAuth(例如 OpenAI/Gemini 账号接入)不是一回事。
|
||||
// 这里是用于登录 Sub2API 本身的用户体系。
|
||||
type LinuxDoConnectConfig struct {
|
||||
Enabled bool `mapstructure:"enabled"`
|
||||
ClientID string `mapstructure:"client_id"`
|
||||
ClientSecret string `mapstructure:"client_secret"`
|
||||
AuthorizeURL string `mapstructure:"authorize_url"`
|
||||
TokenURL string `mapstructure:"token_url"`
|
||||
UserInfoURL string `mapstructure:"userinfo_url"`
|
||||
Scopes string `mapstructure:"scopes"`
|
||||
RedirectURL string `mapstructure:"redirect_url"` // 后端回调地址(需在提供方后台登记)
|
||||
FrontendRedirectURL string `mapstructure:"frontend_redirect_url"` // 前端接收 token 的路由(默认:/auth/linuxdo/callback)
|
||||
TokenAuthMethod string `mapstructure:"token_auth_method"` // client_secret_post / client_secret_basic / none
|
||||
UsePKCE bool `mapstructure:"use_pkce"`
|
||||
|
||||
// 可选:用于从 userinfo JSON 中提取字段的 gjson 路径。
|
||||
// 为空时,服务端会尝试一组常见字段名。
|
||||
UserInfoEmailPath string `mapstructure:"userinfo_email_path"`
|
||||
UserInfoIDPath string `mapstructure:"userinfo_id_path"`
|
||||
UserInfoUsernamePath string `mapstructure:"userinfo_username_path"`
|
||||
}
|
||||
|
||||
type DefaultConfig struct {
|
||||
AdminEmail string `mapstructure:"admin_email"`
|
||||
AdminPassword string `mapstructure:"admin_password"`
|
||||
@@ -388,6 +414,18 @@ func Load() (*Config, error) {
|
||||
cfg.Server.Mode = "debug"
|
||||
}
|
||||
cfg.JWT.Secret = strings.TrimSpace(cfg.JWT.Secret)
|
||||
cfg.LinuxDo.ClientID = strings.TrimSpace(cfg.LinuxDo.ClientID)
|
||||
cfg.LinuxDo.ClientSecret = strings.TrimSpace(cfg.LinuxDo.ClientSecret)
|
||||
cfg.LinuxDo.AuthorizeURL = strings.TrimSpace(cfg.LinuxDo.AuthorizeURL)
|
||||
cfg.LinuxDo.TokenURL = strings.TrimSpace(cfg.LinuxDo.TokenURL)
|
||||
cfg.LinuxDo.UserInfoURL = strings.TrimSpace(cfg.LinuxDo.UserInfoURL)
|
||||
cfg.LinuxDo.Scopes = strings.TrimSpace(cfg.LinuxDo.Scopes)
|
||||
cfg.LinuxDo.RedirectURL = strings.TrimSpace(cfg.LinuxDo.RedirectURL)
|
||||
cfg.LinuxDo.FrontendRedirectURL = strings.TrimSpace(cfg.LinuxDo.FrontendRedirectURL)
|
||||
cfg.LinuxDo.TokenAuthMethod = strings.ToLower(strings.TrimSpace(cfg.LinuxDo.TokenAuthMethod))
|
||||
cfg.LinuxDo.UserInfoEmailPath = strings.TrimSpace(cfg.LinuxDo.UserInfoEmailPath)
|
||||
cfg.LinuxDo.UserInfoIDPath = strings.TrimSpace(cfg.LinuxDo.UserInfoIDPath)
|
||||
cfg.LinuxDo.UserInfoUsernamePath = strings.TrimSpace(cfg.LinuxDo.UserInfoUsernamePath)
|
||||
cfg.CORS.AllowedOrigins = normalizeStringSlice(cfg.CORS.AllowedOrigins)
|
||||
cfg.Security.ResponseHeaders.AdditionalAllowed = normalizeStringSlice(cfg.Security.ResponseHeaders.AdditionalAllowed)
|
||||
cfg.Security.ResponseHeaders.ForceRemove = normalizeStringSlice(cfg.Security.ResponseHeaders.ForceRemove)
|
||||
@@ -426,6 +464,81 @@ func Load() (*Config, error) {
|
||||
return &cfg, nil
|
||||
}
|
||||
|
||||
// ValidateAbsoluteHTTPURL 校验一个绝对 http(s) URL(禁止 fragment)。
|
||||
func ValidateAbsoluteHTTPURL(raw string) error {
|
||||
raw = strings.TrimSpace(raw)
|
||||
if raw == "" {
|
||||
return fmt.Errorf("empty url")
|
||||
}
|
||||
u, err := url.Parse(raw)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !u.IsAbs() {
|
||||
return fmt.Errorf("must be absolute")
|
||||
}
|
||||
if !isHTTPScheme(u.Scheme) {
|
||||
return fmt.Errorf("unsupported scheme: %s", u.Scheme)
|
||||
}
|
||||
if strings.TrimSpace(u.Host) == "" {
|
||||
return fmt.Errorf("missing host")
|
||||
}
|
||||
if u.Fragment != "" {
|
||||
return fmt.Errorf("must not include fragment")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidateFrontendRedirectURL 校验前端回调地址:
|
||||
// - 允许同源相对路径(以 / 开头)
|
||||
// - 或绝对 http(s) URL(禁止 fragment)
|
||||
func ValidateFrontendRedirectURL(raw string) error {
|
||||
raw = strings.TrimSpace(raw)
|
||||
if raw == "" {
|
||||
return fmt.Errorf("empty url")
|
||||
}
|
||||
if strings.ContainsAny(raw, "\r\n") {
|
||||
return fmt.Errorf("contains invalid characters")
|
||||
}
|
||||
if strings.HasPrefix(raw, "/") {
|
||||
if strings.HasPrefix(raw, "//") {
|
||||
return fmt.Errorf("must not start with //")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
u, err := url.Parse(raw)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !u.IsAbs() {
|
||||
return fmt.Errorf("must be absolute http(s) url or relative path")
|
||||
}
|
||||
if !isHTTPScheme(u.Scheme) {
|
||||
return fmt.Errorf("unsupported scheme: %s", u.Scheme)
|
||||
}
|
||||
if strings.TrimSpace(u.Host) == "" {
|
||||
return fmt.Errorf("missing host")
|
||||
}
|
||||
if u.Fragment != "" {
|
||||
return fmt.Errorf("must not include fragment")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func isHTTPScheme(scheme string) bool {
|
||||
return strings.EqualFold(scheme, "http") || strings.EqualFold(scheme, "https")
|
||||
}
|
||||
|
||||
func warnIfInsecureURL(field, raw string) {
|
||||
u, err := url.Parse(strings.TrimSpace(raw))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if strings.EqualFold(u.Scheme, "http") {
|
||||
log.Printf("Warning: %s uses http scheme; use https in production to avoid token leakage.", field)
|
||||
}
|
||||
}
|
||||
|
||||
func setDefaults() {
|
||||
viper.SetDefault("run_mode", RunModeStandard)
|
||||
|
||||
@@ -475,6 +588,22 @@ func setDefaults() {
|
||||
// Turnstile
|
||||
viper.SetDefault("turnstile.required", false)
|
||||
|
||||
// LinuxDo Connect OAuth 登录(终端用户 SSO)
|
||||
viper.SetDefault("linuxdo_connect.enabled", false)
|
||||
viper.SetDefault("linuxdo_connect.client_id", "")
|
||||
viper.SetDefault("linuxdo_connect.client_secret", "")
|
||||
viper.SetDefault("linuxdo_connect.authorize_url", "https://connect.linux.do/oauth2/authorize")
|
||||
viper.SetDefault("linuxdo_connect.token_url", "https://connect.linux.do/oauth2/token")
|
||||
viper.SetDefault("linuxdo_connect.userinfo_url", "https://connect.linux.do/api/user")
|
||||
viper.SetDefault("linuxdo_connect.scopes", "user")
|
||||
viper.SetDefault("linuxdo_connect.redirect_url", "")
|
||||
viper.SetDefault("linuxdo_connect.frontend_redirect_url", "/auth/linuxdo/callback")
|
||||
viper.SetDefault("linuxdo_connect.token_auth_method", "client_secret_post")
|
||||
viper.SetDefault("linuxdo_connect.use_pkce", false)
|
||||
viper.SetDefault("linuxdo_connect.userinfo_email_path", "")
|
||||
viper.SetDefault("linuxdo_connect.userinfo_id_path", "")
|
||||
viper.SetDefault("linuxdo_connect.userinfo_username_path", "")
|
||||
|
||||
// Database
|
||||
viper.SetDefault("database.host", "localhost")
|
||||
viper.SetDefault("database.port", 5432)
|
||||
@@ -586,6 +715,60 @@ func (c *Config) Validate() error {
|
||||
if c.Security.CSP.Enabled && strings.TrimSpace(c.Security.CSP.Policy) == "" {
|
||||
return fmt.Errorf("security.csp.policy is required when CSP is enabled")
|
||||
}
|
||||
if c.LinuxDo.Enabled {
|
||||
if strings.TrimSpace(c.LinuxDo.ClientID) == "" {
|
||||
return fmt.Errorf("linuxdo_connect.client_id is required when linuxdo_connect.enabled=true")
|
||||
}
|
||||
if strings.TrimSpace(c.LinuxDo.AuthorizeURL) == "" {
|
||||
return fmt.Errorf("linuxdo_connect.authorize_url is required when linuxdo_connect.enabled=true")
|
||||
}
|
||||
if strings.TrimSpace(c.LinuxDo.TokenURL) == "" {
|
||||
return fmt.Errorf("linuxdo_connect.token_url is required when linuxdo_connect.enabled=true")
|
||||
}
|
||||
if strings.TrimSpace(c.LinuxDo.UserInfoURL) == "" {
|
||||
return fmt.Errorf("linuxdo_connect.userinfo_url is required when linuxdo_connect.enabled=true")
|
||||
}
|
||||
if strings.TrimSpace(c.LinuxDo.RedirectURL) == "" {
|
||||
return fmt.Errorf("linuxdo_connect.redirect_url is required when linuxdo_connect.enabled=true")
|
||||
}
|
||||
method := strings.ToLower(strings.TrimSpace(c.LinuxDo.TokenAuthMethod))
|
||||
switch method {
|
||||
case "", "client_secret_post", "client_secret_basic", "none":
|
||||
default:
|
||||
return fmt.Errorf("linuxdo_connect.token_auth_method must be one of: client_secret_post/client_secret_basic/none")
|
||||
}
|
||||
if method == "none" && !c.LinuxDo.UsePKCE {
|
||||
return fmt.Errorf("linuxdo_connect.use_pkce must be true when linuxdo_connect.token_auth_method=none")
|
||||
}
|
||||
if (method == "" || method == "client_secret_post" || method == "client_secret_basic") && strings.TrimSpace(c.LinuxDo.ClientSecret) == "" {
|
||||
return fmt.Errorf("linuxdo_connect.client_secret is required when linuxdo_connect.enabled=true and token_auth_method is client_secret_post/client_secret_basic")
|
||||
}
|
||||
if strings.TrimSpace(c.LinuxDo.FrontendRedirectURL) == "" {
|
||||
return fmt.Errorf("linuxdo_connect.frontend_redirect_url is required when linuxdo_connect.enabled=true")
|
||||
}
|
||||
|
||||
if err := ValidateAbsoluteHTTPURL(c.LinuxDo.AuthorizeURL); err != nil {
|
||||
return fmt.Errorf("linuxdo_connect.authorize_url invalid: %w", err)
|
||||
}
|
||||
if err := ValidateAbsoluteHTTPURL(c.LinuxDo.TokenURL); err != nil {
|
||||
return fmt.Errorf("linuxdo_connect.token_url invalid: %w", err)
|
||||
}
|
||||
if err := ValidateAbsoluteHTTPURL(c.LinuxDo.UserInfoURL); err != nil {
|
||||
return fmt.Errorf("linuxdo_connect.userinfo_url invalid: %w", err)
|
||||
}
|
||||
if err := ValidateAbsoluteHTTPURL(c.LinuxDo.RedirectURL); err != nil {
|
||||
return fmt.Errorf("linuxdo_connect.redirect_url invalid: %w", err)
|
||||
}
|
||||
if err := ValidateFrontendRedirectURL(c.LinuxDo.FrontendRedirectURL); err != nil {
|
||||
return fmt.Errorf("linuxdo_connect.frontend_redirect_url invalid: %w", err)
|
||||
}
|
||||
|
||||
warnIfInsecureURL("linuxdo_connect.authorize_url", c.LinuxDo.AuthorizeURL)
|
||||
warnIfInsecureURL("linuxdo_connect.token_url", c.LinuxDo.TokenURL)
|
||||
warnIfInsecureURL("linuxdo_connect.userinfo_url", c.LinuxDo.UserInfoURL)
|
||||
warnIfInsecureURL("linuxdo_connect.redirect_url", c.LinuxDo.RedirectURL)
|
||||
warnIfInsecureURL("linuxdo_connect.frontend_redirect_url", c.LinuxDo.FrontendRedirectURL)
|
||||
}
|
||||
if c.Billing.CircuitBreaker.Enabled {
|
||||
if c.Billing.CircuitBreaker.FailureThreshold <= 0 {
|
||||
return fmt.Errorf("billing.circuit_breaker.failure_threshold must be positive")
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -90,3 +91,53 @@ func TestLoadDefaultSecurityToggles(t *testing.T) {
|
||||
t.Fatalf("ResponseHeaders.Enabled = true, want false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateLinuxDoFrontendRedirectURL(t *testing.T) {
|
||||
viper.Reset()
|
||||
|
||||
cfg, err := Load()
|
||||
if err != nil {
|
||||
t.Fatalf("Load() error: %v", err)
|
||||
}
|
||||
|
||||
cfg.LinuxDo.Enabled = true
|
||||
cfg.LinuxDo.ClientID = "test-client"
|
||||
cfg.LinuxDo.ClientSecret = "test-secret"
|
||||
cfg.LinuxDo.RedirectURL = "https://example.com/api/v1/auth/oauth/linuxdo/callback"
|
||||
cfg.LinuxDo.TokenAuthMethod = "client_secret_post"
|
||||
cfg.LinuxDo.UsePKCE = false
|
||||
|
||||
cfg.LinuxDo.FrontendRedirectURL = "javascript:alert(1)"
|
||||
err = cfg.Validate()
|
||||
if err == nil {
|
||||
t.Fatalf("Validate() expected error for javascript scheme, got nil")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "linuxdo_connect.frontend_redirect_url") {
|
||||
t.Fatalf("Validate() expected frontend_redirect_url error, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateLinuxDoPKCERequiredForPublicClient(t *testing.T) {
|
||||
viper.Reset()
|
||||
|
||||
cfg, err := Load()
|
||||
if err != nil {
|
||||
t.Fatalf("Load() error: %v", err)
|
||||
}
|
||||
|
||||
cfg.LinuxDo.Enabled = true
|
||||
cfg.LinuxDo.ClientID = "test-client"
|
||||
cfg.LinuxDo.ClientSecret = ""
|
||||
cfg.LinuxDo.RedirectURL = "https://example.com/api/v1/auth/oauth/linuxdo/callback"
|
||||
cfg.LinuxDo.FrontendRedirectURL = "/auth/linuxdo/callback"
|
||||
cfg.LinuxDo.TokenAuthMethod = "none"
|
||||
cfg.LinuxDo.UsePKCE = false
|
||||
|
||||
err = cfg.Validate()
|
||||
if err == nil {
|
||||
t.Fatalf("Validate() expected error when token_auth_method=none and use_pkce=false, got nil")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "linuxdo_connect.use_pkce") {
|
||||
t.Fatalf("Validate() expected use_pkce error, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -116,6 +116,7 @@ type BulkUpdateAccountsRequest struct {
|
||||
Concurrency *int `json:"concurrency"`
|
||||
Priority *int `json:"priority"`
|
||||
Status string `json:"status" binding:"omitempty,oneof=active inactive error"`
|
||||
Schedulable *bool `json:"schedulable"`
|
||||
GroupIDs *[]int64 `json:"group_ids"`
|
||||
Credentials map[string]any `json:"credentials"`
|
||||
Extra map[string]any `json:"extra"`
|
||||
@@ -136,6 +137,11 @@ func (h *AccountHandler) List(c *gin.Context) {
|
||||
accountType := c.Query("type")
|
||||
status := c.Query("status")
|
||||
search := c.Query("search")
|
||||
// 标准化和验证 search 参数
|
||||
search = strings.TrimSpace(search)
|
||||
if len(search) > 100 {
|
||||
search = search[:100]
|
||||
}
|
||||
|
||||
accounts, total, err := h.adminService.ListAccounts(c.Request.Context(), page, pageSize, platform, accountType, status, search)
|
||||
if err != nil {
|
||||
@@ -655,6 +661,7 @@ func (h *AccountHandler) BulkUpdate(c *gin.Context) {
|
||||
req.Concurrency != nil ||
|
||||
req.Priority != nil ||
|
||||
req.Status != "" ||
|
||||
req.Schedulable != nil ||
|
||||
req.GroupIDs != nil ||
|
||||
len(req.Credentials) > 0 ||
|
||||
len(req.Extra) > 0
|
||||
@@ -671,6 +678,7 @@ func (h *AccountHandler) BulkUpdate(c *gin.Context) {
|
||||
Concurrency: req.Concurrency,
|
||||
Priority: req.Priority,
|
||||
Status: req.Status,
|
||||
Schedulable: req.Schedulable,
|
||||
GroupIDs: req.GroupIDs,
|
||||
Credentials: req.Credentials,
|
||||
Extra: req.Extra,
|
||||
|
||||
@@ -2,6 +2,7 @@ package admin
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
@@ -67,6 +68,12 @@ func (h *GroupHandler) List(c *gin.Context) {
|
||||
page, pageSize := response.ParsePagination(c)
|
||||
platform := c.Query("platform")
|
||||
status := c.Query("status")
|
||||
search := c.Query("search")
|
||||
// 标准化和验证 search 参数
|
||||
search = strings.TrimSpace(search)
|
||||
if len(search) > 100 {
|
||||
search = search[:100]
|
||||
}
|
||||
isExclusiveStr := c.Query("is_exclusive")
|
||||
|
||||
var isExclusive *bool
|
||||
@@ -75,7 +82,7 @@ func (h *GroupHandler) List(c *gin.Context) {
|
||||
isExclusive = &val
|
||||
}
|
||||
|
||||
groups, total, err := h.adminService.ListGroups(c.Request.Context(), page, pageSize, platform, status, isExclusive)
|
||||
groups, total, err := h.adminService.ListGroups(c.Request.Context(), page, pageSize, platform, status, search, isExclusive)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
|
||||
@@ -51,6 +51,11 @@ func (h *ProxyHandler) List(c *gin.Context) {
|
||||
protocol := c.Query("protocol")
|
||||
status := c.Query("status")
|
||||
search := c.Query("search")
|
||||
// 标准化和验证 search 参数
|
||||
search = strings.TrimSpace(search)
|
||||
if len(search) > 100 {
|
||||
search = search[:100]
|
||||
}
|
||||
|
||||
proxies, total, err := h.adminService.ListProxiesWithAccountCount(c.Request.Context(), page, pageSize, protocol, status, search)
|
||||
if err != nil {
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"encoding/csv"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
@@ -41,6 +42,11 @@ func (h *RedeemHandler) List(c *gin.Context) {
|
||||
codeType := c.Query("type")
|
||||
status := c.Query("status")
|
||||
search := c.Query("search")
|
||||
// 标准化和验证 search 参数
|
||||
search = strings.TrimSpace(search)
|
||||
if len(search) > 100 {
|
||||
search = search[:100]
|
||||
}
|
||||
|
||||
codes, total, err := h.adminService.ListRedeemCodes(c.Request.Context(), page, pageSize, codeType, status, search)
|
||||
if err != nil {
|
||||
|
||||
@@ -2,8 +2,10 @@ package admin
|
||||
|
||||
import (
|
||||
"log"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
@@ -38,33 +40,37 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
|
||||
}
|
||||
|
||||
response.Success(c, dto.SystemSettings{
|
||||
RegistrationEnabled: settings.RegistrationEnabled,
|
||||
EmailVerifyEnabled: settings.EmailVerifyEnabled,
|
||||
SMTPHost: settings.SMTPHost,
|
||||
SMTPPort: settings.SMTPPort,
|
||||
SMTPUsername: settings.SMTPUsername,
|
||||
SMTPPasswordConfigured: settings.SMTPPasswordConfigured,
|
||||
SMTPFrom: settings.SMTPFrom,
|
||||
SMTPFromName: settings.SMTPFromName,
|
||||
SMTPUseTLS: settings.SMTPUseTLS,
|
||||
TurnstileEnabled: settings.TurnstileEnabled,
|
||||
TurnstileSiteKey: settings.TurnstileSiteKey,
|
||||
TurnstileSecretKeyConfigured: settings.TurnstileSecretKeyConfigured,
|
||||
SiteName: settings.SiteName,
|
||||
SiteLogo: settings.SiteLogo,
|
||||
SiteSubtitle: settings.SiteSubtitle,
|
||||
APIBaseURL: settings.APIBaseURL,
|
||||
ContactInfo: settings.ContactInfo,
|
||||
DocURL: settings.DocURL,
|
||||
DefaultConcurrency: settings.DefaultConcurrency,
|
||||
DefaultBalance: settings.DefaultBalance,
|
||||
EnableModelFallback: settings.EnableModelFallback,
|
||||
FallbackModelAnthropic: settings.FallbackModelAnthropic,
|
||||
FallbackModelOpenAI: settings.FallbackModelOpenAI,
|
||||
FallbackModelGemini: settings.FallbackModelGemini,
|
||||
FallbackModelAntigravity: settings.FallbackModelAntigravity,
|
||||
EnableIdentityPatch: settings.EnableIdentityPatch,
|
||||
IdentityPatchPrompt: settings.IdentityPatchPrompt,
|
||||
RegistrationEnabled: settings.RegistrationEnabled,
|
||||
EmailVerifyEnabled: settings.EmailVerifyEnabled,
|
||||
SMTPHost: settings.SMTPHost,
|
||||
SMTPPort: settings.SMTPPort,
|
||||
SMTPUsername: settings.SMTPUsername,
|
||||
SMTPPasswordConfigured: settings.SMTPPasswordConfigured,
|
||||
SMTPFrom: settings.SMTPFrom,
|
||||
SMTPFromName: settings.SMTPFromName,
|
||||
SMTPUseTLS: settings.SMTPUseTLS,
|
||||
TurnstileEnabled: settings.TurnstileEnabled,
|
||||
TurnstileSiteKey: settings.TurnstileSiteKey,
|
||||
TurnstileSecretKeyConfigured: settings.TurnstileSecretKeyConfigured,
|
||||
LinuxDoConnectEnabled: settings.LinuxDoConnectEnabled,
|
||||
LinuxDoConnectClientID: settings.LinuxDoConnectClientID,
|
||||
LinuxDoConnectClientSecretConfigured: settings.LinuxDoConnectClientSecretConfigured,
|
||||
LinuxDoConnectRedirectURL: settings.LinuxDoConnectRedirectURL,
|
||||
SiteName: settings.SiteName,
|
||||
SiteLogo: settings.SiteLogo,
|
||||
SiteSubtitle: settings.SiteSubtitle,
|
||||
APIBaseURL: settings.APIBaseURL,
|
||||
ContactInfo: settings.ContactInfo,
|
||||
DocURL: settings.DocURL,
|
||||
DefaultConcurrency: settings.DefaultConcurrency,
|
||||
DefaultBalance: settings.DefaultBalance,
|
||||
EnableModelFallback: settings.EnableModelFallback,
|
||||
FallbackModelAnthropic: settings.FallbackModelAnthropic,
|
||||
FallbackModelOpenAI: settings.FallbackModelOpenAI,
|
||||
FallbackModelGemini: settings.FallbackModelGemini,
|
||||
FallbackModelAntigravity: settings.FallbackModelAntigravity,
|
||||
EnableIdentityPatch: settings.EnableIdentityPatch,
|
||||
IdentityPatchPrompt: settings.IdentityPatchPrompt,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -88,6 +94,12 @@ type UpdateSettingsRequest struct {
|
||||
TurnstileSiteKey string `json:"turnstile_site_key"`
|
||||
TurnstileSecretKey string `json:"turnstile_secret_key"`
|
||||
|
||||
// LinuxDo Connect OAuth 登录(终端用户 SSO)
|
||||
LinuxDoConnectEnabled bool `json:"linuxdo_connect_enabled"`
|
||||
LinuxDoConnectClientID string `json:"linuxdo_connect_client_id"`
|
||||
LinuxDoConnectClientSecret string `json:"linuxdo_connect_client_secret"`
|
||||
LinuxDoConnectRedirectURL string `json:"linuxdo_connect_redirect_url"`
|
||||
|
||||
// OEM设置
|
||||
SiteName string `json:"site_name"`
|
||||
SiteLogo string `json:"site_logo"`
|
||||
@@ -165,34 +177,67 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
// LinuxDo Connect 参数验证
|
||||
if req.LinuxDoConnectEnabled {
|
||||
req.LinuxDoConnectClientID = strings.TrimSpace(req.LinuxDoConnectClientID)
|
||||
req.LinuxDoConnectClientSecret = strings.TrimSpace(req.LinuxDoConnectClientSecret)
|
||||
req.LinuxDoConnectRedirectURL = strings.TrimSpace(req.LinuxDoConnectRedirectURL)
|
||||
|
||||
if req.LinuxDoConnectClientID == "" {
|
||||
response.BadRequest(c, "LinuxDo Client ID is required when enabled")
|
||||
return
|
||||
}
|
||||
if req.LinuxDoConnectRedirectURL == "" {
|
||||
response.BadRequest(c, "LinuxDo Redirect URL is required when enabled")
|
||||
return
|
||||
}
|
||||
if err := config.ValidateAbsoluteHTTPURL(req.LinuxDoConnectRedirectURL); err != nil {
|
||||
response.BadRequest(c, "LinuxDo Redirect URL must be an absolute http(s) URL")
|
||||
return
|
||||
}
|
||||
|
||||
// 如果未提供 client_secret,则保留现有值(如有)。
|
||||
if req.LinuxDoConnectClientSecret == "" {
|
||||
if previousSettings.LinuxDoConnectClientSecret == "" {
|
||||
response.BadRequest(c, "LinuxDo Client Secret is required when enabled")
|
||||
return
|
||||
}
|
||||
req.LinuxDoConnectClientSecret = previousSettings.LinuxDoConnectClientSecret
|
||||
}
|
||||
}
|
||||
|
||||
settings := &service.SystemSettings{
|
||||
RegistrationEnabled: req.RegistrationEnabled,
|
||||
EmailVerifyEnabled: req.EmailVerifyEnabled,
|
||||
SMTPHost: req.SMTPHost,
|
||||
SMTPPort: req.SMTPPort,
|
||||
SMTPUsername: req.SMTPUsername,
|
||||
SMTPPassword: req.SMTPPassword,
|
||||
SMTPFrom: req.SMTPFrom,
|
||||
SMTPFromName: req.SMTPFromName,
|
||||
SMTPUseTLS: req.SMTPUseTLS,
|
||||
TurnstileEnabled: req.TurnstileEnabled,
|
||||
TurnstileSiteKey: req.TurnstileSiteKey,
|
||||
TurnstileSecretKey: req.TurnstileSecretKey,
|
||||
SiteName: req.SiteName,
|
||||
SiteLogo: req.SiteLogo,
|
||||
SiteSubtitle: req.SiteSubtitle,
|
||||
APIBaseURL: req.APIBaseURL,
|
||||
ContactInfo: req.ContactInfo,
|
||||
DocURL: req.DocURL,
|
||||
DefaultConcurrency: req.DefaultConcurrency,
|
||||
DefaultBalance: req.DefaultBalance,
|
||||
EnableModelFallback: req.EnableModelFallback,
|
||||
FallbackModelAnthropic: req.FallbackModelAnthropic,
|
||||
FallbackModelOpenAI: req.FallbackModelOpenAI,
|
||||
FallbackModelGemini: req.FallbackModelGemini,
|
||||
FallbackModelAntigravity: req.FallbackModelAntigravity,
|
||||
EnableIdentityPatch: req.EnableIdentityPatch,
|
||||
IdentityPatchPrompt: req.IdentityPatchPrompt,
|
||||
RegistrationEnabled: req.RegistrationEnabled,
|
||||
EmailVerifyEnabled: req.EmailVerifyEnabled,
|
||||
SMTPHost: req.SMTPHost,
|
||||
SMTPPort: req.SMTPPort,
|
||||
SMTPUsername: req.SMTPUsername,
|
||||
SMTPPassword: req.SMTPPassword,
|
||||
SMTPFrom: req.SMTPFrom,
|
||||
SMTPFromName: req.SMTPFromName,
|
||||
SMTPUseTLS: req.SMTPUseTLS,
|
||||
TurnstileEnabled: req.TurnstileEnabled,
|
||||
TurnstileSiteKey: req.TurnstileSiteKey,
|
||||
TurnstileSecretKey: req.TurnstileSecretKey,
|
||||
LinuxDoConnectEnabled: req.LinuxDoConnectEnabled,
|
||||
LinuxDoConnectClientID: req.LinuxDoConnectClientID,
|
||||
LinuxDoConnectClientSecret: req.LinuxDoConnectClientSecret,
|
||||
LinuxDoConnectRedirectURL: req.LinuxDoConnectRedirectURL,
|
||||
SiteName: req.SiteName,
|
||||
SiteLogo: req.SiteLogo,
|
||||
SiteSubtitle: req.SiteSubtitle,
|
||||
APIBaseURL: req.APIBaseURL,
|
||||
ContactInfo: req.ContactInfo,
|
||||
DocURL: req.DocURL,
|
||||
DefaultConcurrency: req.DefaultConcurrency,
|
||||
DefaultBalance: req.DefaultBalance,
|
||||
EnableModelFallback: req.EnableModelFallback,
|
||||
FallbackModelAnthropic: req.FallbackModelAnthropic,
|
||||
FallbackModelOpenAI: req.FallbackModelOpenAI,
|
||||
FallbackModelGemini: req.FallbackModelGemini,
|
||||
FallbackModelAntigravity: req.FallbackModelAntigravity,
|
||||
EnableIdentityPatch: req.EnableIdentityPatch,
|
||||
IdentityPatchPrompt: req.IdentityPatchPrompt,
|
||||
}
|
||||
|
||||
if err := h.settingService.UpdateSettings(c.Request.Context(), settings); err != nil {
|
||||
@@ -210,33 +255,37 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
}
|
||||
|
||||
response.Success(c, dto.SystemSettings{
|
||||
RegistrationEnabled: updatedSettings.RegistrationEnabled,
|
||||
EmailVerifyEnabled: updatedSettings.EmailVerifyEnabled,
|
||||
SMTPHost: updatedSettings.SMTPHost,
|
||||
SMTPPort: updatedSettings.SMTPPort,
|
||||
SMTPUsername: updatedSettings.SMTPUsername,
|
||||
SMTPPasswordConfigured: updatedSettings.SMTPPasswordConfigured,
|
||||
SMTPFrom: updatedSettings.SMTPFrom,
|
||||
SMTPFromName: updatedSettings.SMTPFromName,
|
||||
SMTPUseTLS: updatedSettings.SMTPUseTLS,
|
||||
TurnstileEnabled: updatedSettings.TurnstileEnabled,
|
||||
TurnstileSiteKey: updatedSettings.TurnstileSiteKey,
|
||||
TurnstileSecretKeyConfigured: updatedSettings.TurnstileSecretKeyConfigured,
|
||||
SiteName: updatedSettings.SiteName,
|
||||
SiteLogo: updatedSettings.SiteLogo,
|
||||
SiteSubtitle: updatedSettings.SiteSubtitle,
|
||||
APIBaseURL: updatedSettings.APIBaseURL,
|
||||
ContactInfo: updatedSettings.ContactInfo,
|
||||
DocURL: updatedSettings.DocURL,
|
||||
DefaultConcurrency: updatedSettings.DefaultConcurrency,
|
||||
DefaultBalance: updatedSettings.DefaultBalance,
|
||||
EnableModelFallback: updatedSettings.EnableModelFallback,
|
||||
FallbackModelAnthropic: updatedSettings.FallbackModelAnthropic,
|
||||
FallbackModelOpenAI: updatedSettings.FallbackModelOpenAI,
|
||||
FallbackModelGemini: updatedSettings.FallbackModelGemini,
|
||||
FallbackModelAntigravity: updatedSettings.FallbackModelAntigravity,
|
||||
EnableIdentityPatch: updatedSettings.EnableIdentityPatch,
|
||||
IdentityPatchPrompt: updatedSettings.IdentityPatchPrompt,
|
||||
RegistrationEnabled: updatedSettings.RegistrationEnabled,
|
||||
EmailVerifyEnabled: updatedSettings.EmailVerifyEnabled,
|
||||
SMTPHost: updatedSettings.SMTPHost,
|
||||
SMTPPort: updatedSettings.SMTPPort,
|
||||
SMTPUsername: updatedSettings.SMTPUsername,
|
||||
SMTPPasswordConfigured: updatedSettings.SMTPPasswordConfigured,
|
||||
SMTPFrom: updatedSettings.SMTPFrom,
|
||||
SMTPFromName: updatedSettings.SMTPFromName,
|
||||
SMTPUseTLS: updatedSettings.SMTPUseTLS,
|
||||
TurnstileEnabled: updatedSettings.TurnstileEnabled,
|
||||
TurnstileSiteKey: updatedSettings.TurnstileSiteKey,
|
||||
TurnstileSecretKeyConfigured: updatedSettings.TurnstileSecretKeyConfigured,
|
||||
LinuxDoConnectEnabled: updatedSettings.LinuxDoConnectEnabled,
|
||||
LinuxDoConnectClientID: updatedSettings.LinuxDoConnectClientID,
|
||||
LinuxDoConnectClientSecretConfigured: updatedSettings.LinuxDoConnectClientSecretConfigured,
|
||||
LinuxDoConnectRedirectURL: updatedSettings.LinuxDoConnectRedirectURL,
|
||||
SiteName: updatedSettings.SiteName,
|
||||
SiteLogo: updatedSettings.SiteLogo,
|
||||
SiteSubtitle: updatedSettings.SiteSubtitle,
|
||||
APIBaseURL: updatedSettings.APIBaseURL,
|
||||
ContactInfo: updatedSettings.ContactInfo,
|
||||
DocURL: updatedSettings.DocURL,
|
||||
DefaultConcurrency: updatedSettings.DefaultConcurrency,
|
||||
DefaultBalance: updatedSettings.DefaultBalance,
|
||||
EnableModelFallback: updatedSettings.EnableModelFallback,
|
||||
FallbackModelAnthropic: updatedSettings.FallbackModelAnthropic,
|
||||
FallbackModelOpenAI: updatedSettings.FallbackModelOpenAI,
|
||||
FallbackModelGemini: updatedSettings.FallbackModelGemini,
|
||||
FallbackModelAntigravity: updatedSettings.FallbackModelAntigravity,
|
||||
EnableIdentityPatch: updatedSettings.EnableIdentityPatch,
|
||||
IdentityPatchPrompt: updatedSettings.IdentityPatchPrompt,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -298,6 +347,18 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
|
||||
if req.TurnstileSecretKey != "" {
|
||||
changed = append(changed, "turnstile_secret_key")
|
||||
}
|
||||
if before.LinuxDoConnectEnabled != after.LinuxDoConnectEnabled {
|
||||
changed = append(changed, "linuxdo_connect_enabled")
|
||||
}
|
||||
if before.LinuxDoConnectClientID != after.LinuxDoConnectClientID {
|
||||
changed = append(changed, "linuxdo_connect_client_id")
|
||||
}
|
||||
if req.LinuxDoConnectClientSecret != "" {
|
||||
changed = append(changed, "linuxdo_connect_client_secret")
|
||||
}
|
||||
if before.LinuxDoConnectRedirectURL != after.LinuxDoConnectRedirectURL {
|
||||
changed = append(changed, "linuxdo_connect_redirect_url")
|
||||
}
|
||||
if before.SiteName != after.SiteName {
|
||||
changed = append(changed, "site_name")
|
||||
}
|
||||
@@ -337,6 +398,12 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
|
||||
if before.FallbackModelAntigravity != after.FallbackModelAntigravity {
|
||||
changed = append(changed, "fallback_model_antigravity")
|
||||
}
|
||||
if before.EnableIdentityPatch != after.EnableIdentityPatch {
|
||||
changed = append(changed, "enable_identity_patch")
|
||||
}
|
||||
if before.IdentityPatchPrompt != after.IdentityPatchPrompt {
|
||||
changed = append(changed, "identity_patch_prompt")
|
||||
}
|
||||
return changed
|
||||
}
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@ package admin
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
@@ -63,10 +64,17 @@ type UpdateBalanceRequest struct {
|
||||
func (h *UserHandler) List(c *gin.Context) {
|
||||
page, pageSize := response.ParsePagination(c)
|
||||
|
||||
search := c.Query("search")
|
||||
// 标准化和验证 search 参数
|
||||
search = strings.TrimSpace(search)
|
||||
if len(search) > 100 {
|
||||
search = search[:100]
|
||||
}
|
||||
|
||||
filters := service.UserListFilters{
|
||||
Status: c.Query("status"),
|
||||
Role: c.Query("role"),
|
||||
Search: c.Query("search"),
|
||||
Search: search,
|
||||
Attributes: parseAttributeFilters(c),
|
||||
}
|
||||
|
||||
|
||||
@@ -15,14 +15,16 @@ type AuthHandler struct {
|
||||
cfg *config.Config
|
||||
authService *service.AuthService
|
||||
userService *service.UserService
|
||||
settingSvc *service.SettingService
|
||||
}
|
||||
|
||||
// NewAuthHandler creates a new AuthHandler
|
||||
func NewAuthHandler(cfg *config.Config, authService *service.AuthService, userService *service.UserService) *AuthHandler {
|
||||
func NewAuthHandler(cfg *config.Config, authService *service.AuthService, userService *service.UserService, settingService *service.SettingService) *AuthHandler {
|
||||
return &AuthHandler{
|
||||
cfg: cfg,
|
||||
authService: authService,
|
||||
userService: userService,
|
||||
settingSvc: settingService,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
679
backend/internal/handler/auth_linuxdo_oauth.go
Normal file
679
backend/internal/handler/auth_linuxdo_oauth.go
Normal file
@@ -0,0 +1,679 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/oauth"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/imroc/req/v3"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
const (
|
||||
linuxDoOAuthCookiePath = "/api/v1/auth/oauth/linuxdo"
|
||||
linuxDoOAuthStateCookieName = "linuxdo_oauth_state"
|
||||
linuxDoOAuthVerifierCookie = "linuxdo_oauth_verifier"
|
||||
linuxDoOAuthRedirectCookie = "linuxdo_oauth_redirect"
|
||||
linuxDoOAuthCookieMaxAgeSec = 10 * 60 // 10 minutes
|
||||
linuxDoOAuthDefaultRedirectTo = "/dashboard"
|
||||
linuxDoOAuthDefaultFrontendCB = "/auth/linuxdo/callback"
|
||||
|
||||
linuxDoOAuthMaxRedirectLen = 2048
|
||||
linuxDoOAuthMaxFragmentValueLen = 512
|
||||
linuxDoOAuthMaxSubjectLen = 64 - len("linuxdo-")
|
||||
)
|
||||
|
||||
type linuxDoTokenResponse struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
TokenType string `json:"token_type"`
|
||||
ExpiresIn int64 `json:"expires_in"`
|
||||
RefreshToken string `json:"refresh_token,omitempty"`
|
||||
Scope string `json:"scope,omitempty"`
|
||||
}
|
||||
|
||||
type linuxDoTokenExchangeError struct {
|
||||
StatusCode int
|
||||
ProviderError string
|
||||
ProviderDescription string
|
||||
Body string
|
||||
}
|
||||
|
||||
func (e *linuxDoTokenExchangeError) Error() string {
|
||||
if e == nil {
|
||||
return ""
|
||||
}
|
||||
parts := []string{fmt.Sprintf("token exchange status=%d", e.StatusCode)}
|
||||
if strings.TrimSpace(e.ProviderError) != "" {
|
||||
parts = append(parts, "error="+strings.TrimSpace(e.ProviderError))
|
||||
}
|
||||
if strings.TrimSpace(e.ProviderDescription) != "" {
|
||||
parts = append(parts, "error_description="+strings.TrimSpace(e.ProviderDescription))
|
||||
}
|
||||
return strings.Join(parts, " ")
|
||||
}
|
||||
|
||||
// LinuxDoOAuthStart 启动 LinuxDo Connect OAuth 登录流程。
|
||||
// GET /api/v1/auth/oauth/linuxdo/start?redirect=/dashboard
|
||||
func (h *AuthHandler) LinuxDoOAuthStart(c *gin.Context) {
|
||||
cfg, err := h.getLinuxDoOAuthConfig(c.Request.Context())
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
state, err := oauth.GenerateState()
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_STATE_GEN_FAILED", "failed to generate oauth state").WithCause(err))
|
||||
return
|
||||
}
|
||||
|
||||
redirectTo := sanitizeFrontendRedirectPath(c.Query("redirect"))
|
||||
if redirectTo == "" {
|
||||
redirectTo = linuxDoOAuthDefaultRedirectTo
|
||||
}
|
||||
|
||||
secureCookie := isRequestHTTPS(c)
|
||||
setCookie(c, linuxDoOAuthStateCookieName, encodeCookieValue(state), linuxDoOAuthCookieMaxAgeSec, secureCookie)
|
||||
setCookie(c, linuxDoOAuthRedirectCookie, encodeCookieValue(redirectTo), linuxDoOAuthCookieMaxAgeSec, secureCookie)
|
||||
|
||||
codeChallenge := ""
|
||||
if cfg.UsePKCE {
|
||||
verifier, err := oauth.GenerateCodeVerifier()
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_PKCE_GEN_FAILED", "failed to generate pkce verifier").WithCause(err))
|
||||
return
|
||||
}
|
||||
codeChallenge = oauth.GenerateCodeChallenge(verifier)
|
||||
setCookie(c, linuxDoOAuthVerifierCookie, encodeCookieValue(verifier), linuxDoOAuthCookieMaxAgeSec, secureCookie)
|
||||
}
|
||||
|
||||
redirectURI := strings.TrimSpace(cfg.RedirectURL)
|
||||
if redirectURI == "" {
|
||||
response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth redirect url not configured"))
|
||||
return
|
||||
}
|
||||
|
||||
authURL, err := buildLinuxDoAuthorizeURL(cfg, state, codeChallenge, redirectURI)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_BUILD_URL_FAILED", "failed to build oauth authorization url").WithCause(err))
|
||||
return
|
||||
}
|
||||
|
||||
c.Redirect(http.StatusFound, authURL)
|
||||
}
|
||||
|
||||
// LinuxDoOAuthCallback 处理 OAuth 回调:创建/登录用户,然后重定向到前端。
|
||||
// GET /api/v1/auth/oauth/linuxdo/callback?code=...&state=...
|
||||
func (h *AuthHandler) LinuxDoOAuthCallback(c *gin.Context) {
|
||||
cfg, cfgErr := h.getLinuxDoOAuthConfig(c.Request.Context())
|
||||
if cfgErr != nil {
|
||||
response.ErrorFrom(c, cfgErr)
|
||||
return
|
||||
}
|
||||
|
||||
frontendCallback := strings.TrimSpace(cfg.FrontendRedirectURL)
|
||||
if frontendCallback == "" {
|
||||
frontendCallback = linuxDoOAuthDefaultFrontendCB
|
||||
}
|
||||
|
||||
if providerErr := strings.TrimSpace(c.Query("error")); providerErr != "" {
|
||||
redirectOAuthError(c, frontendCallback, "provider_error", providerErr, c.Query("error_description"))
|
||||
return
|
||||
}
|
||||
|
||||
code := strings.TrimSpace(c.Query("code"))
|
||||
state := strings.TrimSpace(c.Query("state"))
|
||||
if code == "" || state == "" {
|
||||
redirectOAuthError(c, frontendCallback, "missing_params", "missing code/state", "")
|
||||
return
|
||||
}
|
||||
|
||||
secureCookie := isRequestHTTPS(c)
|
||||
defer func() {
|
||||
clearCookie(c, linuxDoOAuthStateCookieName, secureCookie)
|
||||
clearCookie(c, linuxDoOAuthVerifierCookie, secureCookie)
|
||||
clearCookie(c, linuxDoOAuthRedirectCookie, secureCookie)
|
||||
}()
|
||||
|
||||
expectedState, err := readCookieDecoded(c, linuxDoOAuthStateCookieName)
|
||||
if err != nil || expectedState == "" || state != expectedState {
|
||||
redirectOAuthError(c, frontendCallback, "invalid_state", "invalid oauth state", "")
|
||||
return
|
||||
}
|
||||
|
||||
redirectTo, _ := readCookieDecoded(c, linuxDoOAuthRedirectCookie)
|
||||
redirectTo = sanitizeFrontendRedirectPath(redirectTo)
|
||||
if redirectTo == "" {
|
||||
redirectTo = linuxDoOAuthDefaultRedirectTo
|
||||
}
|
||||
|
||||
codeVerifier := ""
|
||||
if cfg.UsePKCE {
|
||||
codeVerifier, _ = readCookieDecoded(c, linuxDoOAuthVerifierCookie)
|
||||
if codeVerifier == "" {
|
||||
redirectOAuthError(c, frontendCallback, "missing_verifier", "missing pkce verifier", "")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
redirectURI := strings.TrimSpace(cfg.RedirectURL)
|
||||
if redirectURI == "" {
|
||||
redirectOAuthError(c, frontendCallback, "config_error", "oauth redirect url not configured", "")
|
||||
return
|
||||
}
|
||||
|
||||
tokenResp, err := linuxDoExchangeCode(c.Request.Context(), cfg, code, redirectURI, codeVerifier)
|
||||
if err != nil {
|
||||
description := ""
|
||||
var exchangeErr *linuxDoTokenExchangeError
|
||||
if errors.As(err, &exchangeErr) && exchangeErr != nil {
|
||||
log.Printf(
|
||||
"[LinuxDo OAuth] token exchange failed: status=%d provider_error=%q provider_description=%q body=%s",
|
||||
exchangeErr.StatusCode,
|
||||
exchangeErr.ProviderError,
|
||||
exchangeErr.ProviderDescription,
|
||||
truncateLogValue(exchangeErr.Body, 2048),
|
||||
)
|
||||
description = exchangeErr.Error()
|
||||
} else {
|
||||
log.Printf("[LinuxDo OAuth] token exchange failed: %v", err)
|
||||
description = err.Error()
|
||||
}
|
||||
redirectOAuthError(c, frontendCallback, "token_exchange_failed", "failed to exchange oauth code", singleLine(description))
|
||||
return
|
||||
}
|
||||
|
||||
email, username, subject, err := linuxDoFetchUserInfo(c.Request.Context(), cfg, tokenResp)
|
||||
if err != nil {
|
||||
log.Printf("[LinuxDo OAuth] userinfo fetch failed: %v", err)
|
||||
redirectOAuthError(c, frontendCallback, "userinfo_failed", "failed to fetch user info", "")
|
||||
return
|
||||
}
|
||||
|
||||
// 安全考虑:不要把第三方返回的 email 直接映射到本地账号(可能与本地邮箱用户冲突导致账号被接管)。
|
||||
// 统一使用基于 subject 的稳定合成邮箱来做账号绑定。
|
||||
if subject != "" {
|
||||
email = linuxDoSyntheticEmail(subject)
|
||||
}
|
||||
|
||||
jwtToken, _, err := h.authService.LoginOrRegisterOAuth(c.Request.Context(), email, username)
|
||||
if err != nil {
|
||||
// 避免把内部细节泄露给客户端;给前端保留结构化原因与提示信息即可。
|
||||
redirectOAuthError(c, frontendCallback, "login_failed", infraerrors.Reason(err), infraerrors.Message(err))
|
||||
return
|
||||
}
|
||||
|
||||
fragment := url.Values{}
|
||||
fragment.Set("access_token", jwtToken)
|
||||
fragment.Set("token_type", "Bearer")
|
||||
fragment.Set("redirect", redirectTo)
|
||||
redirectWithFragment(c, frontendCallback, fragment)
|
||||
}
|
||||
|
||||
func (h *AuthHandler) getLinuxDoOAuthConfig(ctx context.Context) (config.LinuxDoConnectConfig, error) {
|
||||
if h != nil && h.settingSvc != nil {
|
||||
return h.settingSvc.GetLinuxDoConnectOAuthConfig(ctx)
|
||||
}
|
||||
if h == nil || h.cfg == nil {
|
||||
return config.LinuxDoConnectConfig{}, infraerrors.ServiceUnavailable("CONFIG_NOT_READY", "config not loaded")
|
||||
}
|
||||
if !h.cfg.LinuxDo.Enabled {
|
||||
return config.LinuxDoConnectConfig{}, infraerrors.NotFound("OAUTH_DISABLED", "oauth login is disabled")
|
||||
}
|
||||
return h.cfg.LinuxDo, nil
|
||||
}
|
||||
|
||||
func linuxDoExchangeCode(
|
||||
ctx context.Context,
|
||||
cfg config.LinuxDoConnectConfig,
|
||||
code string,
|
||||
redirectURI string,
|
||||
codeVerifier string,
|
||||
) (*linuxDoTokenResponse, error) {
|
||||
client := req.C().SetTimeout(30 * time.Second)
|
||||
|
||||
form := url.Values{}
|
||||
form.Set("grant_type", "authorization_code")
|
||||
form.Set("client_id", cfg.ClientID)
|
||||
form.Set("code", code)
|
||||
form.Set("redirect_uri", redirectURI)
|
||||
if cfg.UsePKCE {
|
||||
form.Set("code_verifier", codeVerifier)
|
||||
}
|
||||
|
||||
r := client.R().
|
||||
SetContext(ctx).
|
||||
SetHeader("Accept", "application/json")
|
||||
|
||||
switch strings.ToLower(strings.TrimSpace(cfg.TokenAuthMethod)) {
|
||||
case "", "client_secret_post":
|
||||
form.Set("client_secret", cfg.ClientSecret)
|
||||
case "client_secret_basic":
|
||||
r.SetBasicAuth(cfg.ClientID, cfg.ClientSecret)
|
||||
case "none":
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported token_auth_method: %s", cfg.TokenAuthMethod)
|
||||
}
|
||||
|
||||
resp, err := r.SetFormDataFromValues(form).Post(cfg.TokenURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("request token: %w", err)
|
||||
}
|
||||
body := strings.TrimSpace(resp.String())
|
||||
if !resp.IsSuccessState() {
|
||||
providerErr, providerDesc := parseOAuthProviderError(body)
|
||||
return nil, &linuxDoTokenExchangeError{
|
||||
StatusCode: resp.StatusCode,
|
||||
ProviderError: providerErr,
|
||||
ProviderDescription: providerDesc,
|
||||
Body: body,
|
||||
}
|
||||
}
|
||||
|
||||
tokenResp, ok := parseLinuxDoTokenResponse(body)
|
||||
if !ok || strings.TrimSpace(tokenResp.AccessToken) == "" {
|
||||
return nil, &linuxDoTokenExchangeError{
|
||||
StatusCode: resp.StatusCode,
|
||||
Body: body,
|
||||
}
|
||||
}
|
||||
if strings.TrimSpace(tokenResp.TokenType) == "" {
|
||||
tokenResp.TokenType = "Bearer"
|
||||
}
|
||||
return tokenResp, nil
|
||||
}
|
||||
|
||||
func linuxDoFetchUserInfo(
|
||||
ctx context.Context,
|
||||
cfg config.LinuxDoConnectConfig,
|
||||
token *linuxDoTokenResponse,
|
||||
) (email string, username string, subject string, err error) {
|
||||
client := req.C().SetTimeout(30 * time.Second)
|
||||
authorization, err := buildBearerAuthorization(token.TokenType, token.AccessToken)
|
||||
if err != nil {
|
||||
return "", "", "", fmt.Errorf("invalid token for userinfo request: %w", err)
|
||||
}
|
||||
|
||||
resp, err := client.R().
|
||||
SetContext(ctx).
|
||||
SetHeader("Accept", "application/json").
|
||||
SetHeader("Authorization", authorization).
|
||||
Get(cfg.UserInfoURL)
|
||||
if err != nil {
|
||||
return "", "", "", fmt.Errorf("request userinfo: %w", err)
|
||||
}
|
||||
if !resp.IsSuccessState() {
|
||||
return "", "", "", fmt.Errorf("userinfo status=%d", resp.StatusCode)
|
||||
}
|
||||
|
||||
return linuxDoParseUserInfo(resp.String(), cfg)
|
||||
}
|
||||
|
||||
func linuxDoParseUserInfo(body string, cfg config.LinuxDoConnectConfig) (email string, username string, subject string, err error) {
|
||||
email = firstNonEmpty(
|
||||
getGJSON(body, cfg.UserInfoEmailPath),
|
||||
getGJSON(body, "email"),
|
||||
getGJSON(body, "user.email"),
|
||||
getGJSON(body, "data.email"),
|
||||
getGJSON(body, "attributes.email"),
|
||||
)
|
||||
username = firstNonEmpty(
|
||||
getGJSON(body, cfg.UserInfoUsernamePath),
|
||||
getGJSON(body, "username"),
|
||||
getGJSON(body, "preferred_username"),
|
||||
getGJSON(body, "name"),
|
||||
getGJSON(body, "user.username"),
|
||||
getGJSON(body, "user.name"),
|
||||
)
|
||||
subject = firstNonEmpty(
|
||||
getGJSON(body, cfg.UserInfoIDPath),
|
||||
getGJSON(body, "sub"),
|
||||
getGJSON(body, "id"),
|
||||
getGJSON(body, "user_id"),
|
||||
getGJSON(body, "uid"),
|
||||
getGJSON(body, "user.id"),
|
||||
)
|
||||
|
||||
subject = strings.TrimSpace(subject)
|
||||
if subject == "" {
|
||||
return "", "", "", errors.New("userinfo missing id field")
|
||||
}
|
||||
if !isSafeLinuxDoSubject(subject) {
|
||||
return "", "", "", errors.New("userinfo returned invalid id field")
|
||||
}
|
||||
|
||||
email = strings.TrimSpace(email)
|
||||
if email == "" {
|
||||
// LinuxDo Connect 的 userinfo 可能不提供 email。为兼容现有用户模型(email 必填且唯一),使用稳定的合成邮箱。
|
||||
email = linuxDoSyntheticEmail(subject)
|
||||
}
|
||||
|
||||
username = strings.TrimSpace(username)
|
||||
if username == "" {
|
||||
username = "linuxdo_" + subject
|
||||
}
|
||||
|
||||
return email, username, subject, nil
|
||||
}
|
||||
|
||||
func buildLinuxDoAuthorizeURL(cfg config.LinuxDoConnectConfig, state string, codeChallenge string, redirectURI string) (string, error) {
|
||||
u, err := url.Parse(cfg.AuthorizeURL)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("parse authorize_url: %w", err)
|
||||
}
|
||||
|
||||
q := u.Query()
|
||||
q.Set("response_type", "code")
|
||||
q.Set("client_id", cfg.ClientID)
|
||||
q.Set("redirect_uri", redirectURI)
|
||||
if strings.TrimSpace(cfg.Scopes) != "" {
|
||||
q.Set("scope", cfg.Scopes)
|
||||
}
|
||||
q.Set("state", state)
|
||||
if cfg.UsePKCE {
|
||||
q.Set("code_challenge", codeChallenge)
|
||||
q.Set("code_challenge_method", "S256")
|
||||
}
|
||||
|
||||
u.RawQuery = q.Encode()
|
||||
return u.String(), nil
|
||||
}
|
||||
|
||||
func redirectOAuthError(c *gin.Context, frontendCallback string, code string, message string, description string) {
|
||||
fragment := url.Values{}
|
||||
fragment.Set("error", truncateFragmentValue(code))
|
||||
if strings.TrimSpace(message) != "" {
|
||||
fragment.Set("error_message", truncateFragmentValue(message))
|
||||
}
|
||||
if strings.TrimSpace(description) != "" {
|
||||
fragment.Set("error_description", truncateFragmentValue(description))
|
||||
}
|
||||
redirectWithFragment(c, frontendCallback, fragment)
|
||||
}
|
||||
|
||||
func redirectWithFragment(c *gin.Context, frontendCallback string, fragment url.Values) {
|
||||
u, err := url.Parse(frontendCallback)
|
||||
if err != nil {
|
||||
// 兜底:尽力跳转到默认页面,避免卡死在回调页。
|
||||
c.Redirect(http.StatusFound, linuxDoOAuthDefaultRedirectTo)
|
||||
return
|
||||
}
|
||||
if u.Scheme != "" && !strings.EqualFold(u.Scheme, "http") && !strings.EqualFold(u.Scheme, "https") {
|
||||
c.Redirect(http.StatusFound, linuxDoOAuthDefaultRedirectTo)
|
||||
return
|
||||
}
|
||||
u.Fragment = fragment.Encode()
|
||||
c.Header("Cache-Control", "no-store")
|
||||
c.Header("Pragma", "no-cache")
|
||||
c.Redirect(http.StatusFound, u.String())
|
||||
}
|
||||
|
||||
func firstNonEmpty(values ...string) string {
|
||||
for _, v := range values {
|
||||
v = strings.TrimSpace(v)
|
||||
if v != "" {
|
||||
return v
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func parseOAuthProviderError(body string) (providerErr string, providerDesc string) {
|
||||
body = strings.TrimSpace(body)
|
||||
if body == "" {
|
||||
return "", ""
|
||||
}
|
||||
|
||||
providerErr = firstNonEmpty(
|
||||
getGJSON(body, "error"),
|
||||
getGJSON(body, "code"),
|
||||
getGJSON(body, "error.code"),
|
||||
)
|
||||
providerDesc = firstNonEmpty(
|
||||
getGJSON(body, "error_description"),
|
||||
getGJSON(body, "error.message"),
|
||||
getGJSON(body, "message"),
|
||||
getGJSON(body, "detail"),
|
||||
)
|
||||
|
||||
if providerErr != "" || providerDesc != "" {
|
||||
return providerErr, providerDesc
|
||||
}
|
||||
|
||||
values, err := url.ParseQuery(body)
|
||||
if err != nil {
|
||||
return "", ""
|
||||
}
|
||||
providerErr = firstNonEmpty(values.Get("error"), values.Get("code"))
|
||||
providerDesc = firstNonEmpty(values.Get("error_description"), values.Get("error_message"), values.Get("message"))
|
||||
return providerErr, providerDesc
|
||||
}
|
||||
|
||||
func parseLinuxDoTokenResponse(body string) (*linuxDoTokenResponse, bool) {
|
||||
body = strings.TrimSpace(body)
|
||||
if body == "" {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
accessToken := strings.TrimSpace(getGJSON(body, "access_token"))
|
||||
if accessToken != "" {
|
||||
tokenType := strings.TrimSpace(getGJSON(body, "token_type"))
|
||||
refreshToken := strings.TrimSpace(getGJSON(body, "refresh_token"))
|
||||
scope := strings.TrimSpace(getGJSON(body, "scope"))
|
||||
expiresIn := gjson.Get(body, "expires_in").Int()
|
||||
return &linuxDoTokenResponse{
|
||||
AccessToken: accessToken,
|
||||
TokenType: tokenType,
|
||||
ExpiresIn: expiresIn,
|
||||
RefreshToken: refreshToken,
|
||||
Scope: scope,
|
||||
}, true
|
||||
}
|
||||
|
||||
values, err := url.ParseQuery(body)
|
||||
if err != nil {
|
||||
return nil, false
|
||||
}
|
||||
accessToken = strings.TrimSpace(values.Get("access_token"))
|
||||
if accessToken == "" {
|
||||
return nil, false
|
||||
}
|
||||
expiresIn := int64(0)
|
||||
if raw := strings.TrimSpace(values.Get("expires_in")); raw != "" {
|
||||
if v, err := strconv.ParseInt(raw, 10, 64); err == nil {
|
||||
expiresIn = v
|
||||
}
|
||||
}
|
||||
return &linuxDoTokenResponse{
|
||||
AccessToken: accessToken,
|
||||
TokenType: strings.TrimSpace(values.Get("token_type")),
|
||||
ExpiresIn: expiresIn,
|
||||
RefreshToken: strings.TrimSpace(values.Get("refresh_token")),
|
||||
Scope: strings.TrimSpace(values.Get("scope")),
|
||||
}, true
|
||||
}
|
||||
|
||||
func getGJSON(body string, path string) string {
|
||||
path = strings.TrimSpace(path)
|
||||
if path == "" {
|
||||
return ""
|
||||
}
|
||||
res := gjson.Get(body, path)
|
||||
if !res.Exists() {
|
||||
return ""
|
||||
}
|
||||
return res.String()
|
||||
}
|
||||
|
||||
func truncateLogValue(value string, maxLen int) string {
|
||||
value = strings.TrimSpace(value)
|
||||
if value == "" || maxLen <= 0 {
|
||||
return ""
|
||||
}
|
||||
if len(value) <= maxLen {
|
||||
return value
|
||||
}
|
||||
value = value[:maxLen]
|
||||
for !utf8.ValidString(value) {
|
||||
value = value[:len(value)-1]
|
||||
}
|
||||
return value
|
||||
}
|
||||
|
||||
func singleLine(value string) string {
|
||||
value = strings.TrimSpace(value)
|
||||
if value == "" {
|
||||
return ""
|
||||
}
|
||||
return strings.Join(strings.Fields(value), " ")
|
||||
}
|
||||
|
||||
func sanitizeFrontendRedirectPath(path string) string {
|
||||
path = strings.TrimSpace(path)
|
||||
if path == "" {
|
||||
return ""
|
||||
}
|
||||
if len(path) > linuxDoOAuthMaxRedirectLen {
|
||||
return ""
|
||||
}
|
||||
// 只允许同源相对路径(避免开放重定向)。
|
||||
if !strings.HasPrefix(path, "/") {
|
||||
return ""
|
||||
}
|
||||
if strings.HasPrefix(path, "//") {
|
||||
return ""
|
||||
}
|
||||
if strings.Contains(path, "://") {
|
||||
return ""
|
||||
}
|
||||
if strings.ContainsAny(path, "\r\n") {
|
||||
return ""
|
||||
}
|
||||
return path
|
||||
}
|
||||
|
||||
func isRequestHTTPS(c *gin.Context) bool {
|
||||
if c.Request.TLS != nil {
|
||||
return true
|
||||
}
|
||||
proto := strings.ToLower(strings.TrimSpace(c.GetHeader("X-Forwarded-Proto")))
|
||||
return proto == "https"
|
||||
}
|
||||
|
||||
func encodeCookieValue(value string) string {
|
||||
return base64.RawURLEncoding.EncodeToString([]byte(value))
|
||||
}
|
||||
|
||||
func decodeCookieValue(value string) (string, error) {
|
||||
raw, err := base64.RawURLEncoding.DecodeString(value)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(raw), nil
|
||||
}
|
||||
|
||||
func readCookieDecoded(c *gin.Context, name string) (string, error) {
|
||||
ck, err := c.Request.Cookie(name)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return decodeCookieValue(ck.Value)
|
||||
}
|
||||
|
||||
func setCookie(c *gin.Context, name string, value string, maxAgeSec int, secure bool) {
|
||||
http.SetCookie(c.Writer, &http.Cookie{
|
||||
Name: name,
|
||||
Value: value,
|
||||
Path: linuxDoOAuthCookiePath,
|
||||
MaxAge: maxAgeSec,
|
||||
HttpOnly: true,
|
||||
Secure: secure,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
})
|
||||
}
|
||||
|
||||
func clearCookie(c *gin.Context, name string, secure bool) {
|
||||
http.SetCookie(c.Writer, &http.Cookie{
|
||||
Name: name,
|
||||
Value: "",
|
||||
Path: linuxDoOAuthCookiePath,
|
||||
MaxAge: -1,
|
||||
HttpOnly: true,
|
||||
Secure: secure,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
})
|
||||
}
|
||||
|
||||
func truncateFragmentValue(value string) string {
|
||||
value = strings.TrimSpace(value)
|
||||
if value == "" {
|
||||
return ""
|
||||
}
|
||||
if len(value) > linuxDoOAuthMaxFragmentValueLen {
|
||||
value = value[:linuxDoOAuthMaxFragmentValueLen]
|
||||
for !utf8.ValidString(value) {
|
||||
value = value[:len(value)-1]
|
||||
}
|
||||
}
|
||||
return value
|
||||
}
|
||||
|
||||
func buildBearerAuthorization(tokenType, accessToken string) (string, error) {
|
||||
tokenType = strings.TrimSpace(tokenType)
|
||||
if tokenType == "" {
|
||||
tokenType = "Bearer"
|
||||
}
|
||||
if !strings.EqualFold(tokenType, "Bearer") {
|
||||
return "", fmt.Errorf("unsupported token_type: %s", tokenType)
|
||||
}
|
||||
|
||||
accessToken = strings.TrimSpace(accessToken)
|
||||
if accessToken == "" {
|
||||
return "", errors.New("missing access_token")
|
||||
}
|
||||
if strings.ContainsAny(accessToken, " \t\r\n") {
|
||||
return "", errors.New("access_token contains whitespace")
|
||||
}
|
||||
return "Bearer " + accessToken, nil
|
||||
}
|
||||
|
||||
func isSafeLinuxDoSubject(subject string) bool {
|
||||
subject = strings.TrimSpace(subject)
|
||||
if subject == "" || len(subject) > linuxDoOAuthMaxSubjectLen {
|
||||
return false
|
||||
}
|
||||
for _, r := range subject {
|
||||
switch {
|
||||
case r >= '0' && r <= '9':
|
||||
case r >= 'a' && r <= 'z':
|
||||
case r >= 'A' && r <= 'Z':
|
||||
case r == '_' || r == '-':
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func linuxDoSyntheticEmail(subject string) string {
|
||||
subject = strings.TrimSpace(subject)
|
||||
if subject == "" {
|
||||
return ""
|
||||
}
|
||||
return "linuxdo-" + subject + service.LinuxDoConnectSyntheticEmailDomain
|
||||
}
|
||||
108
backend/internal/handler/auth_linuxdo_oauth_test.go
Normal file
108
backend/internal/handler/auth_linuxdo_oauth_test.go
Normal file
@@ -0,0 +1,108 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestSanitizeFrontendRedirectPath(t *testing.T) {
|
||||
require.Equal(t, "/dashboard", sanitizeFrontendRedirectPath("/dashboard"))
|
||||
require.Equal(t, "/dashboard", sanitizeFrontendRedirectPath(" /dashboard "))
|
||||
require.Equal(t, "", sanitizeFrontendRedirectPath("dashboard"))
|
||||
require.Equal(t, "", sanitizeFrontendRedirectPath("//evil.com"))
|
||||
require.Equal(t, "", sanitizeFrontendRedirectPath("https://evil.com"))
|
||||
require.Equal(t, "", sanitizeFrontendRedirectPath("/\nfoo"))
|
||||
|
||||
long := "/" + strings.Repeat("a", linuxDoOAuthMaxRedirectLen)
|
||||
require.Equal(t, "", sanitizeFrontendRedirectPath(long))
|
||||
}
|
||||
|
||||
func TestBuildBearerAuthorization(t *testing.T) {
|
||||
auth, err := buildBearerAuthorization("", "token123")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "Bearer token123", auth)
|
||||
|
||||
auth, err = buildBearerAuthorization("bearer", "token123")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "Bearer token123", auth)
|
||||
|
||||
_, err = buildBearerAuthorization("MAC", "token123")
|
||||
require.Error(t, err)
|
||||
|
||||
_, err = buildBearerAuthorization("Bearer", "token 123")
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestLinuxDoParseUserInfoParsesIDAndUsername(t *testing.T) {
|
||||
cfg := config.LinuxDoConnectConfig{
|
||||
UserInfoURL: "https://connect.linux.do/api/user",
|
||||
}
|
||||
|
||||
email, username, subject, err := linuxDoParseUserInfo(`{"id":123,"username":"alice"}`, cfg)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "123", subject)
|
||||
require.Equal(t, "alice", username)
|
||||
require.Equal(t, "linuxdo-123@linuxdo-connect.invalid", email)
|
||||
}
|
||||
|
||||
func TestLinuxDoParseUserInfoDefaultsUsername(t *testing.T) {
|
||||
cfg := config.LinuxDoConnectConfig{
|
||||
UserInfoURL: "https://connect.linux.do/api/user",
|
||||
}
|
||||
|
||||
email, username, subject, err := linuxDoParseUserInfo(`{"id":"123"}`, cfg)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "123", subject)
|
||||
require.Equal(t, "linuxdo_123", username)
|
||||
require.Equal(t, "linuxdo-123@linuxdo-connect.invalid", email)
|
||||
}
|
||||
|
||||
func TestLinuxDoParseUserInfoRejectsUnsafeSubject(t *testing.T) {
|
||||
cfg := config.LinuxDoConnectConfig{
|
||||
UserInfoURL: "https://connect.linux.do/api/user",
|
||||
}
|
||||
|
||||
_, _, _, err := linuxDoParseUserInfo(`{"id":"123@456"}`, cfg)
|
||||
require.Error(t, err)
|
||||
|
||||
tooLong := strings.Repeat("a", linuxDoOAuthMaxSubjectLen+1)
|
||||
_, _, _, err = linuxDoParseUserInfo(`{"id":"`+tooLong+`"}`, cfg)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestParseOAuthProviderErrorJSON(t *testing.T) {
|
||||
code, desc := parseOAuthProviderError(`{"error":"invalid_client","error_description":"bad secret"}`)
|
||||
require.Equal(t, "invalid_client", code)
|
||||
require.Equal(t, "bad secret", desc)
|
||||
}
|
||||
|
||||
func TestParseOAuthProviderErrorForm(t *testing.T) {
|
||||
code, desc := parseOAuthProviderError("error=invalid_request&error_description=Missing+code_verifier")
|
||||
require.Equal(t, "invalid_request", code)
|
||||
require.Equal(t, "Missing code_verifier", desc)
|
||||
}
|
||||
|
||||
func TestParseLinuxDoTokenResponseJSON(t *testing.T) {
|
||||
token, ok := parseLinuxDoTokenResponse(`{"access_token":"t1","token_type":"Bearer","expires_in":3600,"scope":"user"}`)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "t1", token.AccessToken)
|
||||
require.Equal(t, "Bearer", token.TokenType)
|
||||
require.Equal(t, int64(3600), token.ExpiresIn)
|
||||
require.Equal(t, "user", token.Scope)
|
||||
}
|
||||
|
||||
func TestParseLinuxDoTokenResponseForm(t *testing.T) {
|
||||
token, ok := parseLinuxDoTokenResponse("access_token=t2&token_type=bearer&expires_in=60")
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "t2", token.AccessToken)
|
||||
require.Equal(t, "bearer", token.TokenType)
|
||||
require.Equal(t, int64(60), token.ExpiresIn)
|
||||
}
|
||||
|
||||
func TestSingleLineStripsWhitespace(t *testing.T) {
|
||||
require.Equal(t, "hello world", singleLine("hello\r\nworld"))
|
||||
require.Equal(t, "", singleLine("\n\t\r"))
|
||||
}
|
||||
@@ -17,6 +17,11 @@ type SystemSettings struct {
|
||||
TurnstileSiteKey string `json:"turnstile_site_key"`
|
||||
TurnstileSecretKeyConfigured bool `json:"turnstile_secret_key_configured"`
|
||||
|
||||
LinuxDoConnectEnabled bool `json:"linuxdo_connect_enabled"`
|
||||
LinuxDoConnectClientID string `json:"linuxdo_connect_client_id"`
|
||||
LinuxDoConnectClientSecretConfigured bool `json:"linuxdo_connect_client_secret_configured"`
|
||||
LinuxDoConnectRedirectURL string `json:"linuxdo_connect_redirect_url"`
|
||||
|
||||
SiteName string `json:"site_name"`
|
||||
SiteLogo string `json:"site_logo"`
|
||||
SiteSubtitle string `json:"site_subtitle"`
|
||||
@@ -50,5 +55,6 @@ type PublicSettings struct {
|
||||
APIBaseURL string `json:"api_base_url"`
|
||||
ContactInfo string `json:"contact_info"`
|
||||
DocURL string `json:"doc_url"`
|
||||
LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"`
|
||||
Version string `json:"version"`
|
||||
}
|
||||
|
||||
@@ -42,6 +42,7 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) {
|
||||
APIBaseURL: settings.APIBaseURL,
|
||||
ContactInfo: settings.ContactInfo,
|
||||
DocURL: settings.DocURL,
|
||||
LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled,
|
||||
Version: h.version,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -5,8 +5,11 @@ import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
@@ -22,10 +25,10 @@ func resolveHost(urlStr string) string {
|
||||
return parsed.Host
|
||||
}
|
||||
|
||||
// NewAPIRequest 创建 Antigravity API 请求(v1internal 端点)
|
||||
func NewAPIRequest(ctx context.Context, action, accessToken string, body []byte) (*http.Request, error) {
|
||||
// NewAPIRequestWithURL 使用指定的 base URL 创建 Antigravity API 请求(v1internal 端点)
|
||||
func NewAPIRequestWithURL(ctx context.Context, baseURL, action, accessToken string, body []byte) (*http.Request, error) {
|
||||
// 构建 URL,流式请求添加 ?alt=sse 参数
|
||||
apiURL := fmt.Sprintf("%s/v1internal:%s", BaseURL, action)
|
||||
apiURL := fmt.Sprintf("%s/v1internal:%s", baseURL, action)
|
||||
isStream := action == "streamGenerateContent"
|
||||
if isStream {
|
||||
apiURL += "?alt=sse"
|
||||
@@ -53,11 +56,15 @@ func NewAPIRequest(ctx context.Context, action, accessToken string, body []byte)
|
||||
req.Host = host
|
||||
}
|
||||
|
||||
// 注意:requestType 已在 JSON body 的 V1InternalRequest 中设置,不需要 HTTP Header
|
||||
|
||||
return req, nil
|
||||
}
|
||||
|
||||
// NewAPIRequest 使用默认 URL 创建 Antigravity API 请求(v1internal 端点)
|
||||
// 向后兼容:仅使用默认 BaseURL
|
||||
func NewAPIRequest(ctx context.Context, action, accessToken string, body []byte) (*http.Request, error) {
|
||||
return NewAPIRequestWithURL(ctx, BaseURL, action, accessToken, body)
|
||||
}
|
||||
|
||||
// TokenResponse Google OAuth token 响应
|
||||
type TokenResponse struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
@@ -164,6 +171,38 @@ func NewClient(proxyURL string) *Client {
|
||||
}
|
||||
}
|
||||
|
||||
// isConnectionError 判断是否为连接错误(网络超时、DNS 失败、连接拒绝)
|
||||
func isConnectionError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// 检查超时错误
|
||||
var netErr net.Error
|
||||
if errors.As(err, &netErr) && netErr.Timeout() {
|
||||
return true
|
||||
}
|
||||
|
||||
// 检查连接错误(DNS 失败、连接拒绝)
|
||||
var opErr *net.OpError
|
||||
if errors.As(err, &opErr) {
|
||||
return true
|
||||
}
|
||||
|
||||
// 检查 URL 错误
|
||||
var urlErr *url.Error
|
||||
return errors.As(err, &urlErr)
|
||||
}
|
||||
|
||||
// shouldFallbackToNextURL 判断是否应切换到下一个 URL
|
||||
// 仅连接错误和 HTTP 429 触发 URL 降级
|
||||
func shouldFallbackToNextURL(err error, statusCode int) bool {
|
||||
if isConnectionError(err) {
|
||||
return true
|
||||
}
|
||||
return statusCode == http.StatusTooManyRequests
|
||||
}
|
||||
|
||||
// ExchangeCode 用 authorization code 交换 token
|
||||
func (c *Client) ExchangeCode(ctx context.Context, code, codeVerifier string) (*TokenResponse, error) {
|
||||
params := url.Values{}
|
||||
@@ -272,6 +311,7 @@ func (c *Client) GetUserInfo(ctx context.Context, accessToken string) (*UserInfo
|
||||
}
|
||||
|
||||
// LoadCodeAssist 获取账户信息,返回解析后的结构体和原始 JSON
|
||||
// 支持 URL fallback:sandbox → daily → prod
|
||||
func (c *Client) LoadCodeAssist(ctx context.Context, accessToken string) (*LoadCodeAssistResponse, map[string]any, error) {
|
||||
reqBody := LoadCodeAssistRequest{}
|
||||
reqBody.Metadata.IDEType = "ANTIGRAVITY"
|
||||
@@ -281,40 +321,65 @@ func (c *Client) LoadCodeAssist(ctx context.Context, accessToken string) (*LoadC
|
||||
return nil, nil, fmt.Errorf("序列化请求失败: %w", err)
|
||||
}
|
||||
|
||||
url := BaseURL + "/v1internal:loadCodeAssist"
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, strings.NewReader(string(bodyBytes)))
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("创建请求失败: %w", err)
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("User-Agent", UserAgent)
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("loadCodeAssist 请求失败: %w", err)
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
respBodyBytes, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("读取响应失败: %w", err)
|
||||
// 获取可用的 URL 列表
|
||||
availableURLs := DefaultURLAvailability.GetAvailableURLs()
|
||||
if len(availableURLs) == 0 {
|
||||
availableURLs = BaseURLs // 所有 URL 都不可用时,重试所有
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, nil, fmt.Errorf("loadCodeAssist 失败 (HTTP %d): %s", resp.StatusCode, string(respBodyBytes))
|
||||
var lastErr error
|
||||
for urlIdx, baseURL := range availableURLs {
|
||||
apiURL := baseURL + "/v1internal:loadCodeAssist"
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, apiURL, strings.NewReader(string(bodyBytes)))
|
||||
if err != nil {
|
||||
lastErr = fmt.Errorf("创建请求失败: %w", err)
|
||||
continue
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("User-Agent", UserAgent)
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
lastErr = fmt.Errorf("loadCodeAssist 请求失败: %w", err)
|
||||
if shouldFallbackToNextURL(err, 0) && urlIdx < len(availableURLs)-1 {
|
||||
DefaultURLAvailability.MarkUnavailable(baseURL)
|
||||
log.Printf("[antigravity] loadCodeAssist URL fallback: %s -> %s", baseURL, availableURLs[urlIdx+1])
|
||||
continue
|
||||
}
|
||||
return nil, nil, lastErr
|
||||
}
|
||||
|
||||
respBodyBytes, err := io.ReadAll(resp.Body)
|
||||
_ = resp.Body.Close() // 立即关闭,避免循环内 defer 导致的资源泄漏
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("读取响应失败: %w", err)
|
||||
}
|
||||
|
||||
// 检查是否需要 URL 降级
|
||||
if shouldFallbackToNextURL(nil, resp.StatusCode) && urlIdx < len(availableURLs)-1 {
|
||||
DefaultURLAvailability.MarkUnavailable(baseURL)
|
||||
log.Printf("[antigravity] loadCodeAssist URL fallback (HTTP %d): %s -> %s", resp.StatusCode, baseURL, availableURLs[urlIdx+1])
|
||||
continue
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, nil, fmt.Errorf("loadCodeAssist 失败 (HTTP %d): %s", resp.StatusCode, string(respBodyBytes))
|
||||
}
|
||||
|
||||
var loadResp LoadCodeAssistResponse
|
||||
if err := json.Unmarshal(respBodyBytes, &loadResp); err != nil {
|
||||
return nil, nil, fmt.Errorf("响应解析失败: %w", err)
|
||||
}
|
||||
|
||||
// 解析原始 JSON 为 map
|
||||
var rawResp map[string]any
|
||||
_ = json.Unmarshal(respBodyBytes, &rawResp)
|
||||
|
||||
return &loadResp, rawResp, nil
|
||||
}
|
||||
|
||||
var loadResp LoadCodeAssistResponse
|
||||
if err := json.Unmarshal(respBodyBytes, &loadResp); err != nil {
|
||||
return nil, nil, fmt.Errorf("响应解析失败: %w", err)
|
||||
}
|
||||
|
||||
// 解析原始 JSON 为 map
|
||||
var rawResp map[string]any
|
||||
_ = json.Unmarshal(respBodyBytes, &rawResp)
|
||||
|
||||
return &loadResp, rawResp, nil
|
||||
return nil, nil, lastErr
|
||||
}
|
||||
|
||||
// ModelQuotaInfo 模型配额信息
|
||||
@@ -339,6 +404,7 @@ type FetchAvailableModelsResponse struct {
|
||||
}
|
||||
|
||||
// FetchAvailableModels 获取可用模型和配额信息,返回解析后的结构体和原始 JSON
|
||||
// 支持 URL fallback:sandbox → daily → prod
|
||||
func (c *Client) FetchAvailableModels(ctx context.Context, accessToken, projectID string) (*FetchAvailableModelsResponse, map[string]any, error) {
|
||||
reqBody := FetchAvailableModelsRequest{Project: projectID}
|
||||
bodyBytes, err := json.Marshal(reqBody)
|
||||
@@ -346,38 +412,63 @@ func (c *Client) FetchAvailableModels(ctx context.Context, accessToken, projectI
|
||||
return nil, nil, fmt.Errorf("序列化请求失败: %w", err)
|
||||
}
|
||||
|
||||
apiURL := BaseURL + "/v1internal:fetchAvailableModels"
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, apiURL, strings.NewReader(string(bodyBytes)))
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("创建请求失败: %w", err)
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("User-Agent", UserAgent)
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("fetchAvailableModels 请求失败: %w", err)
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
respBodyBytes, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("读取响应失败: %w", err)
|
||||
// 获取可用的 URL 列表
|
||||
availableURLs := DefaultURLAvailability.GetAvailableURLs()
|
||||
if len(availableURLs) == 0 {
|
||||
availableURLs = BaseURLs // 所有 URL 都不可用时,重试所有
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, nil, fmt.Errorf("fetchAvailableModels 失败 (HTTP %d): %s", resp.StatusCode, string(respBodyBytes))
|
||||
var lastErr error
|
||||
for urlIdx, baseURL := range availableURLs {
|
||||
apiURL := baseURL + "/v1internal:fetchAvailableModels"
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, apiURL, strings.NewReader(string(bodyBytes)))
|
||||
if err != nil {
|
||||
lastErr = fmt.Errorf("创建请求失败: %w", err)
|
||||
continue
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("User-Agent", UserAgent)
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
lastErr = fmt.Errorf("fetchAvailableModels 请求失败: %w", err)
|
||||
if shouldFallbackToNextURL(err, 0) && urlIdx < len(availableURLs)-1 {
|
||||
DefaultURLAvailability.MarkUnavailable(baseURL)
|
||||
log.Printf("[antigravity] fetchAvailableModels URL fallback: %s -> %s", baseURL, availableURLs[urlIdx+1])
|
||||
continue
|
||||
}
|
||||
return nil, nil, lastErr
|
||||
}
|
||||
|
||||
respBodyBytes, err := io.ReadAll(resp.Body)
|
||||
_ = resp.Body.Close() // 立即关闭,避免循环内 defer 导致的资源泄漏
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("读取响应失败: %w", err)
|
||||
}
|
||||
|
||||
// 检查是否需要 URL 降级
|
||||
if shouldFallbackToNextURL(nil, resp.StatusCode) && urlIdx < len(availableURLs)-1 {
|
||||
DefaultURLAvailability.MarkUnavailable(baseURL)
|
||||
log.Printf("[antigravity] fetchAvailableModels URL fallback (HTTP %d): %s -> %s", resp.StatusCode, baseURL, availableURLs[urlIdx+1])
|
||||
continue
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, nil, fmt.Errorf("fetchAvailableModels 失败 (HTTP %d): %s", resp.StatusCode, string(respBodyBytes))
|
||||
}
|
||||
|
||||
var modelsResp FetchAvailableModelsResponse
|
||||
if err := json.Unmarshal(respBodyBytes, &modelsResp); err != nil {
|
||||
return nil, nil, fmt.Errorf("响应解析失败: %w", err)
|
||||
}
|
||||
|
||||
// 解析原始 JSON 为 map
|
||||
var rawResp map[string]any
|
||||
_ = json.Unmarshal(respBodyBytes, &rawResp)
|
||||
|
||||
return &modelsResp, rawResp, nil
|
||||
}
|
||||
|
||||
var modelsResp FetchAvailableModelsResponse
|
||||
if err := json.Unmarshal(respBodyBytes, &modelsResp); err != nil {
|
||||
return nil, nil, fmt.Errorf("响应解析失败: %w", err)
|
||||
}
|
||||
|
||||
// 解析原始 JSON 为 map
|
||||
var rawResp map[string]any
|
||||
_ = json.Unmarshal(respBodyBytes, &rawResp)
|
||||
|
||||
return &modelsResp, rawResp, nil
|
||||
return nil, nil, lastErr
|
||||
}
|
||||
|
||||
@@ -32,17 +32,79 @@ const (
|
||||
"https://www.googleapis.com/auth/cclog " +
|
||||
"https://www.googleapis.com/auth/experimentsandconfigs"
|
||||
|
||||
// API 端点
|
||||
// 优先使用 sandbox daily URL,配额更宽松
|
||||
BaseURL = "https://daily-cloudcode-pa.sandbox.googleapis.com"
|
||||
|
||||
// User-Agent(模拟官方客户端)
|
||||
UserAgent = "antigravity/1.104.0 darwin/arm64"
|
||||
|
||||
// Session 过期时间
|
||||
SessionTTL = 30 * time.Minute
|
||||
|
||||
// URL 可用性 TTL(不可用 URL 的恢复时间)
|
||||
URLAvailabilityTTL = 5 * time.Minute
|
||||
)
|
||||
|
||||
// BaseURLs 定义 Antigravity API 端点,按优先级排序
|
||||
// fallback 顺序: sandbox → daily → prod
|
||||
var BaseURLs = []string{
|
||||
"https://daily-cloudcode-pa.sandbox.googleapis.com", // sandbox
|
||||
"https://daily-cloudcode-pa.googleapis.com", // daily
|
||||
"https://cloudcode-pa.googleapis.com", // prod
|
||||
}
|
||||
|
||||
// BaseURL 默认 URL(保持向后兼容)
|
||||
var BaseURL = BaseURLs[0]
|
||||
|
||||
// URLAvailability 管理 URL 可用性状态(带 TTL 自动恢复)
|
||||
type URLAvailability struct {
|
||||
mu sync.RWMutex
|
||||
unavailable map[string]time.Time // URL -> 恢复时间
|
||||
ttl time.Duration
|
||||
}
|
||||
|
||||
// DefaultURLAvailability 全局 URL 可用性管理器
|
||||
var DefaultURLAvailability = NewURLAvailability(URLAvailabilityTTL)
|
||||
|
||||
// NewURLAvailability 创建 URL 可用性管理器
|
||||
func NewURLAvailability(ttl time.Duration) *URLAvailability {
|
||||
return &URLAvailability{
|
||||
unavailable: make(map[string]time.Time),
|
||||
ttl: ttl,
|
||||
}
|
||||
}
|
||||
|
||||
// MarkUnavailable 标记 URL 临时不可用
|
||||
func (u *URLAvailability) MarkUnavailable(url string) {
|
||||
u.mu.Lock()
|
||||
defer u.mu.Unlock()
|
||||
u.unavailable[url] = time.Now().Add(u.ttl)
|
||||
}
|
||||
|
||||
// IsAvailable 检查 URL 是否可用
|
||||
func (u *URLAvailability) IsAvailable(url string) bool {
|
||||
u.mu.RLock()
|
||||
defer u.mu.RUnlock()
|
||||
expiry, exists := u.unavailable[url]
|
||||
if !exists {
|
||||
return true
|
||||
}
|
||||
return time.Now().After(expiry)
|
||||
}
|
||||
|
||||
// GetAvailableURLs 返回可用的 URL 列表(保持优先级顺序)
|
||||
func (u *URLAvailability) GetAvailableURLs() []string {
|
||||
u.mu.RLock()
|
||||
defer u.mu.RUnlock()
|
||||
|
||||
now := time.Now()
|
||||
result := make([]string, 0, len(BaseURLs))
|
||||
for _, url := range BaseURLs {
|
||||
expiry, exists := u.unavailable[url]
|
||||
if !exists || now.After(expiry) {
|
||||
result = append(result, url)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// OAuthSession 保存 OAuth 授权流程的临时状态
|
||||
type OAuthSession struct {
|
||||
State string `json:"state"`
|
||||
|
||||
@@ -27,10 +27,9 @@ const (
|
||||
// https://www.googleapis.com/auth/generative-language.retriever (often with cloud-platform).
|
||||
DefaultAIStudioScopes = "https://www.googleapis.com/auth/cloud-platform https://www.googleapis.com/auth/generative-language.retriever"
|
||||
|
||||
// DefaultScopes for Google One (personal Google accounts with Gemini access)
|
||||
// Only used when a custom OAuth client is configured. When using the built-in Gemini CLI client,
|
||||
// Google One uses DefaultCodeAssistScopes (same as code_assist) because the built-in client
|
||||
// cannot request restricted scopes like generative-language.retriever or drive.readonly.
|
||||
// DefaultGoogleOneScopes (DEPRECATED, no longer used)
|
||||
// Google One now always uses the built-in Gemini CLI client with DefaultCodeAssistScopes.
|
||||
// This constant is kept for backward compatibility but is not actively used.
|
||||
DefaultGoogleOneScopes = "https://www.googleapis.com/auth/cloud-platform https://www.googleapis.com/auth/generative-language.retriever https://www.googleapis.com/auth/drive.readonly https://www.googleapis.com/auth/userinfo.email https://www.googleapis.com/auth/userinfo.profile"
|
||||
|
||||
// GeminiCLIRedirectURI is the redirect URI used by Gemini CLI for Code Assist OAuth.
|
||||
|
||||
@@ -185,13 +185,9 @@ func EffectiveOAuthConfig(cfg OAuthConfig, oauthType string) (OAuthConfig, error
|
||||
effective.Scopes = DefaultAIStudioScopes
|
||||
}
|
||||
case "google_one":
|
||||
// Google One uses built-in Gemini CLI client (same as code_assist)
|
||||
// Built-in client can't request restricted scopes like generative-language.retriever
|
||||
if isBuiltinClient {
|
||||
effective.Scopes = DefaultCodeAssistScopes
|
||||
} else {
|
||||
effective.Scopes = DefaultGoogleOneScopes
|
||||
}
|
||||
// Google One always uses built-in Gemini CLI client (same as code_assist)
|
||||
// Built-in client can't request restricted scopes like generative-language.retriever or drive.readonly
|
||||
effective.Scopes = DefaultCodeAssistScopes
|
||||
default:
|
||||
// Default to Code Assist scopes
|
||||
effective.Scopes = DefaultCodeAssistScopes
|
||||
|
||||
@@ -23,14 +23,14 @@ func TestEffectiveOAuthConfig_GoogleOne(t *testing.T) {
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Google One with custom client",
|
||||
name: "Google One always uses built-in client (even if custom credentials passed)",
|
||||
input: OAuthConfig{
|
||||
ClientID: "custom-client-id",
|
||||
ClientSecret: "custom-client-secret",
|
||||
},
|
||||
oauthType: "google_one",
|
||||
wantClientID: "custom-client-id",
|
||||
wantScopes: DefaultGoogleOneScopes,
|
||||
wantScopes: DefaultCodeAssistScopes, // Uses code assist scopes even with custom client
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
|
||||
@@ -831,6 +831,11 @@ func (r *accountRepository) BulkUpdate(ctx context.Context, ids []int64, updates
|
||||
args = append(args, *updates.Status)
|
||||
idx++
|
||||
}
|
||||
if updates.Schedulable != nil {
|
||||
setClauses = append(setClauses, "schedulable = $"+itoa(idx))
|
||||
args = append(args, *updates.Schedulable)
|
||||
idx++
|
||||
}
|
||||
// JSONB 需要合并而非覆盖,使用 raw SQL 保持旧行为。
|
||||
if len(updates.Credentials) > 0 {
|
||||
payload, err := json.Marshal(updates.Credentials)
|
||||
|
||||
@@ -30,14 +30,15 @@ func (c *geminiOAuthClient) ExchangeCode(ctx context.Context, oauthType, code, c
|
||||
|
||||
// Use different OAuth clients based on oauthType:
|
||||
// - code_assist: always use built-in Gemini CLI OAuth client (public)
|
||||
// - google_one: uses configured OAuth client when provided; otherwise falls back to built-in client
|
||||
// - google_one: always use built-in Gemini CLI OAuth client (public)
|
||||
// - ai_studio: requires a user-provided OAuth client
|
||||
oauthCfgInput := geminicli.OAuthConfig{
|
||||
ClientID: c.cfg.Gemini.OAuth.ClientID,
|
||||
ClientSecret: c.cfg.Gemini.OAuth.ClientSecret,
|
||||
Scopes: c.cfg.Gemini.OAuth.Scopes,
|
||||
}
|
||||
if oauthType == "code_assist" {
|
||||
if oauthType == "code_assist" || oauthType == "google_one" {
|
||||
// Force use of built-in Gemini CLI OAuth client
|
||||
oauthCfgInput.ClientID = ""
|
||||
oauthCfgInput.ClientSecret = ""
|
||||
}
|
||||
@@ -78,7 +79,8 @@ func (c *geminiOAuthClient) RefreshToken(ctx context.Context, oauthType, refresh
|
||||
ClientSecret: c.cfg.Gemini.OAuth.ClientSecret,
|
||||
Scopes: c.cfg.Gemini.OAuth.Scopes,
|
||||
}
|
||||
if oauthType == "code_assist" {
|
||||
if oauthType == "code_assist" || oauthType == "google_one" {
|
||||
// Force use of built-in Gemini CLI OAuth client
|
||||
oauthCfgInput.ClientID = ""
|
||||
oauthCfgInput.ClientSecret = ""
|
||||
}
|
||||
|
||||
@@ -112,10 +112,10 @@ func (r *groupRepository) Delete(ctx context.Context, id int64) error {
|
||||
}
|
||||
|
||||
func (r *groupRepository) List(ctx context.Context, params pagination.PaginationParams) ([]service.Group, *pagination.PaginationResult, error) {
|
||||
return r.ListWithFilters(ctx, params, "", "", nil)
|
||||
return r.ListWithFilters(ctx, params, "", "", "", nil)
|
||||
}
|
||||
|
||||
func (r *groupRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status string, isExclusive *bool) ([]service.Group, *pagination.PaginationResult, error) {
|
||||
func (r *groupRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status, search string, isExclusive *bool) ([]service.Group, *pagination.PaginationResult, error) {
|
||||
q := r.client.Group.Query()
|
||||
|
||||
if platform != "" {
|
||||
@@ -124,6 +124,12 @@ func (r *groupRepository) ListWithFilters(ctx context.Context, params pagination
|
||||
if status != "" {
|
||||
q = q.Where(group.StatusEQ(status))
|
||||
}
|
||||
if search != "" {
|
||||
q = q.Where(group.Or(
|
||||
group.NameContainsFold(search),
|
||||
group.DescriptionContainsFold(search),
|
||||
))
|
||||
}
|
||||
if isExclusive != nil {
|
||||
q = q.Where(group.IsExclusiveEQ(*isExclusive))
|
||||
}
|
||||
|
||||
@@ -131,6 +131,7 @@ func (s *GroupRepoSuite) TestListWithFilters_Platform() {
|
||||
pagination.PaginationParams{Page: 1, PageSize: 10},
|
||||
service.PlatformOpenAI,
|
||||
"",
|
||||
"",
|
||||
nil,
|
||||
)
|
||||
s.Require().NoError(err, "ListWithFilters base")
|
||||
@@ -152,7 +153,7 @@ func (s *GroupRepoSuite) TestListWithFilters_Platform() {
|
||||
SubscriptionType: service.SubscriptionTypeStandard,
|
||||
}))
|
||||
|
||||
groups, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, service.PlatformOpenAI, "", nil)
|
||||
groups, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, service.PlatformOpenAI, "", "", nil)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Len(groups, len(baseGroups)+1)
|
||||
// Verify all groups are OpenAI platform
|
||||
@@ -179,7 +180,7 @@ func (s *GroupRepoSuite) TestListWithFilters_Status() {
|
||||
SubscriptionType: service.SubscriptionTypeStandard,
|
||||
}))
|
||||
|
||||
groups, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", service.StatusDisabled, nil)
|
||||
groups, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", service.StatusDisabled, "", nil)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Len(groups, 1)
|
||||
s.Require().Equal(service.StatusDisabled, groups[0].Status)
|
||||
@@ -204,12 +205,117 @@ func (s *GroupRepoSuite) TestListWithFilters_IsExclusive() {
|
||||
}))
|
||||
|
||||
isExclusive := true
|
||||
groups, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", &isExclusive)
|
||||
groups, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "", &isExclusive)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Len(groups, 1)
|
||||
s.Require().True(groups[0].IsExclusive)
|
||||
}
|
||||
|
||||
func (s *GroupRepoSuite) TestListWithFilters_Search() {
|
||||
newRepo := func() (*groupRepository, context.Context) {
|
||||
tx := testEntTx(s.T())
|
||||
return newGroupRepositoryWithSQL(tx.Client(), tx), context.Background()
|
||||
}
|
||||
|
||||
containsID := func(groups []service.Group, id int64) bool {
|
||||
for i := range groups {
|
||||
if groups[i].ID == id {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
mustCreate := func(repo *groupRepository, ctx context.Context, g *service.Group) *service.Group {
|
||||
s.Require().NoError(repo.Create(ctx, g))
|
||||
s.Require().NotZero(g.ID)
|
||||
return g
|
||||
}
|
||||
|
||||
newGroup := func(name string) *service.Group {
|
||||
return &service.Group{
|
||||
Name: name,
|
||||
Platform: service.PlatformAnthropic,
|
||||
RateMultiplier: 1.0,
|
||||
IsExclusive: false,
|
||||
Status: service.StatusActive,
|
||||
SubscriptionType: service.SubscriptionTypeStandard,
|
||||
}
|
||||
}
|
||||
|
||||
s.Run("search_name_should_match", func() {
|
||||
repo, ctx := newRepo()
|
||||
|
||||
target := mustCreate(repo, ctx, newGroup("it-group-search-name-target"))
|
||||
other := mustCreate(repo, ctx, newGroup("it-group-search-name-other"))
|
||||
|
||||
groups, _, err := repo.ListWithFilters(ctx, pagination.PaginationParams{Page: 1, PageSize: 50}, "", "", "name-target", nil)
|
||||
s.Require().NoError(err)
|
||||
s.Require().True(containsID(groups, target.ID), "expected target group to match by name")
|
||||
s.Require().False(containsID(groups, other.ID), "expected other group to be filtered out")
|
||||
})
|
||||
|
||||
s.Run("search_description_should_match", func() {
|
||||
repo, ctx := newRepo()
|
||||
|
||||
target := newGroup("it-group-search-desc-target")
|
||||
target.Description = "something about desc-needle in here"
|
||||
target = mustCreate(repo, ctx, target)
|
||||
|
||||
other := newGroup("it-group-search-desc-other")
|
||||
other.Description = "nothing to see here"
|
||||
other = mustCreate(repo, ctx, other)
|
||||
|
||||
groups, _, err := repo.ListWithFilters(ctx, pagination.PaginationParams{Page: 1, PageSize: 50}, "", "", "desc-needle", nil)
|
||||
s.Require().NoError(err)
|
||||
s.Require().True(containsID(groups, target.ID), "expected target group to match by description")
|
||||
s.Require().False(containsID(groups, other.ID), "expected other group to be filtered out")
|
||||
})
|
||||
|
||||
s.Run("search_nonexistent_should_return_empty", func() {
|
||||
repo, ctx := newRepo()
|
||||
|
||||
_ = mustCreate(repo, ctx, newGroup("it-group-search-nonexistent-baseline"))
|
||||
|
||||
search := s.T().Name() + "__no_such_group__"
|
||||
groups, _, err := repo.ListWithFilters(ctx, pagination.PaginationParams{Page: 1, PageSize: 50}, "", "", search, nil)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Empty(groups)
|
||||
})
|
||||
|
||||
s.Run("search_should_be_case_insensitive", func() {
|
||||
repo, ctx := newRepo()
|
||||
|
||||
target := mustCreate(repo, ctx, newGroup("MiXeDCaSe-Needle"))
|
||||
other := mustCreate(repo, ctx, newGroup("it-group-search-case-other"))
|
||||
|
||||
groups, _, err := repo.ListWithFilters(ctx, pagination.PaginationParams{Page: 1, PageSize: 50}, "", "", "mixedcase-needle", nil)
|
||||
s.Require().NoError(err)
|
||||
s.Require().True(containsID(groups, target.ID), "expected case-insensitive match")
|
||||
s.Require().False(containsID(groups, other.ID), "expected other group to be filtered out")
|
||||
})
|
||||
|
||||
s.Run("search_should_escape_like_wildcards", func() {
|
||||
repo, ctx := newRepo()
|
||||
|
||||
percentTarget := mustCreate(repo, ctx, newGroup("it-group-search-100%-target"))
|
||||
percentOther := mustCreate(repo, ctx, newGroup("it-group-search-100X-other"))
|
||||
|
||||
groups, _, err := repo.ListWithFilters(ctx, pagination.PaginationParams{Page: 1, PageSize: 50}, "", "", "100%", nil)
|
||||
s.Require().NoError(err)
|
||||
s.Require().True(containsID(groups, percentTarget.ID), "expected literal %% match")
|
||||
s.Require().False(containsID(groups, percentOther.ID), "expected %% not to act as wildcard")
|
||||
|
||||
underscoreTarget := mustCreate(repo, ctx, newGroup("it-group-search-ab_cd-target"))
|
||||
underscoreOther := mustCreate(repo, ctx, newGroup("it-group-search-abXcd-other"))
|
||||
|
||||
groups, _, err = repo.ListWithFilters(ctx, pagination.PaginationParams{Page: 1, PageSize: 50}, "", "", "ab_cd", nil)
|
||||
s.Require().NoError(err)
|
||||
s.Require().True(containsID(groups, underscoreTarget.ID), "expected literal _ match")
|
||||
s.Require().False(containsID(groups, underscoreOther.ID), "expected _ not to act as wildcard")
|
||||
})
|
||||
}
|
||||
|
||||
func (s *GroupRepoSuite) TestListWithFilters_AccountCount() {
|
||||
g1 := &service.Group{
|
||||
Name: "g1",
|
||||
@@ -244,7 +350,7 @@ func (s *GroupRepoSuite) TestListWithFilters_AccountCount() {
|
||||
s.Require().NoError(err)
|
||||
|
||||
isExclusive := true
|
||||
groups, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, service.PlatformAnthropic, service.StatusActive, &isExclusive)
|
||||
groups, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, service.PlatformAnthropic, service.StatusActive, "", &isExclusive)
|
||||
s.Require().NoError(err, "ListWithFilters")
|
||||
s.Require().Equal(int64(1), page.Total)
|
||||
s.Require().Len(groups, 1)
|
||||
|
||||
@@ -304,6 +304,10 @@ func TestAPIContracts(t *testing.T) {
|
||||
"turnstile_enabled": true,
|
||||
"turnstile_site_key": "site-key",
|
||||
"turnstile_secret_key_configured": true,
|
||||
"linuxdo_connect_enabled": false,
|
||||
"linuxdo_connect_client_id": "",
|
||||
"linuxdo_connect_client_secret_configured": false,
|
||||
"linuxdo_connect_redirect_url": "",
|
||||
"site_name": "Sub2API",
|
||||
"site_logo": "",
|
||||
"site_subtitle": "Subtitle",
|
||||
@@ -390,7 +394,7 @@ func newContractDeps(t *testing.T) *contractDeps {
|
||||
settingRepo := newStubSettingRepo()
|
||||
settingService := service.NewSettingService(settingRepo, cfg)
|
||||
|
||||
authHandler := handler.NewAuthHandler(cfg, nil, userService)
|
||||
authHandler := handler.NewAuthHandler(cfg, nil, userService, settingService)
|
||||
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
|
||||
usageHandler := handler.NewUsageHandler(usageService, apiKeyService)
|
||||
adminSettingHandler := adminhandler.NewSettingHandler(settingService, nil, nil)
|
||||
@@ -583,7 +587,7 @@ func (stubGroupRepo) List(ctx context.Context, params pagination.PaginationParam
|
||||
return nil, nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (stubGroupRepo) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status string, isExclusive *bool) ([]service.Group, *pagination.PaginationResult, error) {
|
||||
func (stubGroupRepo) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status, search string, isExclusive *bool) ([]service.Group, *pagination.PaginationResult, error) {
|
||||
return nil, nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
|
||||
@@ -19,6 +19,8 @@ func RegisterAuthRoutes(
|
||||
auth.POST("/register", h.Auth.Register)
|
||||
auth.POST("/login", h.Auth.Login)
|
||||
auth.POST("/send-verify-code", h.Auth.SendVerifyCode)
|
||||
auth.GET("/oauth/linuxdo/start", h.Auth.LinuxDoOAuthStart)
|
||||
auth.GET("/oauth/linuxdo/callback", h.Auth.LinuxDoOAuthCallback)
|
||||
}
|
||||
|
||||
// 公开设置(无需认证)
|
||||
|
||||
@@ -66,6 +66,7 @@ type AccountBulkUpdate struct {
|
||||
Concurrency *int
|
||||
Priority *int
|
||||
Status *string
|
||||
Schedulable *bool
|
||||
Credentials map[string]any
|
||||
Extra map[string]any
|
||||
}
|
||||
|
||||
@@ -661,13 +661,7 @@ func (s *AccountTestService) processGeminiStream(c *gin.Context, body io.Reader)
|
||||
}
|
||||
if candidates, ok := data["candidates"].([]any); ok && len(candidates) > 0 {
|
||||
if candidate, ok := candidates[0].(map[string]any); ok {
|
||||
// Check for completion
|
||||
if finishReason, ok := candidate["finishReason"].(string); ok && finishReason != "" {
|
||||
s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
|
||||
return nil
|
||||
}
|
||||
|
||||
// Extract content
|
||||
// Extract content first (before checking completion)
|
||||
if content, ok := candidate["content"].(map[string]any); ok {
|
||||
if parts, ok := content["parts"].([]any); ok {
|
||||
for _, part := range parts {
|
||||
@@ -679,6 +673,12 @@ func (s *AccountTestService) processGeminiStream(c *gin.Context, body io.Reader)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check for completion after extracting content
|
||||
if finishReason, ok := candidate["finishReason"].(string); ok && finishReason != "" {
|
||||
s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -24,7 +24,7 @@ type AdminService interface {
|
||||
GetUserUsageStats(ctx context.Context, userID int64, period string) (any, error)
|
||||
|
||||
// Group management
|
||||
ListGroups(ctx context.Context, page, pageSize int, platform, status string, isExclusive *bool) ([]Group, int64, error)
|
||||
ListGroups(ctx context.Context, page, pageSize int, platform, status, search string, isExclusive *bool) ([]Group, int64, error)
|
||||
GetAllGroups(ctx context.Context) ([]Group, error)
|
||||
GetAllGroupsByPlatform(ctx context.Context, platform string) ([]Group, error)
|
||||
GetGroup(ctx context.Context, id int64) (*Group, error)
|
||||
@@ -168,6 +168,7 @@ type BulkUpdateAccountsInput struct {
|
||||
Concurrency *int
|
||||
Priority *int
|
||||
Status string
|
||||
Schedulable *bool
|
||||
GroupIDs *[]int64
|
||||
Credentials map[string]any
|
||||
Extra map[string]any
|
||||
@@ -478,9 +479,9 @@ func (s *adminServiceImpl) GetUserUsageStats(ctx context.Context, userID int64,
|
||||
}
|
||||
|
||||
// Group management implementations
|
||||
func (s *adminServiceImpl) ListGroups(ctx context.Context, page, pageSize int, platform, status string, isExclusive *bool) ([]Group, int64, error) {
|
||||
func (s *adminServiceImpl) ListGroups(ctx context.Context, page, pageSize int, platform, status, search string, isExclusive *bool) ([]Group, int64, error) {
|
||||
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
|
||||
groups, result, err := s.groupRepo.ListWithFilters(ctx, params, platform, status, isExclusive)
|
||||
groups, result, err := s.groupRepo.ListWithFilters(ctx, params, platform, status, search, isExclusive)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
@@ -910,6 +911,9 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp
|
||||
if input.Status != "" {
|
||||
repoUpdates.Status = &input.Status
|
||||
}
|
||||
if input.Schedulable != nil {
|
||||
repoUpdates.Schedulable = input.Schedulable
|
||||
}
|
||||
|
||||
// Run bulk update for column/jsonb fields first.
|
||||
if _, err := s.accountRepo.BulkUpdate(ctx, input.AccountIDs, repoUpdates); err != nil {
|
||||
|
||||
@@ -124,7 +124,7 @@ func (s *groupRepoStub) List(ctx context.Context, params pagination.PaginationPa
|
||||
panic("unexpected List call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStub) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status string, isExclusive *bool) ([]Group, *pagination.PaginationResult, error) {
|
||||
func (s *groupRepoStub) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status, search string, isExclusive *bool) ([]Group, *pagination.PaginationResult, error) {
|
||||
panic("unexpected ListWithFilters call")
|
||||
}
|
||||
|
||||
|
||||
@@ -16,6 +16,16 @@ type groupRepoStubForAdmin struct {
|
||||
updated *Group // 记录 Update 调用的参数
|
||||
getByID *Group // GetByID 返回值
|
||||
getErr error // GetByID 返回的错误
|
||||
|
||||
listWithFiltersCalls int
|
||||
listWithFiltersParams pagination.PaginationParams
|
||||
listWithFiltersPlatform string
|
||||
listWithFiltersStatus string
|
||||
listWithFiltersSearch string
|
||||
listWithFiltersIsExclusive *bool
|
||||
listWithFiltersGroups []Group
|
||||
listWithFiltersResult *pagination.PaginationResult
|
||||
listWithFiltersErr error
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForAdmin) Create(_ context.Context, g *Group) error {
|
||||
@@ -47,8 +57,28 @@ func (s *groupRepoStubForAdmin) List(_ context.Context, _ pagination.PaginationP
|
||||
panic("unexpected List call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForAdmin) ListWithFilters(_ context.Context, _ pagination.PaginationParams, _, _ string, _ *bool) ([]Group, *pagination.PaginationResult, error) {
|
||||
panic("unexpected ListWithFilters call")
|
||||
func (s *groupRepoStubForAdmin) ListWithFilters(_ context.Context, params pagination.PaginationParams, platform, status, search string, isExclusive *bool) ([]Group, *pagination.PaginationResult, error) {
|
||||
s.listWithFiltersCalls++
|
||||
s.listWithFiltersParams = params
|
||||
s.listWithFiltersPlatform = platform
|
||||
s.listWithFiltersStatus = status
|
||||
s.listWithFiltersSearch = search
|
||||
s.listWithFiltersIsExclusive = isExclusive
|
||||
|
||||
if s.listWithFiltersErr != nil {
|
||||
return nil, nil, s.listWithFiltersErr
|
||||
}
|
||||
|
||||
result := s.listWithFiltersResult
|
||||
if result == nil {
|
||||
result = &pagination.PaginationResult{
|
||||
Total: int64(len(s.listWithFiltersGroups)),
|
||||
Page: params.Page,
|
||||
PageSize: params.PageSize,
|
||||
}
|
||||
}
|
||||
|
||||
return s.listWithFiltersGroups, result, nil
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForAdmin) ListActive(_ context.Context) ([]Group, error) {
|
||||
@@ -195,3 +225,68 @@ func TestAdminService_UpdateGroup_PartialImagePricing(t *testing.T) {
|
||||
require.InDelta(t, 0.15, *repo.updated.ImagePrice2K, 0.0001) // 原值保持
|
||||
require.Nil(t, repo.updated.ImagePrice4K)
|
||||
}
|
||||
|
||||
func TestAdminService_ListGroups_WithSearch(t *testing.T) {
|
||||
// 测试:
|
||||
// 1. search 参数正常传递到 repository 层
|
||||
// 2. search 为空字符串时的行为
|
||||
// 3. search 与其他过滤条件组合使用
|
||||
|
||||
t.Run("search 参数正常传递到 repository 层", func(t *testing.T) {
|
||||
repo := &groupRepoStubForAdmin{
|
||||
listWithFiltersGroups: []Group{{ID: 1, Name: "alpha"}},
|
||||
listWithFiltersResult: &pagination.PaginationResult{Total: 1},
|
||||
}
|
||||
svc := &adminServiceImpl{groupRepo: repo}
|
||||
|
||||
groups, total, err := svc.ListGroups(context.Background(), 1, 20, "", "", "alpha", nil)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(1), total)
|
||||
require.Equal(t, []Group{{ID: 1, Name: "alpha"}}, groups)
|
||||
|
||||
require.Equal(t, 1, repo.listWithFiltersCalls)
|
||||
require.Equal(t, pagination.PaginationParams{Page: 1, PageSize: 20}, repo.listWithFiltersParams)
|
||||
require.Equal(t, "alpha", repo.listWithFiltersSearch)
|
||||
require.Nil(t, repo.listWithFiltersIsExclusive)
|
||||
})
|
||||
|
||||
t.Run("search 为空字符串时传递空字符串", func(t *testing.T) {
|
||||
repo := &groupRepoStubForAdmin{
|
||||
listWithFiltersGroups: []Group{},
|
||||
listWithFiltersResult: &pagination.PaginationResult{Total: 0},
|
||||
}
|
||||
svc := &adminServiceImpl{groupRepo: repo}
|
||||
|
||||
groups, total, err := svc.ListGroups(context.Background(), 2, 10, "", "", "", nil)
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, groups)
|
||||
require.Equal(t, int64(0), total)
|
||||
|
||||
require.Equal(t, 1, repo.listWithFiltersCalls)
|
||||
require.Equal(t, pagination.PaginationParams{Page: 2, PageSize: 10}, repo.listWithFiltersParams)
|
||||
require.Equal(t, "", repo.listWithFiltersSearch)
|
||||
require.Nil(t, repo.listWithFiltersIsExclusive)
|
||||
})
|
||||
|
||||
t.Run("search 与其他过滤条件组合使用", func(t *testing.T) {
|
||||
isExclusive := true
|
||||
repo := &groupRepoStubForAdmin{
|
||||
listWithFiltersGroups: []Group{{ID: 2, Name: "beta"}},
|
||||
listWithFiltersResult: &pagination.PaginationResult{Total: 42},
|
||||
}
|
||||
svc := &adminServiceImpl{groupRepo: repo}
|
||||
|
||||
groups, total, err := svc.ListGroups(context.Background(), 3, 50, PlatformAntigravity, StatusActive, "beta", &isExclusive)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(42), total)
|
||||
require.Equal(t, []Group{{ID: 2, Name: "beta"}}, groups)
|
||||
|
||||
require.Equal(t, 1, repo.listWithFiltersCalls)
|
||||
require.Equal(t, pagination.PaginationParams{Page: 3, PageSize: 50}, repo.listWithFiltersParams)
|
||||
require.Equal(t, PlatformAntigravity, repo.listWithFiltersPlatform)
|
||||
require.Equal(t, StatusActive, repo.listWithFiltersStatus)
|
||||
require.Equal(t, "beta", repo.listWithFiltersSearch)
|
||||
require.NotNil(t, repo.listWithFiltersIsExclusive)
|
||||
require.True(t, *repo.listWithFiltersIsExclusive)
|
||||
})
|
||||
}
|
||||
|
||||
238
backend/internal/service/admin_service_search_test.go
Normal file
238
backend/internal/service/admin_service_search_test.go
Normal file
@@ -0,0 +1,238 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type accountRepoStubForAdminList struct {
|
||||
accountRepoStub
|
||||
|
||||
listWithFiltersCalls int
|
||||
listWithFiltersParams pagination.PaginationParams
|
||||
listWithFiltersPlatform string
|
||||
listWithFiltersType string
|
||||
listWithFiltersStatus string
|
||||
listWithFiltersSearch string
|
||||
listWithFiltersAccounts []Account
|
||||
listWithFiltersResult *pagination.PaginationResult
|
||||
listWithFiltersErr error
|
||||
}
|
||||
|
||||
func (s *accountRepoStubForAdminList) ListWithFilters(_ context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]Account, *pagination.PaginationResult, error) {
|
||||
s.listWithFiltersCalls++
|
||||
s.listWithFiltersParams = params
|
||||
s.listWithFiltersPlatform = platform
|
||||
s.listWithFiltersType = accountType
|
||||
s.listWithFiltersStatus = status
|
||||
s.listWithFiltersSearch = search
|
||||
|
||||
if s.listWithFiltersErr != nil {
|
||||
return nil, nil, s.listWithFiltersErr
|
||||
}
|
||||
|
||||
result := s.listWithFiltersResult
|
||||
if result == nil {
|
||||
result = &pagination.PaginationResult{
|
||||
Total: int64(len(s.listWithFiltersAccounts)),
|
||||
Page: params.Page,
|
||||
PageSize: params.PageSize,
|
||||
}
|
||||
}
|
||||
|
||||
return s.listWithFiltersAccounts, result, nil
|
||||
}
|
||||
|
||||
type proxyRepoStubForAdminList struct {
|
||||
proxyRepoStub
|
||||
|
||||
listWithFiltersCalls int
|
||||
listWithFiltersParams pagination.PaginationParams
|
||||
listWithFiltersProtocol string
|
||||
listWithFiltersStatus string
|
||||
listWithFiltersSearch string
|
||||
listWithFiltersProxies []Proxy
|
||||
listWithFiltersResult *pagination.PaginationResult
|
||||
listWithFiltersErr error
|
||||
|
||||
listWithFiltersAndAccountCountCalls int
|
||||
listWithFiltersAndAccountCountParams pagination.PaginationParams
|
||||
listWithFiltersAndAccountCountProtocol string
|
||||
listWithFiltersAndAccountCountStatus string
|
||||
listWithFiltersAndAccountCountSearch string
|
||||
listWithFiltersAndAccountCountProxies []ProxyWithAccountCount
|
||||
listWithFiltersAndAccountCountResult *pagination.PaginationResult
|
||||
listWithFiltersAndAccountCountErr error
|
||||
}
|
||||
|
||||
func (s *proxyRepoStubForAdminList) ListWithFilters(_ context.Context, params pagination.PaginationParams, protocol, status, search string) ([]Proxy, *pagination.PaginationResult, error) {
|
||||
s.listWithFiltersCalls++
|
||||
s.listWithFiltersParams = params
|
||||
s.listWithFiltersProtocol = protocol
|
||||
s.listWithFiltersStatus = status
|
||||
s.listWithFiltersSearch = search
|
||||
|
||||
if s.listWithFiltersErr != nil {
|
||||
return nil, nil, s.listWithFiltersErr
|
||||
}
|
||||
|
||||
result := s.listWithFiltersResult
|
||||
if result == nil {
|
||||
result = &pagination.PaginationResult{
|
||||
Total: int64(len(s.listWithFiltersProxies)),
|
||||
Page: params.Page,
|
||||
PageSize: params.PageSize,
|
||||
}
|
||||
}
|
||||
|
||||
return s.listWithFiltersProxies, result, nil
|
||||
}
|
||||
|
||||
func (s *proxyRepoStubForAdminList) ListWithFiltersAndAccountCount(_ context.Context, params pagination.PaginationParams, protocol, status, search string) ([]ProxyWithAccountCount, *pagination.PaginationResult, error) {
|
||||
s.listWithFiltersAndAccountCountCalls++
|
||||
s.listWithFiltersAndAccountCountParams = params
|
||||
s.listWithFiltersAndAccountCountProtocol = protocol
|
||||
s.listWithFiltersAndAccountCountStatus = status
|
||||
s.listWithFiltersAndAccountCountSearch = search
|
||||
|
||||
if s.listWithFiltersAndAccountCountErr != nil {
|
||||
return nil, nil, s.listWithFiltersAndAccountCountErr
|
||||
}
|
||||
|
||||
result := s.listWithFiltersAndAccountCountResult
|
||||
if result == nil {
|
||||
result = &pagination.PaginationResult{
|
||||
Total: int64(len(s.listWithFiltersAndAccountCountProxies)),
|
||||
Page: params.Page,
|
||||
PageSize: params.PageSize,
|
||||
}
|
||||
}
|
||||
|
||||
return s.listWithFiltersAndAccountCountProxies, result, nil
|
||||
}
|
||||
|
||||
type redeemRepoStubForAdminList struct {
|
||||
redeemRepoStub
|
||||
|
||||
listWithFiltersCalls int
|
||||
listWithFiltersParams pagination.PaginationParams
|
||||
listWithFiltersType string
|
||||
listWithFiltersStatus string
|
||||
listWithFiltersSearch string
|
||||
listWithFiltersCodes []RedeemCode
|
||||
listWithFiltersResult *pagination.PaginationResult
|
||||
listWithFiltersErr error
|
||||
}
|
||||
|
||||
func (s *redeemRepoStubForAdminList) ListWithFilters(_ context.Context, params pagination.PaginationParams, codeType, status, search string) ([]RedeemCode, *pagination.PaginationResult, error) {
|
||||
s.listWithFiltersCalls++
|
||||
s.listWithFiltersParams = params
|
||||
s.listWithFiltersType = codeType
|
||||
s.listWithFiltersStatus = status
|
||||
s.listWithFiltersSearch = search
|
||||
|
||||
if s.listWithFiltersErr != nil {
|
||||
return nil, nil, s.listWithFiltersErr
|
||||
}
|
||||
|
||||
result := s.listWithFiltersResult
|
||||
if result == nil {
|
||||
result = &pagination.PaginationResult{
|
||||
Total: int64(len(s.listWithFiltersCodes)),
|
||||
Page: params.Page,
|
||||
PageSize: params.PageSize,
|
||||
}
|
||||
}
|
||||
|
||||
return s.listWithFiltersCodes, result, nil
|
||||
}
|
||||
|
||||
func TestAdminService_ListAccounts_WithSearch(t *testing.T) {
|
||||
t.Run("search 参数正常传递到 repository 层", func(t *testing.T) {
|
||||
repo := &accountRepoStubForAdminList{
|
||||
listWithFiltersAccounts: []Account{{ID: 1, Name: "acc"}},
|
||||
listWithFiltersResult: &pagination.PaginationResult{Total: 10},
|
||||
}
|
||||
svc := &adminServiceImpl{accountRepo: repo}
|
||||
|
||||
accounts, total, err := svc.ListAccounts(context.Background(), 1, 20, PlatformGemini, AccountTypeOAuth, StatusActive, "acc")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(10), total)
|
||||
require.Equal(t, []Account{{ID: 1, Name: "acc"}}, accounts)
|
||||
|
||||
require.Equal(t, 1, repo.listWithFiltersCalls)
|
||||
require.Equal(t, pagination.PaginationParams{Page: 1, PageSize: 20}, repo.listWithFiltersParams)
|
||||
require.Equal(t, PlatformGemini, repo.listWithFiltersPlatform)
|
||||
require.Equal(t, AccountTypeOAuth, repo.listWithFiltersType)
|
||||
require.Equal(t, StatusActive, repo.listWithFiltersStatus)
|
||||
require.Equal(t, "acc", repo.listWithFiltersSearch)
|
||||
})
|
||||
}
|
||||
|
||||
func TestAdminService_ListProxies_WithSearch(t *testing.T) {
|
||||
t.Run("search 参数正常传递到 repository 层", func(t *testing.T) {
|
||||
repo := &proxyRepoStubForAdminList{
|
||||
listWithFiltersProxies: []Proxy{{ID: 2, Name: "p1"}},
|
||||
listWithFiltersResult: &pagination.PaginationResult{Total: 7},
|
||||
}
|
||||
svc := &adminServiceImpl{proxyRepo: repo}
|
||||
|
||||
proxies, total, err := svc.ListProxies(context.Background(), 3, 50, "http", StatusActive, "p1")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(7), total)
|
||||
require.Equal(t, []Proxy{{ID: 2, Name: "p1"}}, proxies)
|
||||
|
||||
require.Equal(t, 1, repo.listWithFiltersCalls)
|
||||
require.Equal(t, pagination.PaginationParams{Page: 3, PageSize: 50}, repo.listWithFiltersParams)
|
||||
require.Equal(t, "http", repo.listWithFiltersProtocol)
|
||||
require.Equal(t, StatusActive, repo.listWithFiltersStatus)
|
||||
require.Equal(t, "p1", repo.listWithFiltersSearch)
|
||||
})
|
||||
}
|
||||
|
||||
func TestAdminService_ListProxiesWithAccountCount_WithSearch(t *testing.T) {
|
||||
t.Run("search 参数正常传递到 repository 层", func(t *testing.T) {
|
||||
repo := &proxyRepoStubForAdminList{
|
||||
listWithFiltersAndAccountCountProxies: []ProxyWithAccountCount{{Proxy: Proxy{ID: 3, Name: "p2"}, AccountCount: 5}},
|
||||
listWithFiltersAndAccountCountResult: &pagination.PaginationResult{Total: 9},
|
||||
}
|
||||
svc := &adminServiceImpl{proxyRepo: repo}
|
||||
|
||||
proxies, total, err := svc.ListProxiesWithAccountCount(context.Background(), 2, 10, "socks5", StatusDisabled, "p2")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(9), total)
|
||||
require.Equal(t, []ProxyWithAccountCount{{Proxy: Proxy{ID: 3, Name: "p2"}, AccountCount: 5}}, proxies)
|
||||
|
||||
require.Equal(t, 1, repo.listWithFiltersAndAccountCountCalls)
|
||||
require.Equal(t, pagination.PaginationParams{Page: 2, PageSize: 10}, repo.listWithFiltersAndAccountCountParams)
|
||||
require.Equal(t, "socks5", repo.listWithFiltersAndAccountCountProtocol)
|
||||
require.Equal(t, StatusDisabled, repo.listWithFiltersAndAccountCountStatus)
|
||||
require.Equal(t, "p2", repo.listWithFiltersAndAccountCountSearch)
|
||||
})
|
||||
}
|
||||
|
||||
func TestAdminService_ListRedeemCodes_WithSearch(t *testing.T) {
|
||||
t.Run("search 参数正常传递到 repository 层", func(t *testing.T) {
|
||||
repo := &redeemRepoStubForAdminList{
|
||||
listWithFiltersCodes: []RedeemCode{{ID: 4, Code: "ABC"}},
|
||||
listWithFiltersResult: &pagination.PaginationResult{Total: 3},
|
||||
}
|
||||
svc := &adminServiceImpl{redeemCodeRepo: repo}
|
||||
|
||||
codes, total, err := svc.ListRedeemCodes(context.Background(), 1, 20, RedeemTypeBalance, StatusUnused, "ABC")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(3), total)
|
||||
require.Equal(t, []RedeemCode{{ID: 4, Code: "ABC"}}, codes)
|
||||
|
||||
require.Equal(t, 1, repo.listWithFiltersCalls)
|
||||
require.Equal(t, pagination.PaginationParams{Page: 1, PageSize: 20}, repo.listWithFiltersParams)
|
||||
require.Equal(t, RedeemTypeBalance, repo.listWithFiltersType)
|
||||
require.Equal(t, StatusUnused, repo.listWithFiltersStatus)
|
||||
require.Equal(t, "ABC", repo.listWithFiltersSearch)
|
||||
})
|
||||
}
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"io"
|
||||
"log"
|
||||
mathrand "math/rand"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
@@ -27,6 +28,32 @@ const (
|
||||
antigravityRetryMaxDelay = 16 * time.Second
|
||||
)
|
||||
|
||||
// isAntigravityConnectionError 判断是否为连接错误(网络超时、DNS 失败、连接拒绝)
|
||||
func isAntigravityConnectionError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// 检查超时错误
|
||||
var netErr net.Error
|
||||
if errors.As(err, &netErr) && netErr.Timeout() {
|
||||
return true
|
||||
}
|
||||
|
||||
// 检查连接错误(DNS 失败、连接拒绝)
|
||||
var opErr *net.OpError
|
||||
return errors.As(err, &opErr)
|
||||
}
|
||||
|
||||
// shouldAntigravityFallbackToNextURL 判断是否应切换到下一个 URL
|
||||
// 仅连接错误和 HTTP 429 触发 URL 降级
|
||||
func shouldAntigravityFallbackToNextURL(err error, statusCode int) bool {
|
||||
if isAntigravityConnectionError(err) {
|
||||
return true
|
||||
}
|
||||
return statusCode == http.StatusTooManyRequests
|
||||
}
|
||||
|
||||
// getSessionID 从 gin.Context 获取 session_id(用于日志追踪)
|
||||
func getSessionID(c *gin.Context) string {
|
||||
if c == nil {
|
||||
@@ -181,45 +208,70 @@ func (s *AntigravityGatewayService) TestConnection(ctx context.Context, account
|
||||
return nil, fmt.Errorf("构建请求失败: %w", err)
|
||||
}
|
||||
|
||||
// 构建 HTTP 请求(总是使用流式 endpoint,与官方客户端一致)
|
||||
req, err := antigravity.NewAPIRequest(ctx, "streamGenerateContent", accessToken, requestBody)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 调试日志:Test 请求信息
|
||||
log.Printf("[antigravity-Test] account=%s request_size=%d url=%s", account.Name, len(requestBody), req.URL.String())
|
||||
|
||||
// 代理 URL
|
||||
proxyURL := ""
|
||||
if account.ProxyID != nil && account.Proxy != nil {
|
||||
proxyURL = account.Proxy.URL()
|
||||
}
|
||||
|
||||
// 发送请求
|
||||
resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("请求失败: %w", err)
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
// 读取响应
|
||||
respBody, err := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("读取响应失败: %w", err)
|
||||
// URL fallback 循环
|
||||
availableURLs := antigravity.DefaultURLAvailability.GetAvailableURLs()
|
||||
if len(availableURLs) == 0 {
|
||||
availableURLs = antigravity.BaseURLs // 所有 URL 都不可用时,重试所有
|
||||
}
|
||||
|
||||
if resp.StatusCode >= 400 {
|
||||
return nil, fmt.Errorf("API 返回 %d: %s", resp.StatusCode, string(respBody))
|
||||
var lastErr error
|
||||
for urlIdx, baseURL := range availableURLs {
|
||||
// 构建 HTTP 请求(总是使用流式 endpoint,与官方客户端一致)
|
||||
req, err := antigravity.NewAPIRequestWithURL(ctx, baseURL, "streamGenerateContent", accessToken, requestBody)
|
||||
if err != nil {
|
||||
lastErr = err
|
||||
continue
|
||||
}
|
||||
|
||||
// 调试日志:Test 请求信息
|
||||
log.Printf("[antigravity-Test] account=%s request_size=%d url=%s", account.Name, len(requestBody), req.URL.String())
|
||||
|
||||
// 发送请求
|
||||
resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency)
|
||||
if err != nil {
|
||||
lastErr = fmt.Errorf("请求失败: %w", err)
|
||||
if shouldAntigravityFallbackToNextURL(err, 0) && urlIdx < len(availableURLs)-1 {
|
||||
antigravity.DefaultURLAvailability.MarkUnavailable(baseURL)
|
||||
log.Printf("[antigravity-Test] URL fallback: %s -> %s", baseURL, availableURLs[urlIdx+1])
|
||||
continue
|
||||
}
|
||||
return nil, lastErr
|
||||
}
|
||||
|
||||
// 读取响应
|
||||
respBody, err := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||||
_ = resp.Body.Close() // 立即关闭,避免循环内 defer 导致的资源泄漏
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("读取响应失败: %w", err)
|
||||
}
|
||||
|
||||
// 检查是否需要 URL 降级
|
||||
if shouldAntigravityFallbackToNextURL(nil, resp.StatusCode) && urlIdx < len(availableURLs)-1 {
|
||||
antigravity.DefaultURLAvailability.MarkUnavailable(baseURL)
|
||||
log.Printf("[antigravity-Test] URL fallback (HTTP %d): %s -> %s", resp.StatusCode, baseURL, availableURLs[urlIdx+1])
|
||||
continue
|
||||
}
|
||||
|
||||
if resp.StatusCode >= 400 {
|
||||
return nil, fmt.Errorf("API 返回 %d: %s", resp.StatusCode, string(respBody))
|
||||
}
|
||||
|
||||
// 解析流式响应,提取文本
|
||||
text := extractTextFromSSEResponse(respBody)
|
||||
|
||||
return &TestConnectionResult{
|
||||
Text: text,
|
||||
MappedModel: mappedModel,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// 解析流式响应,提取文本
|
||||
text := extractTextFromSSEResponse(respBody)
|
||||
|
||||
return &TestConnectionResult{
|
||||
Text: text,
|
||||
MappedModel: mappedModel,
|
||||
}, nil
|
||||
return nil, lastErr
|
||||
}
|
||||
|
||||
// buildGeminiTestRequest 构建 Gemini 格式测试请求
|
||||
@@ -484,62 +536,86 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
|
||||
// 如果客户端请求非流式,在响应处理阶段会收集完整流式响应后转换返回
|
||||
action := "streamGenerateContent"
|
||||
|
||||
// URL fallback 循环
|
||||
availableURLs := antigravity.DefaultURLAvailability.GetAvailableURLs()
|
||||
if len(availableURLs) == 0 {
|
||||
availableURLs = antigravity.BaseURLs // 所有 URL 都不可用时,重试所有
|
||||
}
|
||||
|
||||
// 重试循环
|
||||
var resp *http.Response
|
||||
for attempt := 1; attempt <= antigravityMaxRetries; attempt++ {
|
||||
// 检查 context 是否已取消(客户端断开连接)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
log.Printf("%s status=context_canceled error=%v", prefix, ctx.Err())
|
||||
return nil, ctx.Err()
|
||||
default:
|
||||
}
|
||||
urlFallbackLoop:
|
||||
for urlIdx, baseURL := range availableURLs {
|
||||
for attempt := 1; attempt <= antigravityMaxRetries; attempt++ {
|
||||
// 检查 context 是否已取消(客户端断开连接)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
log.Printf("%s status=context_canceled error=%v", prefix, ctx.Err())
|
||||
return nil, ctx.Err()
|
||||
default:
|
||||
}
|
||||
|
||||
upstreamReq, err := antigravity.NewAPIRequest(ctx, action, accessToken, geminiBody)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
upstreamReq, err := antigravity.NewAPIRequestWithURL(ctx, baseURL, action, accessToken, geminiBody)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resp, err = s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
|
||||
if err != nil {
|
||||
if attempt < antigravityMaxRetries {
|
||||
log.Printf("%s status=request_failed retry=%d/%d error=%v", prefix, attempt, antigravityMaxRetries, err)
|
||||
if !sleepAntigravityBackoffWithContext(ctx, attempt) {
|
||||
log.Printf("%s status=context_canceled_during_backoff", prefix)
|
||||
return nil, ctx.Err()
|
||||
resp, err = s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
|
||||
if err != nil {
|
||||
// 检查是否应触发 URL 降级
|
||||
if shouldAntigravityFallbackToNextURL(err, 0) && urlIdx < len(availableURLs)-1 {
|
||||
antigravity.DefaultURLAvailability.MarkUnavailable(baseURL)
|
||||
log.Printf("%s URL fallback (connection error): %s -> %s", prefix, baseURL, availableURLs[urlIdx+1])
|
||||
continue urlFallbackLoop
|
||||
}
|
||||
continue
|
||||
}
|
||||
log.Printf("%s status=request_failed retries_exhausted error=%v", prefix, err)
|
||||
return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Upstream request failed after retries")
|
||||
}
|
||||
|
||||
if resp.StatusCode >= 400 && s.shouldRetryUpstreamError(resp.StatusCode) {
|
||||
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||||
_ = resp.Body.Close()
|
||||
|
||||
if attempt < antigravityMaxRetries {
|
||||
log.Printf("%s status=%d retry=%d/%d body=%s", prefix, resp.StatusCode, attempt, antigravityMaxRetries, truncateForLog(respBody, 500))
|
||||
if !sleepAntigravityBackoffWithContext(ctx, attempt) {
|
||||
log.Printf("%s status=context_canceled_during_backoff", prefix)
|
||||
return nil, ctx.Err()
|
||||
if attempt < antigravityMaxRetries {
|
||||
log.Printf("%s status=request_failed retry=%d/%d error=%v", prefix, attempt, antigravityMaxRetries, err)
|
||||
if !sleepAntigravityBackoffWithContext(ctx, attempt) {
|
||||
log.Printf("%s status=context_canceled_during_backoff", prefix)
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
continue
|
||||
}
|
||||
continue
|
||||
log.Printf("%s status=request_failed retries_exhausted error=%v", prefix, err)
|
||||
return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Upstream request failed after retries")
|
||||
}
|
||||
// 所有重试都失败,标记限流状态
|
||||
if resp.StatusCode == 429 {
|
||||
s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody)
|
||||
}
|
||||
// 最后一次尝试也失败
|
||||
resp = &http.Response{
|
||||
StatusCode: resp.StatusCode,
|
||||
Header: resp.Header.Clone(),
|
||||
Body: io.NopCloser(bytes.NewReader(respBody)),
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
break
|
||||
// 检查是否应触发 URL 降级(仅 429)
|
||||
if resp.StatusCode == http.StatusTooManyRequests && urlIdx < len(availableURLs)-1 {
|
||||
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||||
_ = resp.Body.Close()
|
||||
antigravity.DefaultURLAvailability.MarkUnavailable(baseURL)
|
||||
log.Printf("%s URL fallback (HTTP 429): %s -> %s body=%s", prefix, baseURL, availableURLs[urlIdx+1], truncateForLog(respBody, 200))
|
||||
continue urlFallbackLoop
|
||||
}
|
||||
|
||||
if resp.StatusCode >= 400 && s.shouldRetryUpstreamError(resp.StatusCode) {
|
||||
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||||
_ = resp.Body.Close()
|
||||
|
||||
if attempt < antigravityMaxRetries {
|
||||
log.Printf("%s status=%d retry=%d/%d body=%s", prefix, resp.StatusCode, attempt, antigravityMaxRetries, truncateForLog(respBody, 500))
|
||||
if !sleepAntigravityBackoffWithContext(ctx, attempt) {
|
||||
log.Printf("%s status=context_canceled_during_backoff", prefix)
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
continue
|
||||
}
|
||||
// 所有重试都失败,标记限流状态
|
||||
if resp.StatusCode == 429 {
|
||||
s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody)
|
||||
}
|
||||
// 最后一次尝试也失败
|
||||
resp = &http.Response{
|
||||
StatusCode: resp.StatusCode,
|
||||
Header: resp.Header.Clone(),
|
||||
Body: io.NopCloser(bytes.NewReader(respBody)),
|
||||
}
|
||||
break urlFallbackLoop
|
||||
}
|
||||
|
||||
break urlFallbackLoop
|
||||
}
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
@@ -1003,61 +1079,85 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
|
||||
// 如果客户端请求非流式,在响应处理阶段会收集完整流式响应后返回
|
||||
upstreamAction := "streamGenerateContent"
|
||||
|
||||
// URL fallback 循环
|
||||
availableURLs := antigravity.DefaultURLAvailability.GetAvailableURLs()
|
||||
if len(availableURLs) == 0 {
|
||||
availableURLs = antigravity.BaseURLs // 所有 URL 都不可用时,重试所有
|
||||
}
|
||||
|
||||
// 重试循环
|
||||
var resp *http.Response
|
||||
for attempt := 1; attempt <= antigravityMaxRetries; attempt++ {
|
||||
// 检查 context 是否已取消(客户端断开连接)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
log.Printf("%s status=context_canceled error=%v", prefix, ctx.Err())
|
||||
return nil, ctx.Err()
|
||||
default:
|
||||
}
|
||||
urlFallbackLoop:
|
||||
for urlIdx, baseURL := range availableURLs {
|
||||
for attempt := 1; attempt <= antigravityMaxRetries; attempt++ {
|
||||
// 检查 context 是否已取消(客户端断开连接)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
log.Printf("%s status=context_canceled error=%v", prefix, ctx.Err())
|
||||
return nil, ctx.Err()
|
||||
default:
|
||||
}
|
||||
|
||||
upstreamReq, err := antigravity.NewAPIRequest(ctx, upstreamAction, accessToken, wrappedBody)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
upstreamReq, err := antigravity.NewAPIRequestWithURL(ctx, baseURL, upstreamAction, accessToken, wrappedBody)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resp, err = s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
|
||||
if err != nil {
|
||||
if attempt < antigravityMaxRetries {
|
||||
log.Printf("%s status=request_failed retry=%d/%d error=%v", prefix, attempt, antigravityMaxRetries, err)
|
||||
if !sleepAntigravityBackoffWithContext(ctx, attempt) {
|
||||
log.Printf("%s status=context_canceled_during_backoff", prefix)
|
||||
return nil, ctx.Err()
|
||||
resp, err = s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
|
||||
if err != nil {
|
||||
// 检查是否应触发 URL 降级
|
||||
if shouldAntigravityFallbackToNextURL(err, 0) && urlIdx < len(availableURLs)-1 {
|
||||
antigravity.DefaultURLAvailability.MarkUnavailable(baseURL)
|
||||
log.Printf("%s URL fallback (connection error): %s -> %s", prefix, baseURL, availableURLs[urlIdx+1])
|
||||
continue urlFallbackLoop
|
||||
}
|
||||
continue
|
||||
}
|
||||
log.Printf("%s status=request_failed retries_exhausted error=%v", prefix, err)
|
||||
return nil, s.writeGoogleError(c, http.StatusBadGateway, "Upstream request failed after retries")
|
||||
}
|
||||
|
||||
if resp.StatusCode >= 400 && s.shouldRetryUpstreamError(resp.StatusCode) {
|
||||
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||||
_ = resp.Body.Close()
|
||||
|
||||
if attempt < antigravityMaxRetries {
|
||||
log.Printf("%s status=%d retry=%d/%d", prefix, resp.StatusCode, attempt, antigravityMaxRetries)
|
||||
if !sleepAntigravityBackoffWithContext(ctx, attempt) {
|
||||
log.Printf("%s status=context_canceled_during_backoff", prefix)
|
||||
return nil, ctx.Err()
|
||||
if attempt < antigravityMaxRetries {
|
||||
log.Printf("%s status=request_failed retry=%d/%d error=%v", prefix, attempt, antigravityMaxRetries, err)
|
||||
if !sleepAntigravityBackoffWithContext(ctx, attempt) {
|
||||
log.Printf("%s status=context_canceled_during_backoff", prefix)
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
continue
|
||||
}
|
||||
continue
|
||||
log.Printf("%s status=request_failed retries_exhausted error=%v", prefix, err)
|
||||
return nil, s.writeGoogleError(c, http.StatusBadGateway, "Upstream request failed after retries")
|
||||
}
|
||||
// 所有重试都失败,标记限流状态
|
||||
if resp.StatusCode == 429 {
|
||||
s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody)
|
||||
}
|
||||
resp = &http.Response{
|
||||
StatusCode: resp.StatusCode,
|
||||
Header: resp.Header.Clone(),
|
||||
Body: io.NopCloser(bytes.NewReader(respBody)),
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
break
|
||||
// 检查是否应触发 URL 降级(仅 429)
|
||||
if resp.StatusCode == http.StatusTooManyRequests && urlIdx < len(availableURLs)-1 {
|
||||
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||||
_ = resp.Body.Close()
|
||||
antigravity.DefaultURLAvailability.MarkUnavailable(baseURL)
|
||||
log.Printf("%s URL fallback (HTTP 429): %s -> %s body=%s", prefix, baseURL, availableURLs[urlIdx+1], truncateForLog(respBody, 200))
|
||||
continue urlFallbackLoop
|
||||
}
|
||||
|
||||
if resp.StatusCode >= 400 && s.shouldRetryUpstreamError(resp.StatusCode) {
|
||||
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||||
_ = resp.Body.Close()
|
||||
|
||||
if attempt < antigravityMaxRetries {
|
||||
log.Printf("%s status=%d retry=%d/%d", prefix, resp.StatusCode, attempt, antigravityMaxRetries)
|
||||
if !sleepAntigravityBackoffWithContext(ctx, attempt) {
|
||||
log.Printf("%s status=context_canceled_during_backoff", prefix)
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
continue
|
||||
}
|
||||
// 所有重试都失败,标记限流状态
|
||||
if resp.StatusCode == 429 {
|
||||
s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody)
|
||||
}
|
||||
resp = &http.Response{
|
||||
StatusCode: resp.StatusCode,
|
||||
Header: resp.Header.Clone(),
|
||||
Body: io.NopCloser(bytes.NewReader(respBody)),
|
||||
}
|
||||
break urlFallbackLoop
|
||||
}
|
||||
|
||||
break urlFallbackLoop
|
||||
}
|
||||
}
|
||||
defer func() {
|
||||
if resp != nil && resp.Body != nil {
|
||||
|
||||
@@ -2,9 +2,13 @@ package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/mail"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
@@ -18,6 +22,7 @@ var (
|
||||
ErrInvalidCredentials = infraerrors.Unauthorized("INVALID_CREDENTIALS", "invalid email or password")
|
||||
ErrUserNotActive = infraerrors.Forbidden("USER_NOT_ACTIVE", "user is not active")
|
||||
ErrEmailExists = infraerrors.Conflict("EMAIL_EXISTS", "email already exists")
|
||||
ErrEmailReserved = infraerrors.BadRequest("EMAIL_RESERVED", "email is reserved")
|
||||
ErrInvalidToken = infraerrors.Unauthorized("INVALID_TOKEN", "invalid token")
|
||||
ErrTokenExpired = infraerrors.Unauthorized("TOKEN_EXPIRED", "token has expired")
|
||||
ErrTokenTooLarge = infraerrors.BadRequest("TOKEN_TOO_LARGE", "token too large")
|
||||
@@ -75,21 +80,30 @@ func (s *AuthService) Register(ctx context.Context, email, password string) (str
|
||||
|
||||
// RegisterWithVerification 用户注册(支持邮件验证),返回token和用户
|
||||
func (s *AuthService) RegisterWithVerification(ctx context.Context, email, password, verifyCode string) (string, *User, error) {
|
||||
// 检查是否开放注册
|
||||
if s.settingService != nil && !s.settingService.IsRegistrationEnabled(ctx) {
|
||||
// 检查是否开放注册(默认关闭:settingService 未配置时不允许注册)
|
||||
if s.settingService == nil || !s.settingService.IsRegistrationEnabled(ctx) {
|
||||
return "", nil, ErrRegDisabled
|
||||
}
|
||||
|
||||
// 防止用户注册 LinuxDo OAuth 合成邮箱,避免第三方登录与本地账号发生碰撞。
|
||||
if isReservedEmail(email) {
|
||||
return "", nil, ErrEmailReserved
|
||||
}
|
||||
|
||||
// 检查是否需要邮件验证
|
||||
if s.settingService != nil && s.settingService.IsEmailVerifyEnabled(ctx) {
|
||||
// 如果邮件验证已开启但邮件服务未配置,拒绝注册
|
||||
// 这是一个配置错误,不应该允许绕过验证
|
||||
if s.emailService == nil {
|
||||
log.Println("[Auth] Email verification enabled but email service not configured, rejecting registration")
|
||||
return "", nil, ErrServiceUnavailable
|
||||
}
|
||||
if verifyCode == "" {
|
||||
return "", nil, ErrEmailVerifyRequired
|
||||
}
|
||||
// 验证邮箱验证码
|
||||
if s.emailService != nil {
|
||||
if err := s.emailService.VerifyCode(ctx, email, verifyCode); err != nil {
|
||||
return "", nil, fmt.Errorf("verify code: %w", err)
|
||||
}
|
||||
if err := s.emailService.VerifyCode(ctx, email, verifyCode); err != nil {
|
||||
return "", nil, fmt.Errorf("verify code: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -128,6 +142,10 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw
|
||||
}
|
||||
|
||||
if err := s.userRepo.Create(ctx, user); err != nil {
|
||||
// 优先检查邮箱冲突错误(竞态条件下可能发生)
|
||||
if errors.Is(err, ErrEmailExists) {
|
||||
return "", nil, ErrEmailExists
|
||||
}
|
||||
log.Printf("[Auth] Database error creating user: %v", err)
|
||||
return "", nil, ErrServiceUnavailable
|
||||
}
|
||||
@@ -148,11 +166,15 @@ type SendVerifyCodeResult struct {
|
||||
|
||||
// SendVerifyCode 发送邮箱验证码(同步方式)
|
||||
func (s *AuthService) SendVerifyCode(ctx context.Context, email string) error {
|
||||
// 检查是否开放注册
|
||||
if s.settingService != nil && !s.settingService.IsRegistrationEnabled(ctx) {
|
||||
// 检查是否开放注册(默认关闭)
|
||||
if s.settingService == nil || !s.settingService.IsRegistrationEnabled(ctx) {
|
||||
return ErrRegDisabled
|
||||
}
|
||||
|
||||
if isReservedEmail(email) {
|
||||
return ErrEmailReserved
|
||||
}
|
||||
|
||||
// 检查邮箱是否已存在
|
||||
existsEmail, err := s.userRepo.ExistsByEmail(ctx, email)
|
||||
if err != nil {
|
||||
@@ -181,12 +203,16 @@ func (s *AuthService) SendVerifyCode(ctx context.Context, email string) error {
|
||||
func (s *AuthService) SendVerifyCodeAsync(ctx context.Context, email string) (*SendVerifyCodeResult, error) {
|
||||
log.Printf("[Auth] SendVerifyCodeAsync called for email: %s", email)
|
||||
|
||||
// 检查是否开放注册
|
||||
if s.settingService != nil && !s.settingService.IsRegistrationEnabled(ctx) {
|
||||
// 检查是否开放注册(默认关闭)
|
||||
if s.settingService == nil || !s.settingService.IsRegistrationEnabled(ctx) {
|
||||
log.Println("[Auth] Registration is disabled")
|
||||
return nil, ErrRegDisabled
|
||||
}
|
||||
|
||||
if isReservedEmail(email) {
|
||||
return nil, ErrEmailReserved
|
||||
}
|
||||
|
||||
// 检查邮箱是否已存在
|
||||
existsEmail, err := s.userRepo.ExistsByEmail(ctx, email)
|
||||
if err != nil {
|
||||
@@ -266,7 +292,7 @@ func (s *AuthService) IsTurnstileEnabled(ctx context.Context) bool {
|
||||
// IsRegistrationEnabled 检查是否开放注册
|
||||
func (s *AuthService) IsRegistrationEnabled(ctx context.Context) bool {
|
||||
if s.settingService == nil {
|
||||
return true
|
||||
return false // 安全默认:settingService 未配置时关闭注册
|
||||
}
|
||||
return s.settingService.IsRegistrationEnabled(ctx)
|
||||
}
|
||||
@@ -311,6 +337,102 @@ func (s *AuthService) Login(ctx context.Context, email, password string) (string
|
||||
return token, user, nil
|
||||
}
|
||||
|
||||
// LoginOrRegisterOAuth 用于第三方 OAuth/SSO 登录:
|
||||
// - 如果邮箱已存在:直接登录(不需要本地密码)
|
||||
// - 如果邮箱不存在:创建新用户并登录
|
||||
//
|
||||
// 注意:该函数用于“终端用户登录 Sub2API 本身”的场景(不同于上游账号的 OAuth,例如 OpenAI/Gemini)。
|
||||
// 为了满足现有数据库约束(需要密码哈希),新用户会生成随机密码并进行哈希保存。
|
||||
func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username string) (string, *User, error) {
|
||||
email = strings.TrimSpace(email)
|
||||
if email == "" || len(email) > 255 {
|
||||
return "", nil, infraerrors.BadRequest("INVALID_EMAIL", "invalid email")
|
||||
}
|
||||
if _, err := mail.ParseAddress(email); err != nil {
|
||||
return "", nil, infraerrors.BadRequest("INVALID_EMAIL", "invalid email")
|
||||
}
|
||||
|
||||
username = strings.TrimSpace(username)
|
||||
if len([]rune(username)) > 100 {
|
||||
username = string([]rune(username)[:100])
|
||||
}
|
||||
|
||||
user, err := s.userRepo.GetByEmail(ctx, email)
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrUserNotFound) {
|
||||
// OAuth 首次登录视为注册。
|
||||
if s.settingService != nil && !s.settingService.IsRegistrationEnabled(ctx) {
|
||||
return "", nil, ErrRegDisabled
|
||||
}
|
||||
|
||||
randomPassword, err := randomHexString(32)
|
||||
if err != nil {
|
||||
log.Printf("[Auth] Failed to generate random password for oauth signup: %v", err)
|
||||
return "", nil, ErrServiceUnavailable
|
||||
}
|
||||
hashedPassword, err := s.HashPassword(randomPassword)
|
||||
if err != nil {
|
||||
return "", nil, fmt.Errorf("hash password: %w", err)
|
||||
}
|
||||
|
||||
// 新用户默认值。
|
||||
defaultBalance := s.cfg.Default.UserBalance
|
||||
defaultConcurrency := s.cfg.Default.UserConcurrency
|
||||
if s.settingService != nil {
|
||||
defaultBalance = s.settingService.GetDefaultBalance(ctx)
|
||||
defaultConcurrency = s.settingService.GetDefaultConcurrency(ctx)
|
||||
}
|
||||
|
||||
newUser := &User{
|
||||
Email: email,
|
||||
Username: username,
|
||||
PasswordHash: hashedPassword,
|
||||
Role: RoleUser,
|
||||
Balance: defaultBalance,
|
||||
Concurrency: defaultConcurrency,
|
||||
Status: StatusActive,
|
||||
}
|
||||
|
||||
if err := s.userRepo.Create(ctx, newUser); err != nil {
|
||||
if errors.Is(err, ErrEmailExists) {
|
||||
// 并发场景:GetByEmail 与 Create 之间用户被创建。
|
||||
user, err = s.userRepo.GetByEmail(ctx, email)
|
||||
if err != nil {
|
||||
log.Printf("[Auth] Database error getting user after conflict: %v", err)
|
||||
return "", nil, ErrServiceUnavailable
|
||||
}
|
||||
} else {
|
||||
log.Printf("[Auth] Database error creating oauth user: %v", err)
|
||||
return "", nil, ErrServiceUnavailable
|
||||
}
|
||||
} else {
|
||||
user = newUser
|
||||
}
|
||||
} else {
|
||||
log.Printf("[Auth] Database error during oauth login: %v", err)
|
||||
return "", nil, ErrServiceUnavailable
|
||||
}
|
||||
}
|
||||
|
||||
if !user.IsActive() {
|
||||
return "", nil, ErrUserNotActive
|
||||
}
|
||||
|
||||
// 尽力补全:当用户名为空时,使用第三方返回的用户名回填。
|
||||
if user.Username == "" && username != "" {
|
||||
user.Username = username
|
||||
if err := s.userRepo.Update(ctx, user); err != nil {
|
||||
log.Printf("[Auth] Failed to update username after oauth login: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
token, err := s.GenerateToken(user)
|
||||
if err != nil {
|
||||
return "", nil, fmt.Errorf("generate token: %w", err)
|
||||
}
|
||||
return token, user, nil
|
||||
}
|
||||
|
||||
// ValidateToken 验证JWT token并返回用户声明
|
||||
func (s *AuthService) ValidateToken(tokenString string) (*JWTClaims, error) {
|
||||
// 先做长度校验,尽早拒绝异常超长 token,降低 DoS 风险。
|
||||
@@ -336,6 +458,11 @@ func (s *AuthService) ValidateToken(tokenString string) (*JWTClaims, error) {
|
||||
|
||||
if err != nil {
|
||||
if errors.Is(err, jwt.ErrTokenExpired) {
|
||||
// token 过期但仍返回 claims(用于 RefreshToken 等场景)
|
||||
// jwt-go 在解析时即使遇到过期错误,token.Claims 仍会被填充
|
||||
if claims, ok := token.Claims.(*JWTClaims); ok {
|
||||
return claims, ErrTokenExpired
|
||||
}
|
||||
return nil, ErrTokenExpired
|
||||
}
|
||||
return nil, ErrInvalidToken
|
||||
@@ -348,6 +475,22 @@ func (s *AuthService) ValidateToken(tokenString string) (*JWTClaims, error) {
|
||||
return nil, ErrInvalidToken
|
||||
}
|
||||
|
||||
func randomHexString(byteLength int) (string, error) {
|
||||
if byteLength <= 0 {
|
||||
byteLength = 16
|
||||
}
|
||||
buf := make([]byte, byteLength)
|
||||
if _, err := rand.Read(buf); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return hex.EncodeToString(buf), nil
|
||||
}
|
||||
|
||||
func isReservedEmail(email string) bool {
|
||||
normalized := strings.ToLower(strings.TrimSpace(email))
|
||||
return strings.HasSuffix(normalized, LinuxDoConnectSyntheticEmailDomain)
|
||||
}
|
||||
|
||||
// GenerateToken 生成JWT token
|
||||
func (s *AuthService) GenerateToken(user *User) (string, error) {
|
||||
now := time.Now()
|
||||
|
||||
@@ -113,13 +113,36 @@ func TestAuthService_Register_Disabled(t *testing.T) {
|
||||
require.ErrorIs(t, err, ErrRegDisabled)
|
||||
}
|
||||
|
||||
func TestAuthService_Register_EmailVerifyRequired(t *testing.T) {
|
||||
func TestAuthService_Register_DisabledByDefault(t *testing.T) {
|
||||
// 当 settings 为 nil(设置项不存在)时,注册应该默认关闭
|
||||
repo := &userRepoStub{}
|
||||
service := newAuthService(repo, nil, nil)
|
||||
|
||||
_, _, err := service.Register(context.Background(), "user@test.com", "password")
|
||||
require.ErrorIs(t, err, ErrRegDisabled)
|
||||
}
|
||||
|
||||
func TestAuthService_Register_EmailVerifyEnabledButServiceNotConfigured(t *testing.T) {
|
||||
repo := &userRepoStub{}
|
||||
// 邮件验证开启但 emailCache 为 nil(emailService 未配置)
|
||||
service := newAuthService(repo, map[string]string{
|
||||
SettingKeyRegistrationEnabled: "true",
|
||||
SettingKeyEmailVerifyEnabled: "true",
|
||||
}, nil)
|
||||
|
||||
// 应返回服务不可用错误,而不是允许绕过验证
|
||||
_, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "any-code")
|
||||
require.ErrorIs(t, err, ErrServiceUnavailable)
|
||||
}
|
||||
|
||||
func TestAuthService_Register_EmailVerifyRequired(t *testing.T) {
|
||||
repo := &userRepoStub{}
|
||||
cache := &emailCacheStub{} // 配置 emailService
|
||||
service := newAuthService(repo, map[string]string{
|
||||
SettingKeyRegistrationEnabled: "true",
|
||||
SettingKeyEmailVerifyEnabled: "true",
|
||||
}, cache)
|
||||
|
||||
_, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "")
|
||||
require.ErrorIs(t, err, ErrEmailVerifyRequired)
|
||||
}
|
||||
@@ -141,7 +164,9 @@ func TestAuthService_Register_EmailVerifyInvalid(t *testing.T) {
|
||||
|
||||
func TestAuthService_Register_EmailExists(t *testing.T) {
|
||||
repo := &userRepoStub{exists: true}
|
||||
service := newAuthService(repo, nil, nil)
|
||||
service := newAuthService(repo, map[string]string{
|
||||
SettingKeyRegistrationEnabled: "true",
|
||||
}, nil)
|
||||
|
||||
_, _, err := service.Register(context.Background(), "user@test.com", "password")
|
||||
require.ErrorIs(t, err, ErrEmailExists)
|
||||
@@ -149,23 +174,50 @@ func TestAuthService_Register_EmailExists(t *testing.T) {
|
||||
|
||||
func TestAuthService_Register_CheckEmailError(t *testing.T) {
|
||||
repo := &userRepoStub{existsErr: errors.New("db down")}
|
||||
service := newAuthService(repo, nil, nil)
|
||||
service := newAuthService(repo, map[string]string{
|
||||
SettingKeyRegistrationEnabled: "true",
|
||||
}, nil)
|
||||
|
||||
_, _, err := service.Register(context.Background(), "user@test.com", "password")
|
||||
require.ErrorIs(t, err, ErrServiceUnavailable)
|
||||
}
|
||||
|
||||
func TestAuthService_Register_ReservedEmail(t *testing.T) {
|
||||
repo := &userRepoStub{}
|
||||
service := newAuthService(repo, map[string]string{
|
||||
SettingKeyRegistrationEnabled: "true",
|
||||
}, nil)
|
||||
|
||||
_, _, err := service.Register(context.Background(), "linuxdo-123@linuxdo-connect.invalid", "password")
|
||||
require.ErrorIs(t, err, ErrEmailReserved)
|
||||
}
|
||||
|
||||
func TestAuthService_Register_CreateError(t *testing.T) {
|
||||
repo := &userRepoStub{createErr: errors.New("create failed")}
|
||||
service := newAuthService(repo, nil, nil)
|
||||
service := newAuthService(repo, map[string]string{
|
||||
SettingKeyRegistrationEnabled: "true",
|
||||
}, nil)
|
||||
|
||||
_, _, err := service.Register(context.Background(), "user@test.com", "password")
|
||||
require.ErrorIs(t, err, ErrServiceUnavailable)
|
||||
}
|
||||
|
||||
func TestAuthService_Register_CreateEmailExistsRace(t *testing.T) {
|
||||
// 模拟竞态条件:ExistsByEmail 返回 false,但 Create 时因唯一约束失败
|
||||
repo := &userRepoStub{createErr: ErrEmailExists}
|
||||
service := newAuthService(repo, map[string]string{
|
||||
SettingKeyRegistrationEnabled: "true",
|
||||
}, nil)
|
||||
|
||||
_, _, err := service.Register(context.Background(), "user@test.com", "password")
|
||||
require.ErrorIs(t, err, ErrEmailExists)
|
||||
}
|
||||
|
||||
func TestAuthService_Register_Success(t *testing.T) {
|
||||
repo := &userRepoStub{nextID: 5}
|
||||
service := newAuthService(repo, nil, nil)
|
||||
service := newAuthService(repo, map[string]string{
|
||||
SettingKeyRegistrationEnabled: "true",
|
||||
}, nil)
|
||||
|
||||
token, user, err := service.Register(context.Background(), "user@test.com", "password")
|
||||
require.NoError(t, err)
|
||||
@@ -180,3 +232,63 @@ func TestAuthService_Register_Success(t *testing.T) {
|
||||
require.Len(t, repo.created, 1)
|
||||
require.True(t, user.CheckPassword("password"))
|
||||
}
|
||||
|
||||
func TestAuthService_ValidateToken_ExpiredReturnsClaimsWithError(t *testing.T) {
|
||||
repo := &userRepoStub{}
|
||||
service := newAuthService(repo, nil, nil)
|
||||
|
||||
// 创建用户并生成 token
|
||||
user := &User{
|
||||
ID: 1,
|
||||
Email: "test@test.com",
|
||||
Role: RoleUser,
|
||||
Status: StatusActive,
|
||||
TokenVersion: 1,
|
||||
}
|
||||
token, err := service.GenerateToken(user)
|
||||
require.NoError(t, err)
|
||||
|
||||
// 验证有效 token
|
||||
claims, err := service.ValidateToken(token)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, claims)
|
||||
require.Equal(t, int64(1), claims.UserID)
|
||||
|
||||
// 模拟过期 token(通过创建一个过期很久的 token)
|
||||
service.cfg.JWT.ExpireHour = -1 // 设置为负数使 token 立即过期
|
||||
expiredToken, err := service.GenerateToken(user)
|
||||
require.NoError(t, err)
|
||||
service.cfg.JWT.ExpireHour = 1 // 恢复
|
||||
|
||||
// 验证过期 token 应返回 claims 和 ErrTokenExpired
|
||||
claims, err = service.ValidateToken(expiredToken)
|
||||
require.ErrorIs(t, err, ErrTokenExpired)
|
||||
require.NotNil(t, claims, "claims should not be nil when token is expired")
|
||||
require.Equal(t, int64(1), claims.UserID)
|
||||
require.Equal(t, "test@test.com", claims.Email)
|
||||
}
|
||||
|
||||
func TestAuthService_RefreshToken_ExpiredTokenNoPanic(t *testing.T) {
|
||||
user := &User{
|
||||
ID: 1,
|
||||
Email: "test@test.com",
|
||||
Role: RoleUser,
|
||||
Status: StatusActive,
|
||||
TokenVersion: 1,
|
||||
}
|
||||
repo := &userRepoStub{user: user}
|
||||
service := newAuthService(repo, nil, nil)
|
||||
|
||||
// 创建过期 token
|
||||
service.cfg.JWT.ExpireHour = -1
|
||||
expiredToken, err := service.GenerateToken(user)
|
||||
require.NoError(t, err)
|
||||
service.cfg.JWT.ExpireHour = 1
|
||||
|
||||
// RefreshToken 使用过期 token 不应 panic
|
||||
require.NotPanics(t, func() {
|
||||
newToken, err := service.RefreshToken(context.Background(), expiredToken)
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, newToken)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -105,7 +105,17 @@ const (
|
||||
// Request identity patch (Claude -> Gemini systemInstruction injection)
|
||||
SettingKeyEnableIdentityPatch = "enable_identity_patch"
|
||||
SettingKeyIdentityPatchPrompt = "identity_patch_prompt"
|
||||
|
||||
// LinuxDo Connect OAuth 登录(终端用户 SSO)
|
||||
SettingKeyLinuxDoConnectEnabled = "linuxdo_connect_enabled"
|
||||
SettingKeyLinuxDoConnectClientID = "linuxdo_connect_client_id"
|
||||
SettingKeyLinuxDoConnectClientSecret = "linuxdo_connect_client_secret"
|
||||
SettingKeyLinuxDoConnectRedirectURL = "linuxdo_connect_redirect_url"
|
||||
)
|
||||
|
||||
// LinuxDoConnectSyntheticEmailDomain 是 LinuxDo Connect 用户的合成邮箱后缀(RFC 保留域名)。
|
||||
// 目的:避免第三方登录返回的用户标识与本地真实邮箱发生碰撞,进而造成账号被接管的风险。
|
||||
const LinuxDoConnectSyntheticEmailDomain = "@linuxdo-connect.invalid"
|
||||
|
||||
// AdminAPIKeyPrefix is the prefix for admin API keys (distinct from user "sk-" keys).
|
||||
const AdminAPIKeyPrefix = "admin-"
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"crypto/rand"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"log"
|
||||
"math/big"
|
||||
"net/smtp"
|
||||
"strconv"
|
||||
@@ -256,7 +257,9 @@ func (s *EmailService) VerifyCode(ctx context.Context, email, code string) error
|
||||
// 验证码不匹配
|
||||
if data.Code != code {
|
||||
data.Attempts++
|
||||
_ = s.cache.SetVerificationCode(ctx, email, data, verifyCodeTTL)
|
||||
if err := s.cache.SetVerificationCode(ctx, email, data, verifyCodeTTL); err != nil {
|
||||
log.Printf("[Email] Failed to update verification attempt count: %v", err)
|
||||
}
|
||||
if data.Attempts >= maxVerifyCodeAttempts {
|
||||
return ErrVerifyCodeMaxAttempts
|
||||
}
|
||||
@@ -264,7 +267,9 @@ func (s *EmailService) VerifyCode(ctx context.Context, email, code string) error
|
||||
}
|
||||
|
||||
// 验证成功,删除验证码
|
||||
_ = s.cache.DeleteVerificationCode(ctx, email)
|
||||
if err := s.cache.DeleteVerificationCode(ctx, email); err != nil {
|
||||
log.Printf("[Email] Failed to delete verification code after success: %v", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -166,7 +166,7 @@ func (m *mockGroupRepoForGemini) DeleteCascade(ctx context.Context, id int64) ([
|
||||
func (m *mockGroupRepoForGemini) List(ctx context.Context, params pagination.PaginationParams) ([]Group, *pagination.PaginationResult, error) {
|
||||
return nil, nil, nil
|
||||
}
|
||||
func (m *mockGroupRepoForGemini) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status string, isExclusive *bool) ([]Group, *pagination.PaginationResult, error) {
|
||||
func (m *mockGroupRepoForGemini) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status, search string, isExclusive *bool) ([]Group, *pagination.PaginationResult, error) {
|
||||
return nil, nil, nil
|
||||
}
|
||||
func (m *mockGroupRepoForGemini) ListActive(ctx context.Context) ([]Group, error) { return nil, nil }
|
||||
|
||||
@@ -120,15 +120,16 @@ func (s *GeminiOAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64
|
||||
}
|
||||
|
||||
// OAuth client selection:
|
||||
// - code_assist: always use built-in Gemini CLI OAuth client (public), regardless of configured client_id/secret.
|
||||
// - google_one: uses configured OAuth client when provided; otherwise falls back to built-in client.
|
||||
// - ai_studio: requires a user-provided OAuth client.
|
||||
// - code_assist: always use built-in Gemini CLI OAuth client (public)
|
||||
// - google_one: always use built-in Gemini CLI OAuth client (public)
|
||||
// - ai_studio: requires a user-provided OAuth client
|
||||
oauthCfg := geminicli.OAuthConfig{
|
||||
ClientID: s.cfg.Gemini.OAuth.ClientID,
|
||||
ClientSecret: s.cfg.Gemini.OAuth.ClientSecret,
|
||||
Scopes: s.cfg.Gemini.OAuth.Scopes,
|
||||
}
|
||||
if oauthType == "code_assist" {
|
||||
if oauthType == "code_assist" || oauthType == "google_one" {
|
||||
// Force use of built-in Gemini CLI OAuth client
|
||||
oauthCfg.ClientID = ""
|
||||
oauthCfg.ClientSecret = ""
|
||||
}
|
||||
@@ -576,6 +577,20 @@ func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExch
|
||||
|
||||
case "google_one":
|
||||
log.Printf("[GeminiOAuth] Processing google_one OAuth type")
|
||||
|
||||
// Google One accounts use cloudaicompanion API, which requires a project_id.
|
||||
// For personal accounts, Google auto-assigns a project_id via the LoadCodeAssist API.
|
||||
if projectID == "" {
|
||||
log.Printf("[GeminiOAuth] No project_id provided, attempting to fetch from LoadCodeAssist API...")
|
||||
var err error
|
||||
projectID, _, err = s.fetchProjectID(ctx, tokenResp.AccessToken, proxyURL)
|
||||
if err != nil {
|
||||
log.Printf("[GeminiOAuth] ERROR: Failed to fetch project_id: %v", err)
|
||||
return nil, fmt.Errorf("google One accounts require a project_id, failed to auto-detect: %w", err)
|
||||
}
|
||||
log.Printf("[GeminiOAuth] Successfully fetched project_id: %s", projectID)
|
||||
}
|
||||
|
||||
log.Printf("[GeminiOAuth] Attempting to fetch Google One tier from Drive API...")
|
||||
// Attempt to fetch Drive storage tier
|
||||
var storageInfo *geminicli.DriveStorageInfo
|
||||
|
||||
@@ -40,7 +40,7 @@ func TestGeminiOAuthService_GenerateAuthURL_RedirectURIStrategy(t *testing.T) {
|
||||
wantProjectID: "",
|
||||
},
|
||||
{
|
||||
name: "google_one uses custom client when configured and redirects to localhost",
|
||||
name: "google_one always forces built-in client even when custom client configured",
|
||||
cfg: &config.Config{
|
||||
Gemini: config.GeminiConfig{
|
||||
OAuth: config.GeminiOAuthConfig{
|
||||
@@ -50,9 +50,9 @@ func TestGeminiOAuthService_GenerateAuthURL_RedirectURIStrategy(t *testing.T) {
|
||||
},
|
||||
},
|
||||
oauthType: "google_one",
|
||||
wantClientID: "custom-client-id",
|
||||
wantRedirect: geminicli.AIStudioOAuthRedirectURI,
|
||||
wantScope: geminicli.DefaultGoogleOneScopes,
|
||||
wantClientID: geminicli.GeminiCLIOAuthClientID,
|
||||
wantRedirect: geminicli.GeminiCLIRedirectURI,
|
||||
wantScope: geminicli.DefaultCodeAssistScopes,
|
||||
wantProjectID: "",
|
||||
},
|
||||
{
|
||||
|
||||
@@ -21,7 +21,7 @@ type GroupRepository interface {
|
||||
DeleteCascade(ctx context.Context, id int64) ([]int64, error)
|
||||
|
||||
List(ctx context.Context, params pagination.PaginationParams) ([]Group, *pagination.PaginationResult, error)
|
||||
ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status string, isExclusive *bool) ([]Group, *pagination.PaginationResult, error)
|
||||
ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status, search string, isExclusive *bool) ([]Group, *pagination.PaginationResult, error)
|
||||
ListActive(ctx context.Context) ([]Group, error)
|
||||
ListActiveByPlatform(ctx context.Context, platform string) ([]Group, error)
|
||||
|
||||
|
||||
@@ -540,10 +540,19 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
|
||||
bodyModified = true
|
||||
}
|
||||
|
||||
// For OAuth accounts using ChatGPT internal API, add store: false
|
||||
// For OAuth accounts using ChatGPT internal API:
|
||||
// 1. Add store: false
|
||||
// 2. Normalize input format for Codex API compatibility
|
||||
if account.Type == AccountTypeOAuth {
|
||||
reqBody["store"] = false
|
||||
bodyModified = true
|
||||
|
||||
// Normalize input format: convert AI SDK multi-part content format to simplified format
|
||||
// AI SDK sends: {"content": [{"type": "input_text", "text": "..."}]}
|
||||
// Codex API expects: {"content": "..."}
|
||||
if normalizeInputForCodexAPI(reqBody) {
|
||||
bodyModified = true
|
||||
}
|
||||
}
|
||||
|
||||
// Re-serialize body only if modified
|
||||
@@ -1085,6 +1094,101 @@ func (s *OpenAIGatewayService) replaceModelInResponseBody(body []byte, fromModel
|
||||
return newBody
|
||||
}
|
||||
|
||||
// normalizeInputForCodexAPI converts AI SDK multi-part content format to simplified format
|
||||
// that the ChatGPT internal Codex API expects.
|
||||
//
|
||||
// AI SDK sends content as an array of typed objects:
|
||||
//
|
||||
// {"content": [{"type": "input_text", "text": "hello"}]}
|
||||
//
|
||||
// ChatGPT Codex API expects content as a simple string:
|
||||
//
|
||||
// {"content": "hello"}
|
||||
//
|
||||
// This function modifies reqBody in-place and returns true if any modification was made.
|
||||
func normalizeInputForCodexAPI(reqBody map[string]any) bool {
|
||||
input, ok := reqBody["input"]
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
// Handle case where input is a simple string (already compatible)
|
||||
if _, isString := input.(string); isString {
|
||||
return false
|
||||
}
|
||||
|
||||
// Handle case where input is an array of messages
|
||||
inputArray, ok := input.([]any)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
modified := false
|
||||
for _, item := range inputArray {
|
||||
message, ok := item.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
content, ok := message["content"]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
// If content is already a string, no conversion needed
|
||||
if _, isString := content.(string); isString {
|
||||
continue
|
||||
}
|
||||
|
||||
// If content is an array (AI SDK format), convert to string
|
||||
contentArray, ok := content.([]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
// Extract text from content array
|
||||
var textParts []string
|
||||
for _, part := range contentArray {
|
||||
partMap, ok := part.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
// Handle different content types
|
||||
partType, _ := partMap["type"].(string)
|
||||
switch partType {
|
||||
case "input_text", "text":
|
||||
// Extract text from input_text or text type
|
||||
if text, ok := partMap["text"].(string); ok {
|
||||
textParts = append(textParts, text)
|
||||
}
|
||||
case "input_image", "image":
|
||||
// For images, we need to preserve the original format
|
||||
// as ChatGPT Codex API may support images in a different way
|
||||
// For now, skip image parts (they will be lost in conversion)
|
||||
// TODO: Consider preserving image data or handling it separately
|
||||
continue
|
||||
case "input_file", "file":
|
||||
// Similar to images, file inputs may need special handling
|
||||
continue
|
||||
default:
|
||||
// For unknown types, try to extract text if available
|
||||
if text, ok := partMap["text"].(string); ok {
|
||||
textParts = append(textParts, text)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Convert content array to string
|
||||
if len(textParts) > 0 {
|
||||
message["content"] = strings.Join(textParts, "\n")
|
||||
modified = true
|
||||
}
|
||||
}
|
||||
|
||||
return modified
|
||||
}
|
||||
|
||||
// OpenAIRecordUsageInput input for recording usage
|
||||
type OpenAIRecordUsageInput struct {
|
||||
Result *OpenAIForwardResult
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
@@ -64,6 +65,7 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
|
||||
SettingKeyAPIBaseURL,
|
||||
SettingKeyContactInfo,
|
||||
SettingKeyDocURL,
|
||||
SettingKeyLinuxDoConnectEnabled,
|
||||
}
|
||||
|
||||
settings, err := s.settingRepo.GetMultiple(ctx, keys)
|
||||
@@ -71,6 +73,13 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
|
||||
return nil, fmt.Errorf("get public settings: %w", err)
|
||||
}
|
||||
|
||||
linuxDoEnabled := false
|
||||
if raw, ok := settings[SettingKeyLinuxDoConnectEnabled]; ok {
|
||||
linuxDoEnabled = raw == "true"
|
||||
} else {
|
||||
linuxDoEnabled = s.cfg != nil && s.cfg.LinuxDo.Enabled
|
||||
}
|
||||
|
||||
return &PublicSettings{
|
||||
RegistrationEnabled: settings[SettingKeyRegistrationEnabled] == "true",
|
||||
EmailVerifyEnabled: settings[SettingKeyEmailVerifyEnabled] == "true",
|
||||
@@ -82,6 +91,7 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
|
||||
APIBaseURL: settings[SettingKeyAPIBaseURL],
|
||||
ContactInfo: settings[SettingKeyContactInfo],
|
||||
DocURL: settings[SettingKeyDocURL],
|
||||
LinuxDoOAuthEnabled: linuxDoEnabled,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -111,6 +121,14 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet
|
||||
updates[SettingKeyTurnstileSecretKey] = settings.TurnstileSecretKey
|
||||
}
|
||||
|
||||
// LinuxDo Connect OAuth 登录(终端用户 SSO)
|
||||
updates[SettingKeyLinuxDoConnectEnabled] = strconv.FormatBool(settings.LinuxDoConnectEnabled)
|
||||
updates[SettingKeyLinuxDoConnectClientID] = settings.LinuxDoConnectClientID
|
||||
updates[SettingKeyLinuxDoConnectRedirectURL] = settings.LinuxDoConnectRedirectURL
|
||||
if settings.LinuxDoConnectClientSecret != "" {
|
||||
updates[SettingKeyLinuxDoConnectClientSecret] = settings.LinuxDoConnectClientSecret
|
||||
}
|
||||
|
||||
// OEM设置
|
||||
updates[SettingKeySiteName] = settings.SiteName
|
||||
updates[SettingKeySiteLogo] = settings.SiteLogo
|
||||
@@ -141,8 +159,8 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet
|
||||
func (s *SettingService) IsRegistrationEnabled(ctx context.Context) bool {
|
||||
value, err := s.settingRepo.GetValue(ctx, SettingKeyRegistrationEnabled)
|
||||
if err != nil {
|
||||
// 默认开放注册
|
||||
return true
|
||||
// 安全默认:如果设置不存在或查询出错,默认关闭注册
|
||||
return false
|
||||
}
|
||||
return value == "true"
|
||||
}
|
||||
@@ -271,6 +289,38 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
|
||||
result.SMTPPassword = settings[SettingKeySMTPPassword]
|
||||
result.TurnstileSecretKey = settings[SettingKeyTurnstileSecretKey]
|
||||
|
||||
// LinuxDo Connect 设置:
|
||||
// - 兼容 config.yaml/env(避免老部署因为未迁移到数据库设置而被意外关闭)
|
||||
// - 支持在后台“系统设置”中覆盖并持久化(存储于 DB)
|
||||
linuxDoBase := config.LinuxDoConnectConfig{}
|
||||
if s.cfg != nil {
|
||||
linuxDoBase = s.cfg.LinuxDo
|
||||
}
|
||||
|
||||
if raw, ok := settings[SettingKeyLinuxDoConnectEnabled]; ok {
|
||||
result.LinuxDoConnectEnabled = raw == "true"
|
||||
} else {
|
||||
result.LinuxDoConnectEnabled = linuxDoBase.Enabled
|
||||
}
|
||||
|
||||
if v, ok := settings[SettingKeyLinuxDoConnectClientID]; ok && strings.TrimSpace(v) != "" {
|
||||
result.LinuxDoConnectClientID = strings.TrimSpace(v)
|
||||
} else {
|
||||
result.LinuxDoConnectClientID = linuxDoBase.ClientID
|
||||
}
|
||||
|
||||
if v, ok := settings[SettingKeyLinuxDoConnectRedirectURL]; ok && strings.TrimSpace(v) != "" {
|
||||
result.LinuxDoConnectRedirectURL = strings.TrimSpace(v)
|
||||
} else {
|
||||
result.LinuxDoConnectRedirectURL = linuxDoBase.RedirectURL
|
||||
}
|
||||
|
||||
result.LinuxDoConnectClientSecret = strings.TrimSpace(settings[SettingKeyLinuxDoConnectClientSecret])
|
||||
if result.LinuxDoConnectClientSecret == "" {
|
||||
result.LinuxDoConnectClientSecret = strings.TrimSpace(linuxDoBase.ClientSecret)
|
||||
}
|
||||
result.LinuxDoConnectClientSecretConfigured = result.LinuxDoConnectClientSecret != ""
|
||||
|
||||
// Model fallback settings
|
||||
result.EnableModelFallback = settings[SettingKeyEnableModelFallback] == "true"
|
||||
result.FallbackModelAnthropic = s.getStringOrDefault(settings, SettingKeyFallbackModelAnthropic, "claude-3-5-sonnet-20241022")
|
||||
@@ -289,6 +339,99 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
|
||||
return result
|
||||
}
|
||||
|
||||
// GetLinuxDoConnectOAuthConfig 返回用于登录的“最终生效” LinuxDo Connect 配置。
|
||||
//
|
||||
// 优先级:
|
||||
// - 若对应系统设置键存在,则覆盖 config.yaml/env 的值
|
||||
// - 否则回退到 config.yaml/env 的值
|
||||
func (s *SettingService) GetLinuxDoConnectOAuthConfig(ctx context.Context) (config.LinuxDoConnectConfig, error) {
|
||||
if s == nil || s.cfg == nil {
|
||||
return config.LinuxDoConnectConfig{}, infraerrors.ServiceUnavailable("CONFIG_NOT_READY", "config not loaded")
|
||||
}
|
||||
|
||||
effective := s.cfg.LinuxDo
|
||||
|
||||
keys := []string{
|
||||
SettingKeyLinuxDoConnectEnabled,
|
||||
SettingKeyLinuxDoConnectClientID,
|
||||
SettingKeyLinuxDoConnectClientSecret,
|
||||
SettingKeyLinuxDoConnectRedirectURL,
|
||||
}
|
||||
settings, err := s.settingRepo.GetMultiple(ctx, keys)
|
||||
if err != nil {
|
||||
return config.LinuxDoConnectConfig{}, fmt.Errorf("get linuxdo connect settings: %w", err)
|
||||
}
|
||||
|
||||
if raw, ok := settings[SettingKeyLinuxDoConnectEnabled]; ok {
|
||||
effective.Enabled = raw == "true"
|
||||
}
|
||||
if v, ok := settings[SettingKeyLinuxDoConnectClientID]; ok && strings.TrimSpace(v) != "" {
|
||||
effective.ClientID = strings.TrimSpace(v)
|
||||
}
|
||||
if v, ok := settings[SettingKeyLinuxDoConnectClientSecret]; ok && strings.TrimSpace(v) != "" {
|
||||
effective.ClientSecret = strings.TrimSpace(v)
|
||||
}
|
||||
if v, ok := settings[SettingKeyLinuxDoConnectRedirectURL]; ok && strings.TrimSpace(v) != "" {
|
||||
effective.RedirectURL = strings.TrimSpace(v)
|
||||
}
|
||||
|
||||
if !effective.Enabled {
|
||||
return config.LinuxDoConnectConfig{}, infraerrors.NotFound("OAUTH_DISABLED", "oauth login is disabled")
|
||||
}
|
||||
|
||||
// 基础健壮性校验(避免把用户重定向到一个必然失败或不安全的 OAuth 流程里)。
|
||||
if strings.TrimSpace(effective.ClientID) == "" {
|
||||
return config.LinuxDoConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth client id not configured")
|
||||
}
|
||||
if strings.TrimSpace(effective.AuthorizeURL) == "" {
|
||||
return config.LinuxDoConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth authorize url not configured")
|
||||
}
|
||||
if strings.TrimSpace(effective.TokenURL) == "" {
|
||||
return config.LinuxDoConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth token url not configured")
|
||||
}
|
||||
if strings.TrimSpace(effective.UserInfoURL) == "" {
|
||||
return config.LinuxDoConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth userinfo url not configured")
|
||||
}
|
||||
if strings.TrimSpace(effective.RedirectURL) == "" {
|
||||
return config.LinuxDoConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth redirect url not configured")
|
||||
}
|
||||
if strings.TrimSpace(effective.FrontendRedirectURL) == "" {
|
||||
return config.LinuxDoConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth frontend redirect url not configured")
|
||||
}
|
||||
|
||||
if err := config.ValidateAbsoluteHTTPURL(effective.AuthorizeURL); err != nil {
|
||||
return config.LinuxDoConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth authorize url invalid")
|
||||
}
|
||||
if err := config.ValidateAbsoluteHTTPURL(effective.TokenURL); err != nil {
|
||||
return config.LinuxDoConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth token url invalid")
|
||||
}
|
||||
if err := config.ValidateAbsoluteHTTPURL(effective.UserInfoURL); err != nil {
|
||||
return config.LinuxDoConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth userinfo url invalid")
|
||||
}
|
||||
if err := config.ValidateAbsoluteHTTPURL(effective.RedirectURL); err != nil {
|
||||
return config.LinuxDoConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth redirect url invalid")
|
||||
}
|
||||
if err := config.ValidateFrontendRedirectURL(effective.FrontendRedirectURL); err != nil {
|
||||
return config.LinuxDoConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth frontend redirect url invalid")
|
||||
}
|
||||
|
||||
method := strings.ToLower(strings.TrimSpace(effective.TokenAuthMethod))
|
||||
switch method {
|
||||
case "", "client_secret_post", "client_secret_basic":
|
||||
if strings.TrimSpace(effective.ClientSecret) == "" {
|
||||
return config.LinuxDoConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth client secret not configured")
|
||||
}
|
||||
case "none":
|
||||
if !effective.UsePKCE {
|
||||
return config.LinuxDoConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth pkce must be enabled when token_auth_method=none")
|
||||
}
|
||||
default:
|
||||
return config.LinuxDoConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth token_auth_method invalid")
|
||||
}
|
||||
|
||||
return effective, nil
|
||||
}
|
||||
|
||||
// getStringOrDefault 获取字符串值或默认值
|
||||
func (s *SettingService) getStringOrDefault(settings map[string]string, key, defaultValue string) string {
|
||||
if value, ok := settings[key]; ok && value != "" {
|
||||
|
||||
@@ -18,6 +18,13 @@ type SystemSettings struct {
|
||||
TurnstileSecretKey string
|
||||
TurnstileSecretKeyConfigured bool
|
||||
|
||||
// LinuxDo Connect OAuth 登录(终端用户 SSO)
|
||||
LinuxDoConnectEnabled bool
|
||||
LinuxDoConnectClientID string
|
||||
LinuxDoConnectClientSecret string
|
||||
LinuxDoConnectClientSecretConfigured bool
|
||||
LinuxDoConnectRedirectURL string
|
||||
|
||||
SiteName string
|
||||
SiteLogo string
|
||||
SiteSubtitle string
|
||||
@@ -51,5 +58,6 @@ type PublicSettings struct {
|
||||
APIBaseURL string
|
||||
ContactInfo string
|
||||
DocURL string
|
||||
LinuxDoOAuthEnabled bool
|
||||
Version string
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user