merge upstream main into fix/bug-cleanup-main

This commit is contained in:
IanShaw027
2026-04-09 21:35:48 +08:00
60 changed files with 6146 additions and 949 deletions

View File

@@ -65,6 +65,7 @@ type Config struct {
JWT JWTConfig `mapstructure:"jwt"`
Totp TotpConfig `mapstructure:"totp"`
LinuxDo LinuxDoConnectConfig `mapstructure:"linuxdo_connect"`
OIDC OIDCConnectConfig `mapstructure:"oidc_connect"`
Default DefaultConfig `mapstructure:"default"`
RateLimit RateLimitConfig `mapstructure:"rate_limit"`
Pricing PricingConfig `mapstructure:"pricing"`
@@ -184,6 +185,34 @@ type LinuxDoConnectConfig struct {
UserInfoUsernamePath string `mapstructure:"userinfo_username_path"`
}
type OIDCConnectConfig struct {
Enabled bool `mapstructure:"enabled"`
ProviderName string `mapstructure:"provider_name"` // 显示名: "Keycloak" 等
ClientID string `mapstructure:"client_id"`
ClientSecret string `mapstructure:"client_secret"`
IssuerURL string `mapstructure:"issuer_url"`
DiscoveryURL string `mapstructure:"discovery_url"`
AuthorizeURL string `mapstructure:"authorize_url"`
TokenURL string `mapstructure:"token_url"`
UserInfoURL string `mapstructure:"userinfo_url"`
JWKSURL string `mapstructure:"jwks_url"`
Scopes string `mapstructure:"scopes"` // 默认 "openid email profile"
RedirectURL string `mapstructure:"redirect_url"` // 后端回调地址(需在提供方后台登记)
FrontendRedirectURL string `mapstructure:"frontend_redirect_url"` // 前端接收 token 的路由(默认:/auth/oidc/callback
TokenAuthMethod string `mapstructure:"token_auth_method"` // client_secret_post / client_secret_basic / none
UsePKCE bool `mapstructure:"use_pkce"`
ValidateIDToken bool `mapstructure:"validate_id_token"`
AllowedSigningAlgs string `mapstructure:"allowed_signing_algs"` // 默认 "RS256,ES256,PS256"
ClockSkewSeconds int `mapstructure:"clock_skew_seconds"` // 默认 120
RequireEmailVerified bool `mapstructure:"require_email_verified"` // 默认 false
// 可选:用于从 userinfo JSON 中提取字段的 gjson 路径。
// 为空时,服务端会尝试一组常见字段名。
UserInfoEmailPath string `mapstructure:"userinfo_email_path"`
UserInfoIDPath string `mapstructure:"userinfo_id_path"`
UserInfoUsernamePath string `mapstructure:"userinfo_username_path"`
}
// TokenRefreshConfig OAuth token自动刷新配置
type TokenRefreshConfig struct {
// 是否启用自动刷新
@@ -318,6 +347,12 @@ type GatewayConfig struct {
// ForceCodexCLI: 强制将 OpenAI `/v1/responses` 请求按 Codex CLI 处理。
// 用于网关未透传/改写 User-Agent 时的兼容兜底(默认关闭,避免影响其他客户端)。
ForceCodexCLI bool `mapstructure:"force_codex_cli"`
// ForcedCodexInstructionsTemplateFile: 服务端强制附加到 Codex 顶层 instructions 的模板文件路径。
// 模板渲染后会直接覆盖最终 instructions若需要保留客户端 system 转换结果,请在模板中显式引用 {{ .ExistingInstructions }}。
ForcedCodexInstructionsTemplateFile string `mapstructure:"forced_codex_instructions_template_file"`
// ForcedCodexInstructionsTemplate: 启动时从模板文件读取并缓存的模板内容。
// 该字段不直接参与配置反序列化,仅用于请求热路径避免重复读盘。
ForcedCodexInstructionsTemplate string `mapstructure:"-"`
// OpenAIPassthroughAllowTimeoutHeaders: OpenAI 透传模式是否放行客户端超时头
// 关闭(默认)可避免 x-stainless-timeout 等头导致上游提前断流。
OpenAIPassthroughAllowTimeoutHeaders bool `mapstructure:"openai_passthrough_allow_timeout_headers"`
@@ -972,6 +1007,23 @@ func load(allowMissingJWTSecret bool) (*Config, error) {
cfg.LinuxDo.UserInfoEmailPath = strings.TrimSpace(cfg.LinuxDo.UserInfoEmailPath)
cfg.LinuxDo.UserInfoIDPath = strings.TrimSpace(cfg.LinuxDo.UserInfoIDPath)
cfg.LinuxDo.UserInfoUsernamePath = strings.TrimSpace(cfg.LinuxDo.UserInfoUsernamePath)
cfg.OIDC.ProviderName = strings.TrimSpace(cfg.OIDC.ProviderName)
cfg.OIDC.ClientID = strings.TrimSpace(cfg.OIDC.ClientID)
cfg.OIDC.ClientSecret = strings.TrimSpace(cfg.OIDC.ClientSecret)
cfg.OIDC.IssuerURL = strings.TrimSpace(cfg.OIDC.IssuerURL)
cfg.OIDC.DiscoveryURL = strings.TrimSpace(cfg.OIDC.DiscoveryURL)
cfg.OIDC.AuthorizeURL = strings.TrimSpace(cfg.OIDC.AuthorizeURL)
cfg.OIDC.TokenURL = strings.TrimSpace(cfg.OIDC.TokenURL)
cfg.OIDC.UserInfoURL = strings.TrimSpace(cfg.OIDC.UserInfoURL)
cfg.OIDC.JWKSURL = strings.TrimSpace(cfg.OIDC.JWKSURL)
cfg.OIDC.Scopes = strings.TrimSpace(cfg.OIDC.Scopes)
cfg.OIDC.RedirectURL = strings.TrimSpace(cfg.OIDC.RedirectURL)
cfg.OIDC.FrontendRedirectURL = strings.TrimSpace(cfg.OIDC.FrontendRedirectURL)
cfg.OIDC.TokenAuthMethod = strings.ToLower(strings.TrimSpace(cfg.OIDC.TokenAuthMethod))
cfg.OIDC.AllowedSigningAlgs = strings.TrimSpace(cfg.OIDC.AllowedSigningAlgs)
cfg.OIDC.UserInfoEmailPath = strings.TrimSpace(cfg.OIDC.UserInfoEmailPath)
cfg.OIDC.UserInfoIDPath = strings.TrimSpace(cfg.OIDC.UserInfoIDPath)
cfg.OIDC.UserInfoUsernamePath = strings.TrimSpace(cfg.OIDC.UserInfoUsernamePath)
cfg.Dashboard.KeyPrefix = strings.TrimSpace(cfg.Dashboard.KeyPrefix)
cfg.CORS.AllowedOrigins = normalizeStringSlice(cfg.CORS.AllowedOrigins)
cfg.Security.ResponseHeaders.AdditionalAllowed = normalizeStringSlice(cfg.Security.ResponseHeaders.AdditionalAllowed)
@@ -983,6 +1035,14 @@ func load(allowMissingJWTSecret bool) (*Config, error) {
cfg.Log.Environment = strings.TrimSpace(cfg.Log.Environment)
cfg.Log.StacktraceLevel = strings.ToLower(strings.TrimSpace(cfg.Log.StacktraceLevel))
cfg.Log.Output.FilePath = strings.TrimSpace(cfg.Log.Output.FilePath)
cfg.Gateway.ForcedCodexInstructionsTemplateFile = strings.TrimSpace(cfg.Gateway.ForcedCodexInstructionsTemplateFile)
if cfg.Gateway.ForcedCodexInstructionsTemplateFile != "" {
content, err := os.ReadFile(cfg.Gateway.ForcedCodexInstructionsTemplateFile)
if err != nil {
return nil, fmt.Errorf("read forced codex instructions template %q: %w", cfg.Gateway.ForcedCodexInstructionsTemplateFile, err)
}
cfg.Gateway.ForcedCodexInstructionsTemplate = string(content)
}
// 兼容旧键 gateway.openai_ws.sticky_previous_response_ttl_seconds。
// 新键未配置(<=0时回退旧键新键优先。
@@ -1142,6 +1202,30 @@ func setDefaults() {
viper.SetDefault("linuxdo_connect.userinfo_id_path", "")
viper.SetDefault("linuxdo_connect.userinfo_username_path", "")
// Generic OIDC OAuth 登录
viper.SetDefault("oidc_connect.enabled", false)
viper.SetDefault("oidc_connect.provider_name", "OIDC")
viper.SetDefault("oidc_connect.client_id", "")
viper.SetDefault("oidc_connect.client_secret", "")
viper.SetDefault("oidc_connect.issuer_url", "")
viper.SetDefault("oidc_connect.discovery_url", "")
viper.SetDefault("oidc_connect.authorize_url", "")
viper.SetDefault("oidc_connect.token_url", "")
viper.SetDefault("oidc_connect.userinfo_url", "")
viper.SetDefault("oidc_connect.jwks_url", "")
viper.SetDefault("oidc_connect.scopes", "openid email profile")
viper.SetDefault("oidc_connect.redirect_url", "")
viper.SetDefault("oidc_connect.frontend_redirect_url", "/auth/oidc/callback")
viper.SetDefault("oidc_connect.token_auth_method", "client_secret_post")
viper.SetDefault("oidc_connect.use_pkce", false)
viper.SetDefault("oidc_connect.validate_id_token", true)
viper.SetDefault("oidc_connect.allowed_signing_algs", "RS256,ES256,PS256")
viper.SetDefault("oidc_connect.clock_skew_seconds", 120)
viper.SetDefault("oidc_connect.require_email_verified", false)
viper.SetDefault("oidc_connect.userinfo_email_path", "")
viper.SetDefault("oidc_connect.userinfo_id_path", "")
viper.SetDefault("oidc_connect.userinfo_username_path", "")
// Database
viper.SetDefault("database.host", "localhost")
viper.SetDefault("database.port", 5432)
@@ -1578,6 +1662,87 @@ func (c *Config) Validate() error {
warnIfInsecureURL("linuxdo_connect.redirect_url", c.LinuxDo.RedirectURL)
warnIfInsecureURL("linuxdo_connect.frontend_redirect_url", c.LinuxDo.FrontendRedirectURL)
}
if c.OIDC.Enabled {
if strings.TrimSpace(c.OIDC.ClientID) == "" {
return fmt.Errorf("oidc_connect.client_id is required when oidc_connect.enabled=true")
}
if strings.TrimSpace(c.OIDC.IssuerURL) == "" {
return fmt.Errorf("oidc_connect.issuer_url is required when oidc_connect.enabled=true")
}
if strings.TrimSpace(c.OIDC.RedirectURL) == "" {
return fmt.Errorf("oidc_connect.redirect_url is required when oidc_connect.enabled=true")
}
if strings.TrimSpace(c.OIDC.FrontendRedirectURL) == "" {
return fmt.Errorf("oidc_connect.frontend_redirect_url is required when oidc_connect.enabled=true")
}
if !scopeContainsOpenID(c.OIDC.Scopes) {
return fmt.Errorf("oidc_connect.scopes must contain openid")
}
method := strings.ToLower(strings.TrimSpace(c.OIDC.TokenAuthMethod))
switch method {
case "", "client_secret_post", "client_secret_basic", "none":
default:
return fmt.Errorf("oidc_connect.token_auth_method must be one of: client_secret_post/client_secret_basic/none")
}
if method == "none" && !c.OIDC.UsePKCE {
return fmt.Errorf("oidc_connect.use_pkce must be true when oidc_connect.token_auth_method=none")
}
if (method == "" || method == "client_secret_post" || method == "client_secret_basic") &&
strings.TrimSpace(c.OIDC.ClientSecret) == "" {
return fmt.Errorf("oidc_connect.client_secret is required when oidc_connect.enabled=true and token_auth_method is client_secret_post/client_secret_basic")
}
if c.OIDC.ClockSkewSeconds < 0 || c.OIDC.ClockSkewSeconds > 600 {
return fmt.Errorf("oidc_connect.clock_skew_seconds must be between 0 and 600")
}
if c.OIDC.ValidateIDToken && strings.TrimSpace(c.OIDC.AllowedSigningAlgs) == "" {
return fmt.Errorf("oidc_connect.allowed_signing_algs is required when oidc_connect.validate_id_token=true")
}
if err := ValidateAbsoluteHTTPURL(c.OIDC.IssuerURL); err != nil {
return fmt.Errorf("oidc_connect.issuer_url invalid: %w", err)
}
if v := strings.TrimSpace(c.OIDC.DiscoveryURL); v != "" {
if err := ValidateAbsoluteHTTPURL(v); err != nil {
return fmt.Errorf("oidc_connect.discovery_url invalid: %w", err)
}
}
if v := strings.TrimSpace(c.OIDC.AuthorizeURL); v != "" {
if err := ValidateAbsoluteHTTPURL(v); err != nil {
return fmt.Errorf("oidc_connect.authorize_url invalid: %w", err)
}
}
if v := strings.TrimSpace(c.OIDC.TokenURL); v != "" {
if err := ValidateAbsoluteHTTPURL(v); err != nil {
return fmt.Errorf("oidc_connect.token_url invalid: %w", err)
}
}
if v := strings.TrimSpace(c.OIDC.UserInfoURL); v != "" {
if err := ValidateAbsoluteHTTPURL(v); err != nil {
return fmt.Errorf("oidc_connect.userinfo_url invalid: %w", err)
}
}
if v := strings.TrimSpace(c.OIDC.JWKSURL); v != "" {
if err := ValidateAbsoluteHTTPURL(v); err != nil {
return fmt.Errorf("oidc_connect.jwks_url invalid: %w", err)
}
}
if err := ValidateAbsoluteHTTPURL(c.OIDC.RedirectURL); err != nil {
return fmt.Errorf("oidc_connect.redirect_url invalid: %w", err)
}
if err := ValidateFrontendRedirectURL(c.OIDC.FrontendRedirectURL); err != nil {
return fmt.Errorf("oidc_connect.frontend_redirect_url invalid: %w", err)
}
warnIfInsecureURL("oidc_connect.issuer_url", c.OIDC.IssuerURL)
warnIfInsecureURL("oidc_connect.discovery_url", c.OIDC.DiscoveryURL)
warnIfInsecureURL("oidc_connect.authorize_url", c.OIDC.AuthorizeURL)
warnIfInsecureURL("oidc_connect.token_url", c.OIDC.TokenURL)
warnIfInsecureURL("oidc_connect.userinfo_url", c.OIDC.UserInfoURL)
warnIfInsecureURL("oidc_connect.jwks_url", c.OIDC.JWKSURL)
warnIfInsecureURL("oidc_connect.redirect_url", c.OIDC.RedirectURL)
warnIfInsecureURL("oidc_connect.frontend_redirect_url", c.OIDC.FrontendRedirectURL)
}
if c.Billing.CircuitBreaker.Enabled {
if c.Billing.CircuitBreaker.FailureThreshold <= 0 {
return fmt.Errorf("billing.circuit_breaker.failure_threshold must be positive")
@@ -2196,6 +2361,15 @@ func ValidateFrontendRedirectURL(raw string) error {
return nil
}
func scopeContainsOpenID(scopes string) bool {
for _, scope := range strings.Fields(strings.ToLower(strings.TrimSpace(scopes))) {
if scope == "openid" {
return true
}
}
return false
}
// isHTTPScheme 检查是否为 HTTP 或 HTTPS 协议
func isHTTPScheme(scheme string) bool {
return strings.EqualFold(scheme, "http") || strings.EqualFold(scheme, "https")

View File

@@ -1,6 +1,8 @@
package config
import (
"os"
"path/filepath"
"strings"
"testing"
"time"
@@ -223,6 +225,23 @@ func TestLoadSchedulingConfigFromEnv(t *testing.T) {
}
}
func TestLoadForcedCodexInstructionsTemplate(t *testing.T) {
resetViperWithJWTSecret(t)
tempDir := t.TempDir()
templatePath := filepath.Join(tempDir, "codex-instructions.md.tmpl")
configPath := filepath.Join(tempDir, "config.yaml")
require.NoError(t, os.WriteFile(templatePath, []byte("server-prefix\n\n{{ .ExistingInstructions }}"), 0o644))
require.NoError(t, os.WriteFile(configPath, []byte("gateway:\n forced_codex_instructions_template_file: \""+templatePath+"\"\n"), 0o644))
t.Setenv("DATA_DIR", tempDir)
cfg, err := Load()
require.NoError(t, err)
require.Equal(t, templatePath, cfg.Gateway.ForcedCodexInstructionsTemplateFile)
require.Equal(t, "server-prefix\n\n{{ .ExistingInstructions }}", cfg.Gateway.ForcedCodexInstructionsTemplate)
}
func TestLoadDefaultSecurityToggles(t *testing.T) {
resetViperWithJWTSecret(t)
@@ -351,6 +370,60 @@ func TestValidateLinuxDoPKCERequiredForPublicClient(t *testing.T) {
}
}
func TestValidateOIDCScopesMustContainOpenID(t *testing.T) {
resetViperWithJWTSecret(t)
cfg, err := Load()
if err != nil {
t.Fatalf("Load() error: %v", err)
}
cfg.OIDC.Enabled = true
cfg.OIDC.ClientID = "oidc-client"
cfg.OIDC.ClientSecret = "oidc-secret"
cfg.OIDC.IssuerURL = "https://issuer.example.com"
cfg.OIDC.AuthorizeURL = "https://issuer.example.com/auth"
cfg.OIDC.TokenURL = "https://issuer.example.com/token"
cfg.OIDC.JWKSURL = "https://issuer.example.com/jwks"
cfg.OIDC.RedirectURL = "https://example.com/api/v1/auth/oauth/oidc/callback"
cfg.OIDC.FrontendRedirectURL = "/auth/oidc/callback"
cfg.OIDC.Scopes = "profile email"
err = cfg.Validate()
if err == nil {
t.Fatalf("Validate() expected error when scopes do not include openid, got nil")
}
if !strings.Contains(err.Error(), "oidc_connect.scopes") {
t.Fatalf("Validate() expected oidc_connect.scopes error, got: %v", err)
}
}
func TestValidateOIDCAllowsIssuerOnlyEndpointsWithDiscoveryFallback(t *testing.T) {
resetViperWithJWTSecret(t)
cfg, err := Load()
if err != nil {
t.Fatalf("Load() error: %v", err)
}
cfg.OIDC.Enabled = true
cfg.OIDC.ClientID = "oidc-client"
cfg.OIDC.ClientSecret = "oidc-secret"
cfg.OIDC.IssuerURL = "https://issuer.example.com"
cfg.OIDC.AuthorizeURL = ""
cfg.OIDC.TokenURL = ""
cfg.OIDC.JWKSURL = ""
cfg.OIDC.RedirectURL = "https://example.com/api/v1/auth/oauth/oidc/callback"
cfg.OIDC.FrontendRedirectURL = "/auth/oidc/callback"
cfg.OIDC.Scopes = "openid email profile"
cfg.OIDC.ValidateIDToken = true
err = cfg.Validate()
if err != nil {
t.Fatalf("Validate() expected issuer-only OIDC config to pass with discovery fallback, got: %v", err)
}
}
func TestLoadDefaultDashboardCacheConfig(t *testing.T) {
resetViperWithJWTSecret(t)

View File

@@ -0,0 +1,10 @@
package domain
// OpenAIMessagesDispatchModelConfig controls how Anthropic /v1/messages
// requests are mapped onto OpenAI/Codex models.
type OpenAIMessagesDispatchModelConfig struct {
OpusMappedModel string `json:"opus_mapped_model,omitempty"`
SonnetMappedModel string `json:"sonnet_mapped_model,omitempty"`
HaikuMappedModel string `json:"haiku_mapped_model,omitempty"`
ExactModelMappings map[string]string `json:"exact_model_mappings,omitempty"`
}

View File

@@ -105,10 +105,11 @@ type CreateGroupRequest struct {
// 支持的模型系列(仅 antigravity 平台使用)
SupportedModelScopes []string `json:"supported_model_scopes"`
// OpenAI Messages 调度配置(仅 openai 平台使用)
AllowMessagesDispatch bool `json:"allow_messages_dispatch"`
RequireOAuthOnly bool `json:"require_oauth_only"`
RequirePrivacySet bool `json:"require_privacy_set"`
DefaultMappedModel string `json:"default_mapped_model"`
AllowMessagesDispatch bool `json:"allow_messages_dispatch"`
RequireOAuthOnly bool `json:"require_oauth_only"`
RequirePrivacySet bool `json:"require_privacy_set"`
DefaultMappedModel string `json:"default_mapped_model"`
MessagesDispatchModelConfig service.OpenAIMessagesDispatchModelConfig `json:"messages_dispatch_model_config"`
// 从指定分组复制账号(创建后自动绑定)
CopyAccountsFromGroupIDs []int64 `json:"copy_accounts_from_group_ids"`
}
@@ -139,10 +140,11 @@ type UpdateGroupRequest struct {
// 支持的模型系列(仅 antigravity 平台使用)
SupportedModelScopes *[]string `json:"supported_model_scopes"`
// OpenAI Messages 调度配置(仅 openai 平台使用)
AllowMessagesDispatch *bool `json:"allow_messages_dispatch"`
RequireOAuthOnly *bool `json:"require_oauth_only"`
RequirePrivacySet *bool `json:"require_privacy_set"`
DefaultMappedModel *string `json:"default_mapped_model"`
AllowMessagesDispatch *bool `json:"allow_messages_dispatch"`
RequireOAuthOnly *bool `json:"require_oauth_only"`
RequirePrivacySet *bool `json:"require_privacy_set"`
DefaultMappedModel *string `json:"default_mapped_model"`
MessagesDispatchModelConfig *service.OpenAIMessagesDispatchModelConfig `json:"messages_dispatch_model_config"`
// 从指定分组复制账号(同步操作:先清空当前分组的账号绑定,再绑定源分组的账号)
CopyAccountsFromGroupIDs []int64 `json:"copy_accounts_from_group_ids"`
}
@@ -259,6 +261,7 @@ func (h *GroupHandler) Create(c *gin.Context) {
RequireOAuthOnly: req.RequireOAuthOnly,
RequirePrivacySet: req.RequirePrivacySet,
DefaultMappedModel: req.DefaultMappedModel,
MessagesDispatchModelConfig: req.MessagesDispatchModelConfig,
CopyAccountsFromGroupIDs: req.CopyAccountsFromGroupIDs,
})
if err != nil {
@@ -309,6 +312,7 @@ func (h *GroupHandler) Update(c *gin.Context) {
RequireOAuthOnly: req.RequireOAuthOnly,
RequirePrivacySet: req.RequirePrivacySet,
DefaultMappedModel: req.DefaultMappedModel,
MessagesDispatchModelConfig: req.MessagesDispatchModelConfig,
CopyAccountsFromGroupIDs: req.CopyAccountsFromGroupIDs,
})
if err != nil {

View File

@@ -35,6 +35,15 @@ func generateMenuItemID() (string, error) {
return hex.EncodeToString(b), nil
}
func scopesContainOpenID(scopes string) bool {
for _, scope := range strings.Fields(strings.ToLower(strings.TrimSpace(scopes))) {
if scope == "openid" {
return true
}
}
return false
}
// SettingHandler 系统设置处理器
type SettingHandler struct {
settingService *service.SettingService
@@ -96,6 +105,28 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
LinuxDoConnectClientID: settings.LinuxDoConnectClientID,
LinuxDoConnectClientSecretConfigured: settings.LinuxDoConnectClientSecretConfigured,
LinuxDoConnectRedirectURL: settings.LinuxDoConnectRedirectURL,
OIDCConnectEnabled: settings.OIDCConnectEnabled,
OIDCConnectProviderName: settings.OIDCConnectProviderName,
OIDCConnectClientID: settings.OIDCConnectClientID,
OIDCConnectClientSecretConfigured: settings.OIDCConnectClientSecretConfigured,
OIDCConnectIssuerURL: settings.OIDCConnectIssuerURL,
OIDCConnectDiscoveryURL: settings.OIDCConnectDiscoveryURL,
OIDCConnectAuthorizeURL: settings.OIDCConnectAuthorizeURL,
OIDCConnectTokenURL: settings.OIDCConnectTokenURL,
OIDCConnectUserInfoURL: settings.OIDCConnectUserInfoURL,
OIDCConnectJWKSURL: settings.OIDCConnectJWKSURL,
OIDCConnectScopes: settings.OIDCConnectScopes,
OIDCConnectRedirectURL: settings.OIDCConnectRedirectURL,
OIDCConnectFrontendRedirectURL: settings.OIDCConnectFrontendRedirectURL,
OIDCConnectTokenAuthMethod: settings.OIDCConnectTokenAuthMethod,
OIDCConnectUsePKCE: settings.OIDCConnectUsePKCE,
OIDCConnectValidateIDToken: settings.OIDCConnectValidateIDToken,
OIDCConnectAllowedSigningAlgs: settings.OIDCConnectAllowedSigningAlgs,
OIDCConnectClockSkewSeconds: settings.OIDCConnectClockSkewSeconds,
OIDCConnectRequireEmailVerified: settings.OIDCConnectRequireEmailVerified,
OIDCConnectUserInfoEmailPath: settings.OIDCConnectUserInfoEmailPath,
OIDCConnectUserInfoIDPath: settings.OIDCConnectUserInfoIDPath,
OIDCConnectUserInfoUsernamePath: settings.OIDCConnectUserInfoUsernamePath,
SiteName: settings.SiteName,
SiteLogo: settings.SiteLogo,
SiteSubtitle: settings.SiteSubtitle,
@@ -166,6 +197,30 @@ type UpdateSettingsRequest struct {
LinuxDoConnectClientSecret string `json:"linuxdo_connect_client_secret"`
LinuxDoConnectRedirectURL string `json:"linuxdo_connect_redirect_url"`
// Generic OIDC OAuth 登录
OIDCConnectEnabled bool `json:"oidc_connect_enabled"`
OIDCConnectProviderName string `json:"oidc_connect_provider_name"`
OIDCConnectClientID string `json:"oidc_connect_client_id"`
OIDCConnectClientSecret string `json:"oidc_connect_client_secret"`
OIDCConnectIssuerURL string `json:"oidc_connect_issuer_url"`
OIDCConnectDiscoveryURL string `json:"oidc_connect_discovery_url"`
OIDCConnectAuthorizeURL string `json:"oidc_connect_authorize_url"`
OIDCConnectTokenURL string `json:"oidc_connect_token_url"`
OIDCConnectUserInfoURL string `json:"oidc_connect_userinfo_url"`
OIDCConnectJWKSURL string `json:"oidc_connect_jwks_url"`
OIDCConnectScopes string `json:"oidc_connect_scopes"`
OIDCConnectRedirectURL string `json:"oidc_connect_redirect_url"`
OIDCConnectFrontendRedirectURL string `json:"oidc_connect_frontend_redirect_url"`
OIDCConnectTokenAuthMethod string `json:"oidc_connect_token_auth_method"`
OIDCConnectUsePKCE bool `json:"oidc_connect_use_pkce"`
OIDCConnectValidateIDToken bool `json:"oidc_connect_validate_id_token"`
OIDCConnectAllowedSigningAlgs string `json:"oidc_connect_allowed_signing_algs"`
OIDCConnectClockSkewSeconds int `json:"oidc_connect_clock_skew_seconds"`
OIDCConnectRequireEmailVerified bool `json:"oidc_connect_require_email_verified"`
OIDCConnectUserInfoEmailPath string `json:"oidc_connect_userinfo_email_path"`
OIDCConnectUserInfoIDPath string `json:"oidc_connect_userinfo_id_path"`
OIDCConnectUserInfoUsernamePath string `json:"oidc_connect_userinfo_username_path"`
// OEM设置
SiteName string `json:"site_name"`
SiteLogo string `json:"site_logo"`
@@ -335,6 +390,122 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
}
}
// Generic OIDC 参数验证
if req.OIDCConnectEnabled {
req.OIDCConnectProviderName = strings.TrimSpace(req.OIDCConnectProviderName)
req.OIDCConnectClientID = strings.TrimSpace(req.OIDCConnectClientID)
req.OIDCConnectClientSecret = strings.TrimSpace(req.OIDCConnectClientSecret)
req.OIDCConnectIssuerURL = strings.TrimSpace(req.OIDCConnectIssuerURL)
req.OIDCConnectDiscoveryURL = strings.TrimSpace(req.OIDCConnectDiscoveryURL)
req.OIDCConnectAuthorizeURL = strings.TrimSpace(req.OIDCConnectAuthorizeURL)
req.OIDCConnectTokenURL = strings.TrimSpace(req.OIDCConnectTokenURL)
req.OIDCConnectUserInfoURL = strings.TrimSpace(req.OIDCConnectUserInfoURL)
req.OIDCConnectJWKSURL = strings.TrimSpace(req.OIDCConnectJWKSURL)
req.OIDCConnectScopes = strings.TrimSpace(req.OIDCConnectScopes)
req.OIDCConnectRedirectURL = strings.TrimSpace(req.OIDCConnectRedirectURL)
req.OIDCConnectFrontendRedirectURL = strings.TrimSpace(req.OIDCConnectFrontendRedirectURL)
req.OIDCConnectTokenAuthMethod = strings.ToLower(strings.TrimSpace(req.OIDCConnectTokenAuthMethod))
req.OIDCConnectAllowedSigningAlgs = strings.TrimSpace(req.OIDCConnectAllowedSigningAlgs)
req.OIDCConnectUserInfoEmailPath = strings.TrimSpace(req.OIDCConnectUserInfoEmailPath)
req.OIDCConnectUserInfoIDPath = strings.TrimSpace(req.OIDCConnectUserInfoIDPath)
req.OIDCConnectUserInfoUsernamePath = strings.TrimSpace(req.OIDCConnectUserInfoUsernamePath)
if req.OIDCConnectProviderName == "" {
req.OIDCConnectProviderName = "OIDC"
}
if req.OIDCConnectClientID == "" {
response.BadRequest(c, "OIDC Client ID is required when enabled")
return
}
if req.OIDCConnectIssuerURL == "" {
response.BadRequest(c, "OIDC Issuer URL is required when enabled")
return
}
if err := config.ValidateAbsoluteHTTPURL(req.OIDCConnectIssuerURL); err != nil {
response.BadRequest(c, "OIDC Issuer URL must be an absolute http(s) URL")
return
}
if req.OIDCConnectDiscoveryURL != "" {
if err := config.ValidateAbsoluteHTTPURL(req.OIDCConnectDiscoveryURL); err != nil {
response.BadRequest(c, "OIDC Discovery URL must be an absolute http(s) URL")
return
}
}
if req.OIDCConnectAuthorizeURL != "" {
if err := config.ValidateAbsoluteHTTPURL(req.OIDCConnectAuthorizeURL); err != nil {
response.BadRequest(c, "OIDC Authorize URL must be an absolute http(s) URL")
return
}
}
if req.OIDCConnectTokenURL != "" {
if err := config.ValidateAbsoluteHTTPURL(req.OIDCConnectTokenURL); err != nil {
response.BadRequest(c, "OIDC Token URL must be an absolute http(s) URL")
return
}
}
if req.OIDCConnectUserInfoURL != "" {
if err := config.ValidateAbsoluteHTTPURL(req.OIDCConnectUserInfoURL); err != nil {
response.BadRequest(c, "OIDC UserInfo URL must be an absolute http(s) URL")
return
}
}
if req.OIDCConnectRedirectURL == "" {
response.BadRequest(c, "OIDC Redirect URL is required when enabled")
return
}
if err := config.ValidateAbsoluteHTTPURL(req.OIDCConnectRedirectURL); err != nil {
response.BadRequest(c, "OIDC Redirect URL must be an absolute http(s) URL")
return
}
if req.OIDCConnectFrontendRedirectURL == "" {
response.BadRequest(c, "OIDC Frontend Redirect URL is required when enabled")
return
}
if err := config.ValidateFrontendRedirectURL(req.OIDCConnectFrontendRedirectURL); err != nil {
response.BadRequest(c, "OIDC Frontend Redirect URL is invalid")
return
}
if !scopesContainOpenID(req.OIDCConnectScopes) {
response.BadRequest(c, "OIDC scopes must contain openid")
return
}
switch req.OIDCConnectTokenAuthMethod {
case "", "client_secret_post", "client_secret_basic", "none":
default:
response.BadRequest(c, "OIDC Token Auth Method must be one of client_secret_post/client_secret_basic/none")
return
}
if req.OIDCConnectTokenAuthMethod == "none" && !req.OIDCConnectUsePKCE {
response.BadRequest(c, "OIDC PKCE must be enabled when token_auth_method=none")
return
}
if req.OIDCConnectClockSkewSeconds < 0 || req.OIDCConnectClockSkewSeconds > 600 {
response.BadRequest(c, "OIDC clock skew seconds must be between 0 and 600")
return
}
if req.OIDCConnectValidateIDToken {
if req.OIDCConnectAllowedSigningAlgs == "" {
response.BadRequest(c, "OIDC Allowed Signing Algs is required when validate_id_token=true")
return
}
}
if req.OIDCConnectJWKSURL != "" {
if err := config.ValidateAbsoluteHTTPURL(req.OIDCConnectJWKSURL); err != nil {
response.BadRequest(c, "OIDC JWKS URL must be an absolute http(s) URL")
return
}
}
if req.OIDCConnectTokenAuthMethod == "" || req.OIDCConnectTokenAuthMethod == "client_secret_post" || req.OIDCConnectTokenAuthMethod == "client_secret_basic" {
if req.OIDCConnectClientSecret == "" {
if previousSettings.OIDCConnectClientSecret == "" {
response.BadRequest(c, "OIDC Client Secret is required when enabled")
return
}
req.OIDCConnectClientSecret = previousSettings.OIDCConnectClientSecret
}
}
}
// “购买订阅”页面配置验证
purchaseEnabled := previousSettings.PurchaseSubscriptionEnabled
if req.PurchaseSubscriptionEnabled != nil {
@@ -565,6 +736,28 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
LinuxDoConnectClientID: req.LinuxDoConnectClientID,
LinuxDoConnectClientSecret: req.LinuxDoConnectClientSecret,
LinuxDoConnectRedirectURL: req.LinuxDoConnectRedirectURL,
OIDCConnectEnabled: req.OIDCConnectEnabled,
OIDCConnectProviderName: req.OIDCConnectProviderName,
OIDCConnectClientID: req.OIDCConnectClientID,
OIDCConnectClientSecret: req.OIDCConnectClientSecret,
OIDCConnectIssuerURL: req.OIDCConnectIssuerURL,
OIDCConnectDiscoveryURL: req.OIDCConnectDiscoveryURL,
OIDCConnectAuthorizeURL: req.OIDCConnectAuthorizeURL,
OIDCConnectTokenURL: req.OIDCConnectTokenURL,
OIDCConnectUserInfoURL: req.OIDCConnectUserInfoURL,
OIDCConnectJWKSURL: req.OIDCConnectJWKSURL,
OIDCConnectScopes: req.OIDCConnectScopes,
OIDCConnectRedirectURL: req.OIDCConnectRedirectURL,
OIDCConnectFrontendRedirectURL: req.OIDCConnectFrontendRedirectURL,
OIDCConnectTokenAuthMethod: req.OIDCConnectTokenAuthMethod,
OIDCConnectUsePKCE: req.OIDCConnectUsePKCE,
OIDCConnectValidateIDToken: req.OIDCConnectValidateIDToken,
OIDCConnectAllowedSigningAlgs: req.OIDCConnectAllowedSigningAlgs,
OIDCConnectClockSkewSeconds: req.OIDCConnectClockSkewSeconds,
OIDCConnectRequireEmailVerified: req.OIDCConnectRequireEmailVerified,
OIDCConnectUserInfoEmailPath: req.OIDCConnectUserInfoEmailPath,
OIDCConnectUserInfoIDPath: req.OIDCConnectUserInfoIDPath,
OIDCConnectUserInfoUsernamePath: req.OIDCConnectUserInfoUsernamePath,
SiteName: req.SiteName,
SiteLogo: req.SiteLogo,
SiteSubtitle: req.SiteSubtitle,
@@ -682,6 +875,28 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
LinuxDoConnectClientID: updatedSettings.LinuxDoConnectClientID,
LinuxDoConnectClientSecretConfigured: updatedSettings.LinuxDoConnectClientSecretConfigured,
LinuxDoConnectRedirectURL: updatedSettings.LinuxDoConnectRedirectURL,
OIDCConnectEnabled: updatedSettings.OIDCConnectEnabled,
OIDCConnectProviderName: updatedSettings.OIDCConnectProviderName,
OIDCConnectClientID: updatedSettings.OIDCConnectClientID,
OIDCConnectClientSecretConfigured: updatedSettings.OIDCConnectClientSecretConfigured,
OIDCConnectIssuerURL: updatedSettings.OIDCConnectIssuerURL,
OIDCConnectDiscoveryURL: updatedSettings.OIDCConnectDiscoveryURL,
OIDCConnectAuthorizeURL: updatedSettings.OIDCConnectAuthorizeURL,
OIDCConnectTokenURL: updatedSettings.OIDCConnectTokenURL,
OIDCConnectUserInfoURL: updatedSettings.OIDCConnectUserInfoURL,
OIDCConnectJWKSURL: updatedSettings.OIDCConnectJWKSURL,
OIDCConnectScopes: updatedSettings.OIDCConnectScopes,
OIDCConnectRedirectURL: updatedSettings.OIDCConnectRedirectURL,
OIDCConnectFrontendRedirectURL: updatedSettings.OIDCConnectFrontendRedirectURL,
OIDCConnectTokenAuthMethod: updatedSettings.OIDCConnectTokenAuthMethod,
OIDCConnectUsePKCE: updatedSettings.OIDCConnectUsePKCE,
OIDCConnectValidateIDToken: updatedSettings.OIDCConnectValidateIDToken,
OIDCConnectAllowedSigningAlgs: updatedSettings.OIDCConnectAllowedSigningAlgs,
OIDCConnectClockSkewSeconds: updatedSettings.OIDCConnectClockSkewSeconds,
OIDCConnectRequireEmailVerified: updatedSettings.OIDCConnectRequireEmailVerified,
OIDCConnectUserInfoEmailPath: updatedSettings.OIDCConnectUserInfoEmailPath,
OIDCConnectUserInfoIDPath: updatedSettings.OIDCConnectUserInfoIDPath,
OIDCConnectUserInfoUsernamePath: updatedSettings.OIDCConnectUserInfoUsernamePath,
SiteName: updatedSettings.SiteName,
SiteLogo: updatedSettings.SiteLogo,
SiteSubtitle: updatedSettings.SiteSubtitle,
@@ -802,6 +1017,72 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
if before.LinuxDoConnectRedirectURL != after.LinuxDoConnectRedirectURL {
changed = append(changed, "linuxdo_connect_redirect_url")
}
if before.OIDCConnectEnabled != after.OIDCConnectEnabled {
changed = append(changed, "oidc_connect_enabled")
}
if before.OIDCConnectProviderName != after.OIDCConnectProviderName {
changed = append(changed, "oidc_connect_provider_name")
}
if before.OIDCConnectClientID != after.OIDCConnectClientID {
changed = append(changed, "oidc_connect_client_id")
}
if req.OIDCConnectClientSecret != "" {
changed = append(changed, "oidc_connect_client_secret")
}
if before.OIDCConnectIssuerURL != after.OIDCConnectIssuerURL {
changed = append(changed, "oidc_connect_issuer_url")
}
if before.OIDCConnectDiscoveryURL != after.OIDCConnectDiscoveryURL {
changed = append(changed, "oidc_connect_discovery_url")
}
if before.OIDCConnectAuthorizeURL != after.OIDCConnectAuthorizeURL {
changed = append(changed, "oidc_connect_authorize_url")
}
if before.OIDCConnectTokenURL != after.OIDCConnectTokenURL {
changed = append(changed, "oidc_connect_token_url")
}
if before.OIDCConnectUserInfoURL != after.OIDCConnectUserInfoURL {
changed = append(changed, "oidc_connect_userinfo_url")
}
if before.OIDCConnectJWKSURL != after.OIDCConnectJWKSURL {
changed = append(changed, "oidc_connect_jwks_url")
}
if before.OIDCConnectScopes != after.OIDCConnectScopes {
changed = append(changed, "oidc_connect_scopes")
}
if before.OIDCConnectRedirectURL != after.OIDCConnectRedirectURL {
changed = append(changed, "oidc_connect_redirect_url")
}
if before.OIDCConnectFrontendRedirectURL != after.OIDCConnectFrontendRedirectURL {
changed = append(changed, "oidc_connect_frontend_redirect_url")
}
if before.OIDCConnectTokenAuthMethod != after.OIDCConnectTokenAuthMethod {
changed = append(changed, "oidc_connect_token_auth_method")
}
if before.OIDCConnectUsePKCE != after.OIDCConnectUsePKCE {
changed = append(changed, "oidc_connect_use_pkce")
}
if before.OIDCConnectValidateIDToken != after.OIDCConnectValidateIDToken {
changed = append(changed, "oidc_connect_validate_id_token")
}
if before.OIDCConnectAllowedSigningAlgs != after.OIDCConnectAllowedSigningAlgs {
changed = append(changed, "oidc_connect_allowed_signing_algs")
}
if before.OIDCConnectClockSkewSeconds != after.OIDCConnectClockSkewSeconds {
changed = append(changed, "oidc_connect_clock_skew_seconds")
}
if before.OIDCConnectRequireEmailVerified != after.OIDCConnectRequireEmailVerified {
changed = append(changed, "oidc_connect_require_email_verified")
}
if before.OIDCConnectUserInfoEmailPath != after.OIDCConnectUserInfoEmailPath {
changed = append(changed, "oidc_connect_userinfo_email_path")
}
if before.OIDCConnectUserInfoIDPath != after.OIDCConnectUserInfoIDPath {
changed = append(changed, "oidc_connect_userinfo_id_path")
}
if before.OIDCConnectUserInfoUsernamePath != after.OIDCConnectUserInfoUsernamePath {
changed = append(changed, "oidc_connect_userinfo_username_path")
}
if before.SiteName != after.SiteName {
changed = append(changed, "site_name")
}

View File

@@ -0,0 +1,873 @@
package handler
import (
"context"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rsa"
"crypto/sha256"
"encoding/base64"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"log"
"math/big"
"net/http"
"net/url"
"strconv"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/oauth"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/golang-jwt/jwt/v5"
"github.com/imroc/req/v3"
"github.com/tidwall/gjson"
)
const (
oidcOAuthCookiePath = "/api/v1/auth/oauth/oidc"
oidcOAuthStateCookieName = "oidc_oauth_state"
oidcOAuthVerifierCookie = "oidc_oauth_verifier"
oidcOAuthRedirectCookie = "oidc_oauth_redirect"
oidcOAuthNonceCookie = "oidc_oauth_nonce"
oidcOAuthCookieMaxAgeSec = 10 * 60 // 10 minutes
oidcOAuthDefaultRedirectTo = "/dashboard"
oidcOAuthDefaultFrontendCB = "/auth/oidc/callback"
)
type oidcTokenResponse struct {
AccessToken string `json:"access_token"`
TokenType string `json:"token_type"`
ExpiresIn int64 `json:"expires_in"`
RefreshToken string `json:"refresh_token,omitempty"`
Scope string `json:"scope,omitempty"`
IDToken string `json:"id_token,omitempty"`
}
type oidcTokenExchangeError struct {
StatusCode int
ProviderError string
ProviderDescription string
Body string
}
func (e *oidcTokenExchangeError) Error() string {
if e == nil {
return ""
}
parts := []string{fmt.Sprintf("token exchange status=%d", e.StatusCode)}
if strings.TrimSpace(e.ProviderError) != "" {
parts = append(parts, "error="+strings.TrimSpace(e.ProviderError))
}
if strings.TrimSpace(e.ProviderDescription) != "" {
parts = append(parts, "error_description="+strings.TrimSpace(e.ProviderDescription))
}
return strings.Join(parts, " ")
}
type oidcIDTokenClaims struct {
Email string `json:"email,omitempty"`
EmailVerified *bool `json:"email_verified,omitempty"`
PreferredUsername string `json:"preferred_username,omitempty"`
Name string `json:"name,omitempty"`
Nonce string `json:"nonce,omitempty"`
Azp string `json:"azp,omitempty"`
jwt.RegisteredClaims
}
type oidcUserInfoClaims struct {
Email string
Username string
Subject string
EmailVerified *bool
}
type oidcJWKSet struct {
Keys []oidcJWK `json:"keys"`
}
type oidcJWK struct {
Kty string `json:"kty"`
Kid string `json:"kid"`
Use string `json:"use"`
Alg string `json:"alg"`
N string `json:"n"`
E string `json:"e"`
Crv string `json:"crv"`
X string `json:"x"`
Y string `json:"y"`
}
// OIDCOAuthStart 启动通用 OIDC OAuth 登录流程。
// GET /api/v1/auth/oauth/oidc/start?redirect=/dashboard
func (h *AuthHandler) OIDCOAuthStart(c *gin.Context) {
cfg, err := h.getOIDCOAuthConfig(c.Request.Context())
if err != nil {
response.ErrorFrom(c, err)
return
}
state, err := oauth.GenerateState()
if err != nil {
response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_STATE_GEN_FAILED", "failed to generate oauth state").WithCause(err))
return
}
redirectTo := sanitizeFrontendRedirectPath(c.Query("redirect"))
if redirectTo == "" {
redirectTo = oidcOAuthDefaultRedirectTo
}
secureCookie := isRequestHTTPS(c)
oidcSetCookie(c, oidcOAuthStateCookieName, encodeCookieValue(state), oidcOAuthCookieMaxAgeSec, secureCookie)
oidcSetCookie(c, oidcOAuthRedirectCookie, encodeCookieValue(redirectTo), oidcOAuthCookieMaxAgeSec, secureCookie)
codeChallenge := ""
if cfg.UsePKCE {
verifier, genErr := oauth.GenerateCodeVerifier()
if genErr != nil {
response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_PKCE_GEN_FAILED", "failed to generate pkce verifier").WithCause(genErr))
return
}
codeChallenge = oauth.GenerateCodeChallenge(verifier)
oidcSetCookie(c, oidcOAuthVerifierCookie, encodeCookieValue(verifier), oidcOAuthCookieMaxAgeSec, secureCookie)
}
nonce := ""
if cfg.ValidateIDToken {
nonce, err = oauth.GenerateState()
if err != nil {
response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_NONCE_GEN_FAILED", "failed to generate oauth nonce").WithCause(err))
return
}
oidcSetCookie(c, oidcOAuthNonceCookie, encodeCookieValue(nonce), oidcOAuthCookieMaxAgeSec, secureCookie)
}
redirectURI := strings.TrimSpace(cfg.RedirectURL)
if redirectURI == "" {
response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth redirect url not configured"))
return
}
authURL, err := buildOIDCAuthorizeURL(cfg, state, nonce, codeChallenge, redirectURI)
if err != nil {
response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_BUILD_URL_FAILED", "failed to build oauth authorization url").WithCause(err))
return
}
c.Redirect(http.StatusFound, authURL)
}
// OIDCOAuthCallback 处理 OIDC 回调:校验 id_token、创建/登录用户并重定向到前端。
// GET /api/v1/auth/oauth/oidc/callback?code=...&state=...
func (h *AuthHandler) OIDCOAuthCallback(c *gin.Context) {
cfg, cfgErr := h.getOIDCOAuthConfig(c.Request.Context())
if cfgErr != nil {
response.ErrorFrom(c, cfgErr)
return
}
frontendCallback := strings.TrimSpace(cfg.FrontendRedirectURL)
if frontendCallback == "" {
frontendCallback = oidcOAuthDefaultFrontendCB
}
if providerErr := strings.TrimSpace(c.Query("error")); providerErr != "" {
redirectOAuthError(c, frontendCallback, "provider_error", providerErr, c.Query("error_description"))
return
}
code := strings.TrimSpace(c.Query("code"))
state := strings.TrimSpace(c.Query("state"))
if code == "" || state == "" {
redirectOAuthError(c, frontendCallback, "missing_params", "missing code/state", "")
return
}
secureCookie := isRequestHTTPS(c)
defer func() {
oidcClearCookie(c, oidcOAuthStateCookieName, secureCookie)
oidcClearCookie(c, oidcOAuthVerifierCookie, secureCookie)
oidcClearCookie(c, oidcOAuthRedirectCookie, secureCookie)
oidcClearCookie(c, oidcOAuthNonceCookie, secureCookie)
}()
expectedState, err := readCookieDecoded(c, oidcOAuthStateCookieName)
if err != nil || expectedState == "" || state != expectedState {
redirectOAuthError(c, frontendCallback, "invalid_state", "invalid oauth state", "")
return
}
redirectTo, _ := readCookieDecoded(c, oidcOAuthRedirectCookie)
redirectTo = sanitizeFrontendRedirectPath(redirectTo)
if redirectTo == "" {
redirectTo = oidcOAuthDefaultRedirectTo
}
codeVerifier := ""
if cfg.UsePKCE {
codeVerifier, _ = readCookieDecoded(c, oidcOAuthVerifierCookie)
if codeVerifier == "" {
redirectOAuthError(c, frontendCallback, "missing_verifier", "missing pkce verifier", "")
return
}
}
expectedNonce := ""
if cfg.ValidateIDToken {
expectedNonce, _ = readCookieDecoded(c, oidcOAuthNonceCookie)
if expectedNonce == "" {
redirectOAuthError(c, frontendCallback, "missing_nonce", "missing oauth nonce", "")
return
}
}
redirectURI := strings.TrimSpace(cfg.RedirectURL)
if redirectURI == "" {
redirectOAuthError(c, frontendCallback, "config_error", "oauth redirect url not configured", "")
return
}
tokenResp, err := oidcExchangeCode(c.Request.Context(), cfg, code, redirectURI, codeVerifier)
if err != nil {
description := ""
var exchangeErr *oidcTokenExchangeError
if errors.As(err, &exchangeErr) && exchangeErr != nil {
log.Printf(
"[OIDC OAuth] token exchange failed: status=%d provider_error=%q provider_description=%q body=%s",
exchangeErr.StatusCode,
exchangeErr.ProviderError,
exchangeErr.ProviderDescription,
truncateLogValue(exchangeErr.Body, 2048),
)
description = exchangeErr.Error()
} else {
log.Printf("[OIDC OAuth] token exchange failed: %v", err)
description = err.Error()
}
redirectOAuthError(c, frontendCallback, "token_exchange_failed", "failed to exchange oauth code", singleLine(description))
return
}
if cfg.ValidateIDToken && strings.TrimSpace(tokenResp.IDToken) == "" {
redirectOAuthError(c, frontendCallback, "missing_id_token", "missing id_token", "")
return
}
idClaims, err := oidcParseAndValidateIDToken(c.Request.Context(), cfg, tokenResp.IDToken, expectedNonce)
if err != nil {
log.Printf("[OIDC OAuth] id_token validation failed: %v", err)
redirectOAuthError(c, frontendCallback, "invalid_id_token", "failed to validate id_token", "")
return
}
userInfoClaims, err := oidcFetchUserInfo(c.Request.Context(), cfg, tokenResp)
if err != nil {
log.Printf("[OIDC OAuth] userinfo fetch failed: %v", err)
redirectOAuthError(c, frontendCallback, "userinfo_failed", "failed to fetch user info", "")
return
}
subject := strings.TrimSpace(idClaims.Subject)
if subject == "" {
subject = strings.TrimSpace(userInfoClaims.Subject)
}
if subject == "" {
redirectOAuthError(c, frontendCallback, "missing_subject", "missing subject claim", "")
return
}
issuer := strings.TrimSpace(idClaims.Issuer)
if issuer == "" {
issuer = strings.TrimSpace(cfg.IssuerURL)
}
if issuer == "" {
redirectOAuthError(c, frontendCallback, "missing_issuer", "missing issuer claim", "")
return
}
emailVerified := userInfoClaims.EmailVerified
if emailVerified == nil {
emailVerified = idClaims.EmailVerified
}
if cfg.RequireEmailVerified {
if emailVerified == nil || !*emailVerified {
redirectOAuthError(c, frontendCallback, "email_not_verified", "email is not verified", "")
return
}
}
identityKey := oidcIdentityKey(issuer, subject)
email := oidcSelectLoginEmail(userInfoClaims.Email, idClaims.Email, identityKey)
username := firstNonEmpty(
userInfoClaims.Username,
idClaims.PreferredUsername,
idClaims.Name,
oidcFallbackUsername(subject),
)
// 传入空邀请码;如果需要邀请码,服务层返回 ErrOAuthInvitationRequired
tokenPair, _, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, "")
if err != nil {
if errors.Is(err, service.ErrOAuthInvitationRequired) {
pendingToken, tokenErr := h.authService.CreatePendingOAuthToken(email, username)
if tokenErr != nil {
redirectOAuthError(c, frontendCallback, "login_failed", "service_error", "")
return
}
fragment := url.Values{}
fragment.Set("error", "invitation_required")
fragment.Set("pending_oauth_token", pendingToken)
fragment.Set("redirect", redirectTo)
redirectWithFragment(c, frontendCallback, fragment)
return
}
redirectOAuthError(c, frontendCallback, "login_failed", infraerrors.Reason(err), infraerrors.Message(err))
return
}
fragment := url.Values{}
fragment.Set("access_token", tokenPair.AccessToken)
fragment.Set("refresh_token", tokenPair.RefreshToken)
fragment.Set("expires_in", fmt.Sprintf("%d", tokenPair.ExpiresIn))
fragment.Set("token_type", "Bearer")
fragment.Set("redirect", redirectTo)
redirectWithFragment(c, frontendCallback, fragment)
}
type completeOIDCOAuthRequest struct {
PendingOAuthToken string `json:"pending_oauth_token" binding:"required"`
InvitationCode string `json:"invitation_code" binding:"required"`
}
// CompleteOIDCOAuthRegistration completes a pending OAuth registration by validating
// the invitation code and creating the user account.
// POST /api/v1/auth/oauth/oidc/complete-registration
func (h *AuthHandler) CompleteOIDCOAuthRegistration(c *gin.Context) {
var req completeOIDCOAuthRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "INVALID_REQUEST", "message": err.Error()})
return
}
email, username, err := h.authService.VerifyPendingOAuthToken(req.PendingOAuthToken)
if err != nil {
c.JSON(http.StatusUnauthorized, gin.H{"error": "INVALID_TOKEN", "message": "invalid or expired registration token"})
return
}
tokenPair, _, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode)
if err != nil {
response.ErrorFrom(c, err)
return
}
c.JSON(http.StatusOK, gin.H{
"access_token": tokenPair.AccessToken,
"refresh_token": tokenPair.RefreshToken,
"expires_in": tokenPair.ExpiresIn,
"token_type": "Bearer",
})
}
func (h *AuthHandler) getOIDCOAuthConfig(ctx context.Context) (config.OIDCConnectConfig, error) {
if h != nil && h.settingSvc != nil {
return h.settingSvc.GetOIDCConnectOAuthConfig(ctx)
}
if h == nil || h.cfg == nil {
return config.OIDCConnectConfig{}, infraerrors.ServiceUnavailable("CONFIG_NOT_READY", "config not loaded")
}
if !h.cfg.OIDC.Enabled {
return config.OIDCConnectConfig{}, infraerrors.NotFound("OAUTH_DISABLED", "oauth login is disabled")
}
return h.cfg.OIDC, nil
}
func oidcExchangeCode(
ctx context.Context,
cfg config.OIDCConnectConfig,
code string,
redirectURI string,
codeVerifier string,
) (*oidcTokenResponse, error) {
client := req.C().SetTimeout(30 * time.Second)
form := url.Values{}
form.Set("grant_type", "authorization_code")
form.Set("client_id", cfg.ClientID)
form.Set("code", code)
form.Set("redirect_uri", redirectURI)
if cfg.UsePKCE {
form.Set("code_verifier", codeVerifier)
}
r := client.R().
SetContext(ctx).
SetHeader("Accept", "application/json")
switch strings.ToLower(strings.TrimSpace(cfg.TokenAuthMethod)) {
case "", "client_secret_post":
form.Set("client_secret", cfg.ClientSecret)
case "client_secret_basic":
r.SetBasicAuth(cfg.ClientID, cfg.ClientSecret)
case "none":
default:
return nil, fmt.Errorf("unsupported token_auth_method: %s", cfg.TokenAuthMethod)
}
resp, err := r.SetFormDataFromValues(form).Post(cfg.TokenURL)
if err != nil {
return nil, fmt.Errorf("request token: %w", err)
}
body := strings.TrimSpace(resp.String())
if !resp.IsSuccessState() {
providerErr, providerDesc := parseOAuthProviderError(body)
return nil, &oidcTokenExchangeError{
StatusCode: resp.StatusCode,
ProviderError: providerErr,
ProviderDescription: providerDesc,
Body: body,
}
}
tokenResp, ok := oidcParseTokenResponse(body)
if !ok {
return nil, &oidcTokenExchangeError{StatusCode: resp.StatusCode, Body: body}
}
if strings.TrimSpace(tokenResp.TokenType) == "" {
tokenResp.TokenType = "Bearer"
}
if strings.TrimSpace(tokenResp.AccessToken) == "" && strings.TrimSpace(tokenResp.IDToken) == "" {
return nil, &oidcTokenExchangeError{StatusCode: resp.StatusCode, Body: body}
}
return tokenResp, nil
}
func oidcParseTokenResponse(body string) (*oidcTokenResponse, bool) {
body = strings.TrimSpace(body)
if body == "" {
return nil, false
}
accessToken := strings.TrimSpace(getGJSON(body, "access_token"))
idToken := strings.TrimSpace(getGJSON(body, "id_token"))
if accessToken != "" || idToken != "" {
tokenType := strings.TrimSpace(getGJSON(body, "token_type"))
refreshToken := strings.TrimSpace(getGJSON(body, "refresh_token"))
scope := strings.TrimSpace(getGJSON(body, "scope"))
expiresIn := gjson.Get(body, "expires_in").Int()
return &oidcTokenResponse{
AccessToken: accessToken,
TokenType: tokenType,
ExpiresIn: expiresIn,
RefreshToken: refreshToken,
Scope: scope,
IDToken: idToken,
}, true
}
values, err := url.ParseQuery(body)
if err != nil {
return nil, false
}
accessToken = strings.TrimSpace(values.Get("access_token"))
idToken = strings.TrimSpace(values.Get("id_token"))
if accessToken == "" && idToken == "" {
return nil, false
}
expiresIn := int64(0)
if raw := strings.TrimSpace(values.Get("expires_in")); raw != "" {
if v, parseErr := strconv.ParseInt(raw, 10, 64); parseErr == nil {
expiresIn = v
}
}
return &oidcTokenResponse{
AccessToken: accessToken,
TokenType: strings.TrimSpace(values.Get("token_type")),
ExpiresIn: expiresIn,
RefreshToken: strings.TrimSpace(values.Get("refresh_token")),
Scope: strings.TrimSpace(values.Get("scope")),
IDToken: idToken,
}, true
}
func oidcFetchUserInfo(
ctx context.Context,
cfg config.OIDCConnectConfig,
token *oidcTokenResponse,
) (*oidcUserInfoClaims, error) {
if strings.TrimSpace(cfg.UserInfoURL) == "" {
return &oidcUserInfoClaims{}, nil
}
if token == nil || strings.TrimSpace(token.AccessToken) == "" {
return nil, errors.New("missing access_token for userinfo request")
}
client := req.C().SetTimeout(30 * time.Second)
authorization, err := buildBearerAuthorization(token.TokenType, token.AccessToken)
if err != nil {
return nil, fmt.Errorf("invalid token for userinfo request: %w", err)
}
resp, err := client.R().
SetContext(ctx).
SetHeader("Accept", "application/json").
SetHeader("Authorization", authorization).
Get(cfg.UserInfoURL)
if err != nil {
return nil, fmt.Errorf("request userinfo: %w", err)
}
if !resp.IsSuccessState() {
return nil, fmt.Errorf("userinfo status=%d", resp.StatusCode)
}
return oidcParseUserInfo(resp.String(), cfg), nil
}
func oidcParseUserInfo(body string, cfg config.OIDCConnectConfig) *oidcUserInfoClaims {
claims := &oidcUserInfoClaims{}
claims.Email = firstNonEmpty(
getGJSON(body, cfg.UserInfoEmailPath),
getGJSON(body, "email"),
getGJSON(body, "user.email"),
getGJSON(body, "data.email"),
getGJSON(body, "attributes.email"),
)
claims.Username = firstNonEmpty(
getGJSON(body, cfg.UserInfoUsernamePath),
getGJSON(body, "preferred_username"),
getGJSON(body, "username"),
getGJSON(body, "name"),
getGJSON(body, "user.username"),
getGJSON(body, "user.name"),
)
claims.Subject = firstNonEmpty(
getGJSON(body, cfg.UserInfoIDPath),
getGJSON(body, "sub"),
getGJSON(body, "id"),
getGJSON(body, "user_id"),
getGJSON(body, "uid"),
getGJSON(body, "user.id"),
)
if verified, ok := getGJSONBool(body, "email_verified"); ok {
claims.EmailVerified = &verified
}
claims.Email = strings.TrimSpace(claims.Email)
claims.Username = strings.TrimSpace(claims.Username)
claims.Subject = strings.TrimSpace(claims.Subject)
return claims
}
func getGJSONBool(body string, path string) (bool, bool) {
path = strings.TrimSpace(path)
if path == "" {
return false, false
}
res := gjson.Get(body, path)
if !res.Exists() {
return false, false
}
return res.Bool(), true
}
func buildOIDCAuthorizeURL(cfg config.OIDCConnectConfig, state, nonce, codeChallenge, redirectURI string) (string, error) {
u, err := url.Parse(cfg.AuthorizeURL)
if err != nil {
return "", fmt.Errorf("parse authorize_url: %w", err)
}
q := u.Query()
q.Set("response_type", "code")
q.Set("client_id", cfg.ClientID)
q.Set("redirect_uri", redirectURI)
if strings.TrimSpace(cfg.Scopes) != "" {
q.Set("scope", cfg.Scopes)
}
q.Set("state", state)
if strings.TrimSpace(nonce) != "" {
q.Set("nonce", nonce)
}
if cfg.UsePKCE {
q.Set("code_challenge", codeChallenge)
q.Set("code_challenge_method", "S256")
}
u.RawQuery = q.Encode()
return u.String(), nil
}
func oidcParseAndValidateIDToken(ctx context.Context, cfg config.OIDCConnectConfig, idToken string, expectedNonce string) (*oidcIDTokenClaims, error) {
idToken = strings.TrimSpace(idToken)
if idToken == "" {
return nil, errors.New("missing id_token")
}
allowed := oidcAllowedSigningAlgs(cfg.AllowedSigningAlgs)
if len(allowed) == 0 {
return nil, errors.New("empty allowed signing algorithms")
}
jwks, err := oidcFetchJWKSet(ctx, cfg.JWKSURL)
if err != nil {
return nil, err
}
leeway := time.Duration(cfg.ClockSkewSeconds) * time.Second
claims := &oidcIDTokenClaims{}
parsed, err := jwt.ParseWithClaims(
idToken,
claims,
func(token *jwt.Token) (any, error) {
alg := strings.TrimSpace(token.Method.Alg())
if !containsString(allowed, alg) {
return nil, fmt.Errorf("unexpected signing algorithm: %s", alg)
}
kid, _ := token.Header["kid"].(string)
return oidcFindPublicKey(jwks, strings.TrimSpace(kid), alg)
},
jwt.WithValidMethods(allowed),
jwt.WithAudience(cfg.ClientID),
jwt.WithIssuer(cfg.IssuerURL),
jwt.WithLeeway(leeway),
)
if err != nil {
return nil, err
}
if !parsed.Valid {
return nil, errors.New("id_token invalid")
}
if strings.TrimSpace(claims.Subject) == "" {
return nil, errors.New("id_token missing sub")
}
if expectedNonce != "" && strings.TrimSpace(claims.Nonce) != strings.TrimSpace(expectedNonce) {
return nil, errors.New("id_token nonce mismatch")
}
if len(claims.Audience) > 1 {
if strings.TrimSpace(claims.Azp) == "" || strings.TrimSpace(claims.Azp) != strings.TrimSpace(cfg.ClientID) {
return nil, errors.New("id_token azp mismatch")
}
}
return claims, nil
}
func oidcAllowedSigningAlgs(raw string) []string {
if strings.TrimSpace(raw) == "" {
return []string{"RS256", "ES256", "PS256"}
}
seen := make(map[string]struct{})
out := make([]string, 0, 4)
for _, part := range strings.Split(raw, ",") {
alg := strings.ToUpper(strings.TrimSpace(part))
if alg == "" {
continue
}
if _, ok := seen[alg]; ok {
continue
}
seen[alg] = struct{}{}
out = append(out, alg)
}
return out
}
func oidcFetchJWKSet(ctx context.Context, jwksURL string) (*oidcJWKSet, error) {
jwksURL = strings.TrimSpace(jwksURL)
if jwksURL == "" {
return nil, errors.New("missing jwks_url")
}
resp, err := req.C().
SetTimeout(30*time.Second).
R().
SetContext(ctx).
SetHeader("Accept", "application/json").
Get(jwksURL)
if err != nil {
return nil, fmt.Errorf("request jwks: %w", err)
}
if !resp.IsSuccessState() {
return nil, fmt.Errorf("jwks status=%d", resp.StatusCode)
}
set := &oidcJWKSet{}
if err := json.Unmarshal(resp.Bytes(), set); err != nil {
return nil, fmt.Errorf("parse jwks: %w", err)
}
if len(set.Keys) == 0 {
return nil, errors.New("jwks empty keys")
}
return set, nil
}
func oidcFindPublicKey(set *oidcJWKSet, kid, alg string) (any, error) {
if set == nil {
return nil, errors.New("jwks not loaded")
}
alg = strings.ToUpper(strings.TrimSpace(alg))
kid = strings.TrimSpace(kid)
var lastErr error
for i := range set.Keys {
k := set.Keys[i]
if strings.TrimSpace(k.Use) != "" && !strings.EqualFold(strings.TrimSpace(k.Use), "sig") {
continue
}
if kid != "" && strings.TrimSpace(k.Kid) != kid {
continue
}
if strings.TrimSpace(k.Alg) != "" && !strings.EqualFold(strings.TrimSpace(k.Alg), alg) {
continue
}
pk, err := k.publicKey()
if err != nil {
lastErr = err
continue
}
if pk != nil {
return pk, nil
}
}
if lastErr != nil {
return nil, lastErr
}
if kid != "" {
return nil, fmt.Errorf("jwk not found for kid=%s", kid)
}
return nil, errors.New("jwk not found")
}
func (k oidcJWK) publicKey() (any, error) {
switch strings.ToUpper(strings.TrimSpace(k.Kty)) {
case "RSA":
n, err := decodeBase64URLBigInt(k.N)
if err != nil {
return nil, fmt.Errorf("decode rsa n: %w", err)
}
eBytes, err := base64.RawURLEncoding.DecodeString(strings.TrimSpace(k.E))
if err != nil {
return nil, fmt.Errorf("decode rsa e: %w", err)
}
if len(eBytes) == 0 {
return nil, errors.New("empty rsa e")
}
e := 0
for _, b := range eBytes {
e = (e << 8) | int(b)
}
if e <= 0 {
return nil, errors.New("invalid rsa exponent")
}
if n.Sign() <= 0 {
return nil, errors.New("invalid rsa modulus")
}
return &rsa.PublicKey{N: n, E: e}, nil
case "EC":
var curve elliptic.Curve
switch strings.TrimSpace(k.Crv) {
case "P-256":
curve = elliptic.P256()
case "P-384":
curve = elliptic.P384()
case "P-521":
curve = elliptic.P521()
default:
return nil, fmt.Errorf("unsupported ec curve: %s", k.Crv)
}
x, err := decodeBase64URLBigInt(k.X)
if err != nil {
return nil, fmt.Errorf("decode ec x: %w", err)
}
y, err := decodeBase64URLBigInt(k.Y)
if err != nil {
return nil, fmt.Errorf("decode ec y: %w", err)
}
if !curve.IsOnCurve(x, y) {
return nil, errors.New("ec point is not on curve")
}
return &ecdsa.PublicKey{Curve: curve, X: x, Y: y}, nil
default:
return nil, fmt.Errorf("unsupported jwk kty: %s", k.Kty)
}
}
func decodeBase64URLBigInt(raw string) (*big.Int, error) {
buf, err := base64.RawURLEncoding.DecodeString(strings.TrimSpace(raw))
if err != nil {
return nil, err
}
if len(buf) == 0 {
return nil, errors.New("empty value")
}
return new(big.Int).SetBytes(buf), nil
}
func containsString(values []string, target string) bool {
target = strings.TrimSpace(target)
for _, v := range values {
if strings.EqualFold(strings.TrimSpace(v), target) {
return true
}
}
return false
}
func oidcIdentityKey(issuer, subject string) string {
issuer = strings.TrimSpace(strings.ToLower(issuer))
subject = strings.TrimSpace(subject)
return issuer + "\x1f" + subject
}
func oidcSyntheticEmailFromIdentityKey(identityKey string) string {
identityKey = strings.TrimSpace(identityKey)
if identityKey == "" {
return ""
}
sum := sha256.Sum256([]byte(identityKey))
return "oidc-" + hex.EncodeToString(sum[:16]) + service.OIDCConnectSyntheticEmailDomain
}
func oidcSelectLoginEmail(userInfoEmail, idTokenEmail, identityKey string) string {
email := strings.TrimSpace(firstNonEmpty(userInfoEmail, idTokenEmail))
if email != "" {
return email
}
return oidcSyntheticEmailFromIdentityKey(identityKey)
}
func oidcFallbackUsername(subject string) string {
subject = strings.TrimSpace(subject)
if subject == "" {
return "oidc_user"
}
sum := sha256.Sum256([]byte(subject))
return "oidc_" + hex.EncodeToString(sum[:])[:12]
}
func oidcSetCookie(c *gin.Context, name, value string, maxAgeSec int, secure bool) {
http.SetCookie(c.Writer, &http.Cookie{
Name: name,
Value: value,
Path: oidcOAuthCookiePath,
MaxAge: maxAgeSec,
HttpOnly: true,
Secure: secure,
SameSite: http.SameSiteLaxMode,
})
}
func oidcClearCookie(c *gin.Context, name string, secure bool) {
http.SetCookie(c.Writer, &http.Cookie{
Name: name,
Value: "",
Path: oidcOAuthCookiePath,
MaxAge: -1,
HttpOnly: true,
Secure: secure,
SameSite: http.SameSiteLaxMode,
})
}

View File

@@ -0,0 +1,120 @@
package handler
import (
"context"
"crypto/rand"
"crypto/rsa"
"encoding/base64"
"encoding/json"
"math/big"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/golang-jwt/jwt/v5"
"github.com/stretchr/testify/require"
)
func TestOIDCSyntheticEmailStableAndDistinct(t *testing.T) {
k1 := oidcIdentityKey("https://issuer.example.com", "subject-a")
k2 := oidcIdentityKey("https://issuer.example.com", "subject-b")
e1 := oidcSyntheticEmailFromIdentityKey(k1)
e1Again := oidcSyntheticEmailFromIdentityKey(k1)
e2 := oidcSyntheticEmailFromIdentityKey(k2)
require.Equal(t, e1, e1Again)
require.NotEqual(t, e1, e2)
require.Contains(t, e1, "@oidc-connect.invalid")
}
func TestOIDCSelectLoginEmailPrefersRealEmail(t *testing.T) {
identityKey := oidcIdentityKey("https://issuer.example.com", "subject-a")
email := oidcSelectLoginEmail("user@example.com", "idtoken@example.com", identityKey)
require.Equal(t, "user@example.com", email)
email = oidcSelectLoginEmail("", "idtoken@example.com", identityKey)
require.Equal(t, "idtoken@example.com", email)
email = oidcSelectLoginEmail("", "", identityKey)
require.Contains(t, email, "@oidc-connect.invalid")
require.Equal(t, oidcSyntheticEmailFromIdentityKey(identityKey), email)
}
func TestBuildOIDCAuthorizeURLIncludesNonceAndPKCE(t *testing.T) {
cfg := config.OIDCConnectConfig{
AuthorizeURL: "https://issuer.example.com/auth",
ClientID: "cid",
Scopes: "openid email profile",
UsePKCE: true,
}
u, err := buildOIDCAuthorizeURL(cfg, "state123", "nonce123", "challenge123", "https://app.example.com/callback")
require.NoError(t, err)
require.Contains(t, u, "nonce=nonce123")
require.Contains(t, u, "code_challenge=challenge123")
require.Contains(t, u, "code_challenge_method=S256")
require.Contains(t, u, "scope=openid+email+profile")
}
func TestOIDCParseAndValidateIDToken(t *testing.T) {
priv, err := rsa.GenerateKey(rand.Reader, 2048)
require.NoError(t, err)
kid := "kid-1"
jwks := oidcJWKSet{Keys: []oidcJWK{buildRSAJWK(kid, &priv.PublicKey)}}
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
require.NoError(t, json.NewEncoder(w).Encode(jwks))
}))
defer srv.Close()
now := time.Now()
claims := oidcIDTokenClaims{
Nonce: "nonce-ok",
Azp: "client-1",
RegisteredClaims: jwt.RegisteredClaims{
Issuer: "https://issuer.example.com",
Subject: "subject-1",
Audience: jwt.ClaimStrings{"client-1", "another-aud"},
IssuedAt: jwt.NewNumericDate(now),
NotBefore: jwt.NewNumericDate(now.Add(-30 * time.Second)),
ExpiresAt: jwt.NewNumericDate(now.Add(5 * time.Minute)),
},
}
tok := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
tok.Header["kid"] = kid
signed, err := tok.SignedString(priv)
require.NoError(t, err)
cfg := config.OIDCConnectConfig{
ClientID: "client-1",
IssuerURL: "https://issuer.example.com",
JWKSURL: srv.URL,
AllowedSigningAlgs: "RS256",
ClockSkewSeconds: 120,
}
parsed, err := oidcParseAndValidateIDToken(context.Background(), cfg, signed, "nonce-ok")
require.NoError(t, err)
require.Equal(t, "subject-1", parsed.Subject)
require.Equal(t, "https://issuer.example.com", parsed.Issuer)
_, err = oidcParseAndValidateIDToken(context.Background(), cfg, signed, "bad-nonce")
require.Error(t, err)
}
func buildRSAJWK(kid string, pub *rsa.PublicKey) oidcJWK {
n := base64.RawURLEncoding.EncodeToString(pub.N.Bytes())
e := base64.RawURLEncoding.EncodeToString(big.NewInt(int64(pub.E)).Bytes())
return oidcJWK{
Kty: "RSA",
Kid: kid,
Use: "sig",
Alg: "RS256",
N: n,
E: e,
}
}

View File

@@ -133,16 +133,17 @@ func GroupFromServiceAdmin(g *service.Group) *AdminGroup {
return nil
}
out := &AdminGroup{
Group: groupFromServiceBase(g),
ModelRouting: g.ModelRouting,
ModelRoutingEnabled: g.ModelRoutingEnabled,
MCPXMLInject: g.MCPXMLInject,
DefaultMappedModel: g.DefaultMappedModel,
SupportedModelScopes: g.SupportedModelScopes,
AccountCount: g.AccountCount,
ActiveAccountCount: g.ActiveAccountCount,
RateLimitedAccountCount: g.RateLimitedAccountCount,
SortOrder: g.SortOrder,
Group: groupFromServiceBase(g),
ModelRouting: g.ModelRouting,
ModelRoutingEnabled: g.ModelRoutingEnabled,
MCPXMLInject: g.MCPXMLInject,
DefaultMappedModel: g.DefaultMappedModel,
MessagesDispatchModelConfig: g.MessagesDispatchModelConfig,
SupportedModelScopes: g.SupportedModelScopes,
AccountCount: g.AccountCount,
ActiveAccountCount: g.ActiveAccountCount,
RateLimitedAccountCount: g.RateLimitedAccountCount,
SortOrder: g.SortOrder,
}
if len(g.AccountGroups) > 0 {
out.AccountGroups = make([]AccountGroup, 0, len(g.AccountGroups))

View File

@@ -51,6 +51,29 @@ type SystemSettings struct {
LinuxDoConnectClientSecretConfigured bool `json:"linuxdo_connect_client_secret_configured"`
LinuxDoConnectRedirectURL string `json:"linuxdo_connect_redirect_url"`
OIDCConnectEnabled bool `json:"oidc_connect_enabled"`
OIDCConnectProviderName string `json:"oidc_connect_provider_name"`
OIDCConnectClientID string `json:"oidc_connect_client_id"`
OIDCConnectClientSecretConfigured bool `json:"oidc_connect_client_secret_configured"`
OIDCConnectIssuerURL string `json:"oidc_connect_issuer_url"`
OIDCConnectDiscoveryURL string `json:"oidc_connect_discovery_url"`
OIDCConnectAuthorizeURL string `json:"oidc_connect_authorize_url"`
OIDCConnectTokenURL string `json:"oidc_connect_token_url"`
OIDCConnectUserInfoURL string `json:"oidc_connect_userinfo_url"`
OIDCConnectJWKSURL string `json:"oidc_connect_jwks_url"`
OIDCConnectScopes string `json:"oidc_connect_scopes"`
OIDCConnectRedirectURL string `json:"oidc_connect_redirect_url"`
OIDCConnectFrontendRedirectURL string `json:"oidc_connect_frontend_redirect_url"`
OIDCConnectTokenAuthMethod string `json:"oidc_connect_token_auth_method"`
OIDCConnectUsePKCE bool `json:"oidc_connect_use_pkce"`
OIDCConnectValidateIDToken bool `json:"oidc_connect_validate_id_token"`
OIDCConnectAllowedSigningAlgs string `json:"oidc_connect_allowed_signing_algs"`
OIDCConnectClockSkewSeconds int `json:"oidc_connect_clock_skew_seconds"`
OIDCConnectRequireEmailVerified bool `json:"oidc_connect_require_email_verified"`
OIDCConnectUserInfoEmailPath string `json:"oidc_connect_userinfo_email_path"`
OIDCConnectUserInfoIDPath string `json:"oidc_connect_userinfo_id_path"`
OIDCConnectUserInfoUsernamePath string `json:"oidc_connect_userinfo_username_path"`
SiteName string `json:"site_name"`
SiteLogo string `json:"site_logo"`
SiteSubtitle string `json:"site_subtitle"`
@@ -132,6 +155,9 @@ type PublicSettings struct {
CustomMenuItems []CustomMenuItem `json:"custom_menu_items"`
CustomEndpoints []CustomEndpoint `json:"custom_endpoints"`
LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"`
OIDCOAuthEnabled bool `json:"oidc_oauth_enabled"`
OIDCOAuthProviderName string `json:"oidc_oauth_provider_name"`
SoraClientEnabled bool `json:"sora_client_enabled"`
BackendModeEnabled bool `json:"backend_mode_enabled"`
Version string `json:"version"`
}

View File

@@ -1,6 +1,10 @@
package dto
import "time"
import (
"time"
"github.com/Wei-Shaw/sub2api/internal/domain"
)
type User struct {
ID int64 `json:"id"`
@@ -112,7 +116,8 @@ type AdminGroup struct {
MCPXMLInject bool `json:"mcp_xml_inject"`
// OpenAI Messages 调度配置(仅 openai 平台使用)
DefaultMappedModel string `json:"default_mapped_model"`
DefaultMappedModel string `json:"default_mapped_model"`
MessagesDispatchModelConfig domain.OpenAIMessagesDispatchModelConfig `json:"messages_dispatch_model_config"`
// 支持的模型系列(仅 antigravity 平台使用)
SupportedModelScopes []string `json:"supported_model_scopes"`

View File

@@ -47,6 +47,13 @@ func resolveOpenAIForwardDefaultMappedModel(apiKey *service.APIKey, fallbackMode
return strings.TrimSpace(apiKey.Group.DefaultMappedModel)
}
func resolveOpenAIMessagesDispatchMappedModel(apiKey *service.APIKey, requestedModel string) string {
if apiKey == nil || apiKey.Group == nil {
return ""
}
return strings.TrimSpace(apiKey.Group.ResolveMessagesDispatchModel(requestedModel))
}
// NewOpenAIGatewayHandler creates a new OpenAIGatewayHandler
func NewOpenAIGatewayHandler(
gatewayService *service.OpenAIGatewayService,
@@ -551,6 +558,7 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
}
reqModel := modelResult.String()
routingModel := service.NormalizeOpenAICompatRequestedModel(reqModel)
preferredMappedModel := resolveOpenAIMessagesDispatchMappedModel(apiKey, reqModel)
reqStream := gjson.GetBytes(body, "stream").Bool()
reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", reqStream))
@@ -609,17 +617,20 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
failedAccountIDs := make(map[int64]struct{})
sameAccountRetryCount := make(map[int64]int)
var lastFailoverErr *service.UpstreamFailoverError
effectiveMappedModel := preferredMappedModel
for {
// 清除上一次迭代的降级模型标记,避免残留影响本次迭代
c.Set("openai_messages_fallback_model", "")
currentRoutingModel := routingModel
if effectiveMappedModel != "" {
currentRoutingModel = effectiveMappedModel
}
reqLog.Debug("openai_messages.account_selecting", zap.Int("excluded_account_count", len(failedAccountIDs)))
selection, scheduleDecision, err := h.gatewayService.SelectAccountWithScheduler(
c.Request.Context(),
apiKey.GroupID,
"", // no previous_response_id
sessionHash,
routingModel,
currentRoutingModel,
failedAccountIDs,
service.OpenAIUpstreamTransportAny,
)
@@ -628,29 +639,7 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
zap.Error(err),
zap.Int("excluded_account_count", len(failedAccountIDs)),
)
// 首次调度失败 + 有默认映射模型 → 用默认模型重试
if len(failedAccountIDs) == 0 {
defaultModel := ""
if apiKey.Group != nil {
defaultModel = apiKey.Group.DefaultMappedModel
}
if defaultModel != "" && defaultModel != routingModel {
reqLog.Info("openai_messages.fallback_to_default_model",
zap.String("default_mapped_model", defaultModel),
)
selection, scheduleDecision, err = h.gatewayService.SelectAccountWithScheduler(
c.Request.Context(),
apiKey.GroupID,
"",
sessionHash,
defaultModel,
failedAccountIDs,
service.OpenAIUpstreamTransportAny,
)
if err == nil && selection != nil {
c.Set("openai_messages_fallback_model", defaultModel)
}
}
if err != nil {
h.anthropicStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "Service temporarily unavailable", streamStarted)
return
@@ -682,9 +671,7 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
service.SetOpsLatencyMs(c, service.OpsRoutingLatencyMsKey, time.Since(routingStart).Milliseconds())
forwardStart := time.Now()
// Forward 层需要始终拿到 group 默认映射模型,这样未命中账号级映射的
// Claude 兼容模型才不会在后续 Codex 规范化中意外退化到 gpt-5.1。
defaultMappedModel := resolveOpenAIForwardDefaultMappedModel(apiKey, c.GetString("openai_messages_fallback_model"))
defaultMappedModel := strings.TrimSpace(effectiveMappedModel)
// 应用渠道模型映射到请求体
forwardBody := body
if channelMappingMsg.Mapped {

View File

@@ -360,7 +360,7 @@ func TestResolveOpenAIForwardDefaultMappedModel(t *testing.T) {
require.Equal(t, "gpt-5.2", resolveOpenAIForwardDefaultMappedModel(apiKey, " gpt-5.2 "))
})
t.Run("uses_group_default_on_normal_path", func(t *testing.T) {
t.Run("uses_group_default_when_explicit_fallback_absent", func(t *testing.T) {
apiKey := &service.APIKey{
Group: &service.Group{DefaultMappedModel: "gpt-5.4"},
}
@@ -376,6 +376,45 @@ func TestResolveOpenAIForwardDefaultMappedModel(t *testing.T) {
})
}
func TestResolveOpenAIMessagesDispatchMappedModel(t *testing.T) {
t.Run("exact_claude_model_override_wins", func(t *testing.T) {
apiKey := &service.APIKey{
Group: &service.Group{
MessagesDispatchModelConfig: service.OpenAIMessagesDispatchModelConfig{
SonnetMappedModel: "gpt-5.2",
ExactModelMappings: map[string]string{
"claude-sonnet-4-5-20250929": "gpt-5.4-mini-high",
},
},
},
}
require.Equal(t, "gpt-5.4-mini", resolveOpenAIMessagesDispatchMappedModel(apiKey, "claude-sonnet-4-5-20250929"))
})
t.Run("uses_family_default_when_no_override", func(t *testing.T) {
apiKey := &service.APIKey{Group: &service.Group{}}
require.Equal(t, "gpt-5.4", resolveOpenAIMessagesDispatchMappedModel(apiKey, "claude-opus-4-6"))
require.Equal(t, "gpt-5.3-codex", resolveOpenAIMessagesDispatchMappedModel(apiKey, "claude-sonnet-4-5-20250929"))
require.Equal(t, "gpt-5.4-mini", resolveOpenAIMessagesDispatchMappedModel(apiKey, "claude-haiku-4-5-20251001"))
})
t.Run("returns_empty_for_non_claude_or_missing_group", func(t *testing.T) {
require.Empty(t, resolveOpenAIMessagesDispatchMappedModel(nil, "claude-sonnet-4-5-20250929"))
require.Empty(t, resolveOpenAIMessagesDispatchMappedModel(&service.APIKey{}, "claude-sonnet-4-5-20250929"))
require.Empty(t, resolveOpenAIMessagesDispatchMappedModel(&service.APIKey{Group: &service.Group{}}, "gpt-5.4"))
})
t.Run("does_not_fall_back_to_group_default_mapped_model", func(t *testing.T) {
apiKey := &service.APIKey{
Group: &service.Group{
DefaultMappedModel: "gpt-5.4",
},
}
require.Empty(t, resolveOpenAIMessagesDispatchMappedModel(apiKey, "gpt-5.4"))
require.Equal(t, "gpt-5.3-codex", resolveOpenAIMessagesDispatchMappedModel(apiKey, "claude-sonnet-4-5-20250929"))
})
}
func TestOpenAIResponses_MissingDependencies_ReturnsServiceUnavailable(t *testing.T) {
gin.SetMode(gin.TestMode)

View File

@@ -54,6 +54,8 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) {
CustomMenuItems: dto.ParseUserVisibleMenuItems(settings.CustomMenuItems),
CustomEndpoints: dto.ParseCustomEndpoints(settings.CustomEndpoints),
LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled,
OIDCOAuthEnabled: settings.OIDCOAuthEnabled,
OIDCOAuthProviderName: settings.OIDCOAuthProviderName,
BackendModeEnabled: settings.BackendModeEnabled,
Version: h.version,
})

View File

@@ -28,7 +28,7 @@ type AnthropicRequest struct {
// AnthropicOutputConfig controls output generation parameters.
type AnthropicOutputConfig struct {
Effort string `json:"effort,omitempty"` // "low" | "medium" | "high"
Effort string `json:"effort,omitempty"` // "low" | "medium" | "high" | "max"
}
// AnthropicThinking configures extended thinking in the Anthropic API.
@@ -167,7 +167,7 @@ type ResponsesRequest struct {
// ResponsesReasoning configures reasoning effort in the Responses API.
type ResponsesReasoning struct {
Effort string `json:"effort"` // "low" | "medium" | "high"
Effort string `json:"effort"` // "low" | "medium" | "high" | "xhigh"
Summary string `json:"summary,omitempty"` // "auto" | "concise" | "detailed"
}
@@ -345,7 +345,7 @@ type ChatCompletionsRequest struct {
StreamOptions *ChatStreamOptions `json:"stream_options,omitempty"`
Tools []ChatTool `json:"tools,omitempty"`
ToolChoice json.RawMessage `json:"tool_choice,omitempty"`
ReasoningEffort string `json:"reasoning_effort,omitempty"` // "low" | "medium" | "high"
ReasoningEffort string `json:"reasoning_effort,omitempty"` // "low" | "medium" | "high" | "xhigh"
ServiceTier string `json:"service_tier,omitempty"`
Stop json.RawMessage `json:"stop,omitempty"` // string or []string

View File

@@ -61,7 +61,8 @@ func (r *groupRepository) Create(ctx context.Context, groupIn *service.Group) er
SetAllowMessagesDispatch(groupIn.AllowMessagesDispatch).
SetRequireOauthOnly(groupIn.RequireOAuthOnly).
SetRequirePrivacySet(groupIn.RequirePrivacySet).
SetDefaultMappedModel(groupIn.DefaultMappedModel)
SetDefaultMappedModel(groupIn.DefaultMappedModel).
SetMessagesDispatchModelConfig(groupIn.MessagesDispatchModelConfig)
// 设置模型路由配置
if groupIn.ModelRouting != nil {
@@ -127,7 +128,8 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er
SetAllowMessagesDispatch(groupIn.AllowMessagesDispatch).
SetRequireOauthOnly(groupIn.RequireOAuthOnly).
SetRequirePrivacySet(groupIn.RequirePrivacySet).
SetDefaultMappedModel(groupIn.DefaultMappedModel)
SetDefaultMappedModel(groupIn.DefaultMappedModel).
SetMessagesDispatchModelConfig(groupIn.MessagesDispatchModelConfig)
// 显式处理可空字段nil 需要 clear非 nil 需要 set。
if groupIn.DailyLimitUSD != nil {

View File

@@ -378,6 +378,7 @@ func buildSchedulerMetadataAccount(account service.Account) service.Account {
Platform: account.Platform,
Type: account.Type,
Concurrency: account.Concurrency,
LoadFactor: account.LoadFactor,
Priority: account.Priority,
RateMultiplier: account.RateMultiplier,
Status: account.Status,

View File

@@ -462,6 +462,28 @@ func TestAPIContracts(t *testing.T) {
service.SettingKeyTurnstileSiteKey: "site-key",
service.SettingKeyTurnstileSecretKey: "secret-key",
service.SettingKeyOIDCConnectEnabled: "false",
service.SettingKeyOIDCConnectProviderName: "OIDC",
service.SettingKeyOIDCConnectClientID: "",
service.SettingKeyOIDCConnectIssuerURL: "",
service.SettingKeyOIDCConnectDiscoveryURL: "",
service.SettingKeyOIDCConnectAuthorizeURL: "",
service.SettingKeyOIDCConnectTokenURL: "",
service.SettingKeyOIDCConnectUserInfoURL: "",
service.SettingKeyOIDCConnectJWKSURL: "",
service.SettingKeyOIDCConnectScopes: "openid email profile",
service.SettingKeyOIDCConnectRedirectURL: "",
service.SettingKeyOIDCConnectFrontendRedirectURL: "/auth/oidc/callback",
service.SettingKeyOIDCConnectTokenAuthMethod: "client_secret_post",
service.SettingKeyOIDCConnectUsePKCE: "false",
service.SettingKeyOIDCConnectValidateIDToken: "true",
service.SettingKeyOIDCConnectAllowedSigningAlgs: "RS256,ES256,PS256",
service.SettingKeyOIDCConnectClockSkewSeconds: "120",
service.SettingKeyOIDCConnectRequireEmailVerified: "false",
service.SettingKeyOIDCConnectUserInfoEmailPath: "",
service.SettingKeyOIDCConnectUserInfoIDPath: "",
service.SettingKeyOIDCConnectUserInfoUsernamePath: "",
service.SettingKeySiteName: "Sub2API",
service.SettingKeySiteLogo: "",
service.SettingKeySiteSubtitle: "Subtitle",
@@ -503,10 +525,32 @@ func TestAPIContracts(t *testing.T) {
"turnstile_enabled": true,
"turnstile_site_key": "site-key",
"turnstile_secret_key_configured": true,
"linuxdo_connect_enabled": false,
"linuxdo_connect_enabled": false,
"linuxdo_connect_client_id": "",
"linuxdo_connect_client_secret_configured": false,
"linuxdo_connect_redirect_url": "",
"oidc_connect_enabled": false,
"oidc_connect_provider_name": "OIDC",
"oidc_connect_client_id": "",
"oidc_connect_client_secret_configured": false,
"oidc_connect_issuer_url": "",
"oidc_connect_discovery_url": "",
"oidc_connect_authorize_url": "",
"oidc_connect_token_url": "",
"oidc_connect_userinfo_url": "",
"oidc_connect_jwks_url": "",
"oidc_connect_scopes": "openid email profile",
"oidc_connect_redirect_url": "",
"oidc_connect_frontend_redirect_url": "/auth/oidc/callback",
"oidc_connect_token_auth_method": "client_secret_post",
"oidc_connect_use_pkce": false,
"oidc_connect_validate_id_token": true,
"oidc_connect_allowed_signing_algs": "RS256,ES256,PS256",
"oidc_connect_clock_skew_seconds": 120,
"oidc_connect_require_email_verified": false,
"oidc_connect_userinfo_email_path": "",
"oidc_connect_userinfo_id_path": "",
"oidc_connect_userinfo_username_path": "",
"ops_monitoring_enabled": false,
"ops_realtime_monitoring_enabled": true,
"ops_query_mode_default": "auto",

View File

@@ -70,6 +70,14 @@ func RegisterAuthRoutes(
}),
h.Auth.CompleteLinuxDoOAuthRegistration,
)
auth.GET("/oauth/oidc/start", h.Auth.OIDCOAuthStart)
auth.GET("/oauth/oidc/callback", h.Auth.OIDCOAuthCallback)
auth.POST("/oauth/oidc/complete-registration",
rateLimiter.LimitWithOptions("oauth-oidc-complete", 10, time.Minute, middleware.RateLimitOptions{
FailureMode: middleware.RateLimitFailClose,
}),
h.Auth.CompleteOIDCOAuthRegistration,
)
}
// 公开设置(无需认证)

View File

@@ -152,10 +152,11 @@ type CreateGroupInput struct {
// 支持的模型系列(仅 antigravity 平台使用)
SupportedModelScopes []string
// OpenAI Messages 调度配置(仅 openai 平台使用)
AllowMessagesDispatch bool
DefaultMappedModel string
RequireOAuthOnly bool
RequirePrivacySet bool
AllowMessagesDispatch bool
DefaultMappedModel string
RequireOAuthOnly bool
RequirePrivacySet bool
MessagesDispatchModelConfig OpenAIMessagesDispatchModelConfig
// 从指定分组复制账号(创建分组后在同一事务内绑定)
CopyAccountsFromGroupIDs []int64
}
@@ -186,10 +187,11 @@ type UpdateGroupInput struct {
// 支持的模型系列(仅 antigravity 平台使用)
SupportedModelScopes *[]string
// OpenAI Messages 调度配置(仅 openai 平台使用)
AllowMessagesDispatch *bool
DefaultMappedModel *string
RequireOAuthOnly *bool
RequirePrivacySet *bool
AllowMessagesDispatch *bool
DefaultMappedModel *string
RequireOAuthOnly *bool
RequirePrivacySet *bool
MessagesDispatchModelConfig *OpenAIMessagesDispatchModelConfig
// 从指定分组复制账号(同步操作:先清空当前分组的账号绑定,再绑定源分组的账号)
CopyAccountsFromGroupIDs []int64
}
@@ -908,7 +910,9 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
RequireOAuthOnly: input.RequireOAuthOnly,
RequirePrivacySet: input.RequirePrivacySet,
DefaultMappedModel: input.DefaultMappedModel,
MessagesDispatchModelConfig: normalizeOpenAIMessagesDispatchModelConfig(input.MessagesDispatchModelConfig),
}
sanitizeGroupMessagesDispatchFields(group)
if err := s.groupRepo.Create(ctx, group); err != nil {
return nil, err
}
@@ -1135,6 +1139,10 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd
if input.DefaultMappedModel != nil {
group.DefaultMappedModel = *input.DefaultMappedModel
}
if input.MessagesDispatchModelConfig != nil {
group.MessagesDispatchModelConfig = normalizeOpenAIMessagesDispatchModelConfig(*input.MessagesDispatchModelConfig)
}
sanitizeGroupMessagesDispatchFields(group)
if err := s.groupRepo.Update(ctx, group); err != nil {
return nil, err

View File

@@ -10,6 +10,11 @@ import (
"github.com/stretchr/testify/require"
)
func ptrString[T ~string](v T) *string {
s := string(v)
return &s
}
// groupRepoStubForAdmin 用于测试 AdminService 的 GroupRepository Stub
type groupRepoStubForAdmin struct {
created *Group // 记录 Create 调用的参数
@@ -261,6 +266,116 @@ func TestAdminService_UpdateGroup_PartialImagePricing(t *testing.T) {
require.Nil(t, repo.updated.ImagePrice4K)
}
func TestAdminService_CreateGroup_NormalizesMessagesDispatchModelConfig(t *testing.T) {
repo := &groupRepoStubForAdmin{}
svc := &adminServiceImpl{groupRepo: repo}
group, err := svc.CreateGroup(context.Background(), &CreateGroupInput{
Name: "dispatch-group",
Description: "dispatch config",
Platform: PlatformOpenAI,
RateMultiplier: 1.0,
MessagesDispatchModelConfig: OpenAIMessagesDispatchModelConfig{
OpusMappedModel: " gpt-5.4-high ",
SonnetMappedModel: " gpt-5.3-codex ",
HaikuMappedModel: " gpt-5.4-mini-medium ",
ExactModelMappings: map[string]string{
" claude-sonnet-4-5-20250929 ": " gpt-5.2-high ",
},
},
})
require.NoError(t, err)
require.NotNil(t, group)
require.NotNil(t, repo.created)
require.Equal(t, OpenAIMessagesDispatchModelConfig{
OpusMappedModel: "gpt-5.4",
SonnetMappedModel: "gpt-5.3-codex",
HaikuMappedModel: "gpt-5.4-mini",
ExactModelMappings: map[string]string{
"claude-sonnet-4-5-20250929": "gpt-5.2",
},
}, repo.created.MessagesDispatchModelConfig)
}
func TestAdminService_UpdateGroup_NormalizesMessagesDispatchModelConfig(t *testing.T) {
existingGroup := &Group{
ID: 1,
Name: "existing-group",
Platform: PlatformOpenAI,
Status: StatusActive,
}
repo := &groupRepoStubForAdmin{getByID: existingGroup}
svc := &adminServiceImpl{groupRepo: repo}
group, err := svc.UpdateGroup(context.Background(), 1, &UpdateGroupInput{
MessagesDispatchModelConfig: &OpenAIMessagesDispatchModelConfig{
SonnetMappedModel: " gpt-5.4-medium ",
ExactModelMappings: map[string]string{
" claude-haiku-4-5-20251001 ": " gpt-5.4-mini-high ",
},
},
})
require.NoError(t, err)
require.NotNil(t, group)
require.NotNil(t, repo.updated)
require.Equal(t, OpenAIMessagesDispatchModelConfig{
SonnetMappedModel: "gpt-5.4",
ExactModelMappings: map[string]string{
"claude-haiku-4-5-20251001": "gpt-5.4-mini",
},
}, repo.updated.MessagesDispatchModelConfig)
}
func TestAdminService_CreateGroup_ClearsMessagesDispatchFieldsForNonOpenAIPlatform(t *testing.T) {
repo := &groupRepoStubForAdmin{}
svc := &adminServiceImpl{groupRepo: repo}
group, err := svc.CreateGroup(context.Background(), &CreateGroupInput{
Name: "anthropic-group",
Description: "non-openai",
Platform: PlatformAnthropic,
RateMultiplier: 1.0,
AllowMessagesDispatch: true,
DefaultMappedModel: "gpt-5.4",
MessagesDispatchModelConfig: OpenAIMessagesDispatchModelConfig{
OpusMappedModel: "gpt-5.4",
},
})
require.NoError(t, err)
require.NotNil(t, group)
require.NotNil(t, repo.created)
require.False(t, repo.created.AllowMessagesDispatch)
require.Empty(t, repo.created.DefaultMappedModel)
require.Equal(t, OpenAIMessagesDispatchModelConfig{}, repo.created.MessagesDispatchModelConfig)
}
func TestAdminService_UpdateGroup_ClearsMessagesDispatchFieldsWhenPlatformChangesAwayFromOpenAI(t *testing.T) {
existingGroup := &Group{
ID: 1,
Name: "existing-openai-group",
Platform: PlatformOpenAI,
Status: StatusActive,
AllowMessagesDispatch: true,
DefaultMappedModel: "gpt-5.4",
MessagesDispatchModelConfig: OpenAIMessagesDispatchModelConfig{
SonnetMappedModel: "gpt-5.3-codex",
},
}
repo := &groupRepoStubForAdmin{getByID: existingGroup}
svc := &adminServiceImpl{groupRepo: repo}
group, err := svc.UpdateGroup(context.Background(), 1, &UpdateGroupInput{
Platform: PlatformAnthropic,
})
require.NoError(t, err)
require.NotNil(t, group)
require.NotNil(t, repo.updated)
require.Equal(t, PlatformAnthropic, repo.updated.Platform)
require.False(t, repo.updated.AllowMessagesDispatch)
require.Empty(t, repo.updated.DefaultMappedModel)
require.Equal(t, OpenAIMessagesDispatchModelConfig{}, repo.updated.MessagesDispatchModelConfig)
}
func TestAdminService_ListGroups_WithSearch(t *testing.T) {
// 测试:
// 1. search 参数正常传递到 repository 层

View File

@@ -833,7 +833,8 @@ func randomHexString(byteLength int) (string, error) {
func isReservedEmail(email string) bool {
normalized := strings.ToLower(strings.TrimSpace(email))
return strings.HasSuffix(normalized, LinuxDoConnectSyntheticEmailDomain)
return strings.HasSuffix(normalized, LinuxDoConnectSyntheticEmailDomain) ||
strings.HasSuffix(normalized, OIDCConnectSyntheticEmailDomain)
}
// GenerateToken 生成JWT access token

View File

@@ -71,6 +71,9 @@ const (
// LinuxDoConnectSyntheticEmailDomain 是 LinuxDo Connect 用户的合成邮箱后缀RFC 保留域名)。
const LinuxDoConnectSyntheticEmailDomain = "@linuxdo-connect.invalid"
// OIDCConnectSyntheticEmailDomain 是 OIDC 用户的合成邮箱后缀RFC 保留域名)。
const OIDCConnectSyntheticEmailDomain = "@oidc-connect.invalid"
// Setting keys
const (
// 注册设置
@@ -105,6 +108,30 @@ const (
SettingKeyLinuxDoConnectClientSecret = "linuxdo_connect_client_secret"
SettingKeyLinuxDoConnectRedirectURL = "linuxdo_connect_redirect_url"
// Generic OIDC OAuth 登录设置
SettingKeyOIDCConnectEnabled = "oidc_connect_enabled"
SettingKeyOIDCConnectProviderName = "oidc_connect_provider_name"
SettingKeyOIDCConnectClientID = "oidc_connect_client_id"
SettingKeyOIDCConnectClientSecret = "oidc_connect_client_secret"
SettingKeyOIDCConnectIssuerURL = "oidc_connect_issuer_url"
SettingKeyOIDCConnectDiscoveryURL = "oidc_connect_discovery_url"
SettingKeyOIDCConnectAuthorizeURL = "oidc_connect_authorize_url"
SettingKeyOIDCConnectTokenURL = "oidc_connect_token_url"
SettingKeyOIDCConnectUserInfoURL = "oidc_connect_userinfo_url"
SettingKeyOIDCConnectJWKSURL = "oidc_connect_jwks_url"
SettingKeyOIDCConnectScopes = "oidc_connect_scopes"
SettingKeyOIDCConnectRedirectURL = "oidc_connect_redirect_url"
SettingKeyOIDCConnectFrontendRedirectURL = "oidc_connect_frontend_redirect_url"
SettingKeyOIDCConnectTokenAuthMethod = "oidc_connect_token_auth_method"
SettingKeyOIDCConnectUsePKCE = "oidc_connect_use_pkce"
SettingKeyOIDCConnectValidateIDToken = "oidc_connect_validate_id_token"
SettingKeyOIDCConnectAllowedSigningAlgs = "oidc_connect_allowed_signing_algs"
SettingKeyOIDCConnectClockSkewSeconds = "oidc_connect_clock_skew_seconds"
SettingKeyOIDCConnectRequireEmailVerified = "oidc_connect_require_email_verified"
SettingKeyOIDCConnectUserInfoEmailPath = "oidc_connect_userinfo_email_path"
SettingKeyOIDCConnectUserInfoIDPath = "oidc_connect_userinfo_id_path"
SettingKeyOIDCConnectUserInfoUsernamePath = "oidc_connect_userinfo_username_path"
// OEM设置
SettingKeySiteName = "site_name" // 网站名称
SettingKeySiteLogo = "site_logo" // 网站Logo (base64)

View File

@@ -3,8 +3,12 @@ package service
import (
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/domain"
)
type OpenAIMessagesDispatchModelConfig = domain.OpenAIMessagesDispatchModelConfig
type Group struct {
ID int64
Name string
@@ -49,10 +53,11 @@ type Group struct {
SortOrder int
// OpenAI Messages 调度配置(仅 openai 平台使用)
AllowMessagesDispatch bool
RequireOAuthOnly bool // 仅允许非 apikey 类型账号关联OpenAI/Antigravity/Anthropic/Gemini
RequirePrivacySet bool // 调度时仅允许 privacy 已成功设置的账号OpenAI/Antigravity/Anthropic/Gemini
DefaultMappedModel string
AllowMessagesDispatch bool
RequireOAuthOnly bool // 仅允许非 apikey 类型账号关联OpenAI/Antigravity/Anthropic/Gemini
RequirePrivacySet bool // 调度时仅允许 privacy 已成功设置的账号OpenAI/Antigravity/Anthropic/Gemini
DefaultMappedModel string
MessagesDispatchModelConfig OpenAIMessagesDispatchModelConfig
CreatedAt time.Time
UpdatedAt time.Time

View File

@@ -0,0 +1,55 @@
package service
import (
"bytes"
"fmt"
"strings"
"text/template"
)
type forcedCodexInstructionsTemplateData struct {
ExistingInstructions string
OriginalModel string
NormalizedModel string
BillingModel string
UpstreamModel string
}
func applyForcedCodexInstructionsTemplate(
reqBody map[string]any,
templateText string,
data forcedCodexInstructionsTemplateData,
) (bool, error) {
rendered, err := renderForcedCodexInstructionsTemplate(templateText, data)
if err != nil {
return false, err
}
if rendered == "" {
return false, nil
}
existing, _ := reqBody["instructions"].(string)
if strings.TrimSpace(existing) == rendered {
return false, nil
}
reqBody["instructions"] = rendered
return true, nil
}
func renderForcedCodexInstructionsTemplate(
templateText string,
data forcedCodexInstructionsTemplateData,
) (string, error) {
tmpl, err := template.New("forced_codex_instructions").Option("missingkey=zero").Parse(templateText)
if err != nil {
return "", fmt.Errorf("parse forced codex instructions template: %w", err)
}
var buf bytes.Buffer
if err := tmpl.Execute(&buf, data); err != nil {
return "", fmt.Errorf("render forced codex instructions template: %w", err)
}
return strings.TrimSpace(buf.String()), nil
}

View File

@@ -6,9 +6,12 @@ import (
"io"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"strings"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/apicompat"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
@@ -127,3 +130,101 @@ func TestForwardAsAnthropic_NormalizesRoutingAndEffortForGpt54XHigh(t *testing.T
t.Logf("upstream body: %s", string(upstream.lastBody))
t.Logf("response body: %s", rec.Body.String())
}
func TestForwardAsAnthropic_ForcedCodexInstructionsTemplatePrependsRenderedInstructions(t *testing.T) {
t.Parallel()
gin.SetMode(gin.TestMode)
templateDir := t.TempDir()
templatePath := filepath.Join(templateDir, "codex-instructions.md.tmpl")
require.NoError(t, os.WriteFile(templatePath, []byte("server-prefix\n\n{{ .ExistingInstructions }}"), 0o644))
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
body := []byte(`{"model":"gpt-5.4","max_tokens":16,"system":"client-system","messages":[{"role":"user","content":"hello"}],"stream":false}`)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body))
c.Request.Header.Set("Content-Type", "application/json")
upstreamBody := strings.Join([]string{
`data: {"type":"response.completed","response":{"id":"resp_1","object":"response","model":"gpt-5.4","status":"completed","output":[{"type":"message","id":"msg_1","role":"assistant","status":"completed","content":[{"type":"output_text","text":"ok"}]}],"usage":{"input_tokens":5,"output_tokens":2,"total_tokens":7}}}`,
"",
"data: [DONE]",
"",
}, "\n")
upstream := &httpUpstreamRecorder{resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_forced"}},
Body: io.NopCloser(strings.NewReader(upstreamBody)),
}}
svc := &OpenAIGatewayService{
cfg: &config.Config{Gateway: config.GatewayConfig{
ForcedCodexInstructionsTemplateFile: templatePath,
ForcedCodexInstructionsTemplate: "server-prefix\n\n{{ .ExistingInstructions }}",
}},
httpUpstream: upstream,
}
account := &Account{
ID: 1,
Name: "openai-oauth",
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Concurrency: 1,
Credentials: map[string]any{
"access_token": "oauth-token",
"chatgpt_account_id": "chatgpt-acc",
},
}
result, err := svc.ForwardAsAnthropic(context.Background(), c, account, body, "", "gpt-5.1")
require.NoError(t, err)
require.NotNil(t, result)
require.Equal(t, "server-prefix\n\nclient-system", gjson.GetBytes(upstream.lastBody, "instructions").String())
}
func TestForwardAsAnthropic_ForcedCodexInstructionsTemplateUsesCachedTemplateContent(t *testing.T) {
t.Parallel()
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
body := []byte(`{"model":"gpt-5.4","max_tokens":16,"system":"client-system","messages":[{"role":"user","content":"hello"}],"stream":false}`)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body))
c.Request.Header.Set("Content-Type", "application/json")
upstreamBody := strings.Join([]string{
`data: {"type":"response.completed","response":{"id":"resp_1","object":"response","model":"gpt-5.4","status":"completed","output":[{"type":"message","id":"msg_1","role":"assistant","status":"completed","content":[{"type":"output_text","text":"ok"}]}],"usage":{"input_tokens":5,"output_tokens":2,"total_tokens":7}}}`,
"",
"data: [DONE]",
"",
}, "\n")
upstream := &httpUpstreamRecorder{resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_forced_cached"}},
Body: io.NopCloser(strings.NewReader(upstreamBody)),
}}
svc := &OpenAIGatewayService{
cfg: &config.Config{Gateway: config.GatewayConfig{
ForcedCodexInstructionsTemplateFile: "/path/that/should/not/be/read.tmpl",
ForcedCodexInstructionsTemplate: "cached-prefix\n\n{{ .ExistingInstructions }}",
}},
httpUpstream: upstream,
}
account := &Account{
ID: 1,
Name: "openai-oauth",
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Concurrency: 1,
Credentials: map[string]any{
"access_token": "oauth-token",
"chatgpt_account_id": "chatgpt-acc",
},
}
result, err := svc.ForwardAsAnthropic(context.Background(), c, account, body, "", "gpt-5.1")
require.NoError(t, err)
require.NotNil(t, result)
require.Equal(t, "cached-prefix\n\nclient-system", gjson.GetBytes(upstream.lastBody, "instructions").String())
}

View File

@@ -86,6 +86,24 @@ func (s *OpenAIGatewayService) ForwardAsAnthropic(
return nil, fmt.Errorf("unmarshal for codex transform: %w", err)
}
codexResult := applyCodexOAuthTransform(reqBody, false, false)
forcedTemplateText := ""
if s.cfg != nil {
forcedTemplateText = s.cfg.Gateway.ForcedCodexInstructionsTemplate
}
templateUpstreamModel := upstreamModel
if codexResult.NormalizedModel != "" {
templateUpstreamModel = codexResult.NormalizedModel
}
existingInstructions, _ := reqBody["instructions"].(string)
if _, err := applyForcedCodexInstructionsTemplate(reqBody, forcedTemplateText, forcedCodexInstructionsTemplateData{
ExistingInstructions: strings.TrimSpace(existingInstructions),
OriginalModel: originalModel,
NormalizedModel: normalizedModel,
BillingModel: billingModel,
UpstreamModel: templateUpstreamModel,
}); err != nil {
return nil, err
}
if codexResult.NormalizedModel != "" {
upstreamModel = codexResult.NormalizedModel
}

View File

@@ -0,0 +1,100 @@
package service
import "strings"
const (
defaultOpenAIMessagesDispatchOpusMappedModel = "gpt-5.4"
defaultOpenAIMessagesDispatchSonnetMappedModel = "gpt-5.3-codex"
defaultOpenAIMessagesDispatchHaikuMappedModel = "gpt-5.4-mini"
)
func normalizeOpenAIMessagesDispatchMappedModel(model string) string {
model = NormalizeOpenAICompatRequestedModel(strings.TrimSpace(model))
return strings.TrimSpace(model)
}
func normalizeOpenAIMessagesDispatchModelConfig(cfg OpenAIMessagesDispatchModelConfig) OpenAIMessagesDispatchModelConfig {
out := OpenAIMessagesDispatchModelConfig{
OpusMappedModel: normalizeOpenAIMessagesDispatchMappedModel(cfg.OpusMappedModel),
SonnetMappedModel: normalizeOpenAIMessagesDispatchMappedModel(cfg.SonnetMappedModel),
HaikuMappedModel: normalizeOpenAIMessagesDispatchMappedModel(cfg.HaikuMappedModel),
}
if len(cfg.ExactModelMappings) > 0 {
out.ExactModelMappings = make(map[string]string, len(cfg.ExactModelMappings))
for requestedModel, mappedModel := range cfg.ExactModelMappings {
requestedModel = strings.TrimSpace(requestedModel)
mappedModel = normalizeOpenAIMessagesDispatchMappedModel(mappedModel)
if requestedModel == "" || mappedModel == "" {
continue
}
out.ExactModelMappings[requestedModel] = mappedModel
}
if len(out.ExactModelMappings) == 0 {
out.ExactModelMappings = nil
}
}
return out
}
func claudeMessagesDispatchFamily(model string) string {
normalized := strings.ToLower(strings.TrimSpace(model))
if !strings.HasPrefix(normalized, "claude") {
return ""
}
switch {
case strings.Contains(normalized, "opus"):
return "opus"
case strings.Contains(normalized, "sonnet"):
return "sonnet"
case strings.Contains(normalized, "haiku"):
return "haiku"
default:
return ""
}
}
func (g *Group) ResolveMessagesDispatchModel(requestedModel string) string {
if g == nil {
return ""
}
requestedModel = strings.TrimSpace(requestedModel)
if requestedModel == "" {
return ""
}
cfg := normalizeOpenAIMessagesDispatchModelConfig(g.MessagesDispatchModelConfig)
if mappedModel := strings.TrimSpace(cfg.ExactModelMappings[requestedModel]); mappedModel != "" {
return mappedModel
}
switch claudeMessagesDispatchFamily(requestedModel) {
case "opus":
if mappedModel := strings.TrimSpace(cfg.OpusMappedModel); mappedModel != "" {
return mappedModel
}
return defaultOpenAIMessagesDispatchOpusMappedModel
case "sonnet":
if mappedModel := strings.TrimSpace(cfg.SonnetMappedModel); mappedModel != "" {
return mappedModel
}
return defaultOpenAIMessagesDispatchSonnetMappedModel
case "haiku":
if mappedModel := strings.TrimSpace(cfg.HaikuMappedModel); mappedModel != "" {
return mappedModel
}
return defaultOpenAIMessagesDispatchHaikuMappedModel
default:
return ""
}
}
func sanitizeGroupMessagesDispatchFields(g *Group) {
if g == nil || g.Platform == PlatformOpenAI {
return
}
g.AllowMessagesDispatch = false
g.DefaultMappedModel = ""
g.MessagesDispatchModelConfig = OpenAIMessagesDispatchModelConfig{}
}

View File

@@ -0,0 +1,27 @@
package service
import "testing"
import "github.com/stretchr/testify/require"
func TestNormalizeOpenAIMessagesDispatchModelConfig(t *testing.T) {
t.Parallel()
cfg := normalizeOpenAIMessagesDispatchModelConfig(OpenAIMessagesDispatchModelConfig{
OpusMappedModel: " gpt-5.4-high ",
SonnetMappedModel: "gpt-5.3-codex",
HaikuMappedModel: " gpt-5.4-mini-medium ",
ExactModelMappings: map[string]string{
" claude-sonnet-4-5-20250929 ": " gpt-5.2-high ",
"": "gpt-5.4",
"claude-opus-4-6": " ",
},
})
require.Equal(t, "gpt-5.4", cfg.OpusMappedModel)
require.Equal(t, "gpt-5.3-codex", cfg.SonnetMappedModel)
require.Equal(t, "gpt-5.4-mini", cfg.HaikuMappedModel)
require.Equal(t, map[string]string{
"claude-sonnet-4-5-20250929": "gpt-5.2",
}, cfg.ExactModelMappings)
}

View File

@@ -16,7 +16,7 @@ import (
var ErrOpsDisabled = infraerrors.NotFound("OPS_DISABLED", "Ops monitoring is disabled")
const (
opsMaxStoredRequestBodyBytes = 10 * 1024
opsMaxStoredRequestBodyBytes = 256 * 1024
opsMaxStoredErrorBodyBytes = 20 * 1024
)

View File

@@ -17,6 +17,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/config"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/imroc/req/v3"
"golang.org/x/sync/singleflight"
)
@@ -167,6 +168,8 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
SettingKeyCustomEndpoints,
SettingKeyLinuxDoConnectEnabled,
SettingKeyBackendModeEnabled,
SettingKeyOIDCConnectEnabled,
SettingKeyOIDCConnectProviderName,
}
settings, err := s.settingRepo.GetMultiple(ctx, keys)
@@ -180,6 +183,19 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
} else {
linuxDoEnabled = s.cfg != nil && s.cfg.LinuxDo.Enabled
}
oidcEnabled := false
if raw, ok := settings[SettingKeyOIDCConnectEnabled]; ok {
oidcEnabled = raw == "true"
} else {
oidcEnabled = s.cfg != nil && s.cfg.OIDC.Enabled
}
oidcProviderName := strings.TrimSpace(settings[SettingKeyOIDCConnectProviderName])
if oidcProviderName == "" && s.cfg != nil {
oidcProviderName = strings.TrimSpace(s.cfg.OIDC.ProviderName)
}
if oidcProviderName == "" {
oidcProviderName = "OIDC"
}
// Password reset requires email verification to be enabled
emailVerifyEnabled := settings[SettingKeyEmailVerifyEnabled] == "true"
@@ -218,6 +234,8 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
CustomEndpoints: settings[SettingKeyCustomEndpoints],
LinuxDoOAuthEnabled: linuxDoEnabled,
BackendModeEnabled: settings[SettingKeyBackendModeEnabled] == "true",
OIDCOAuthEnabled: oidcEnabled,
OIDCOAuthProviderName: oidcProviderName,
}, nil
}
@@ -267,6 +285,8 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any
CustomEndpoints json.RawMessage `json:"custom_endpoints"`
LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"`
BackendModeEnabled bool `json:"backend_mode_enabled"`
OIDCOAuthEnabled bool `json:"oidc_oauth_enabled"`
OIDCOAuthProviderName string `json:"oidc_oauth_provider_name"`
Version string `json:"version,omitempty"`
}{
RegistrationEnabled: settings.RegistrationEnabled,
@@ -294,6 +314,8 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any
CustomEndpoints: safeRawJSONArray(settings.CustomEndpoints),
LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled,
BackendModeEnabled: settings.BackendModeEnabled,
OIDCOAuthEnabled: settings.OIDCOAuthEnabled,
OIDCOAuthProviderName: settings.OIDCOAuthProviderName,
Version: s.version,
}, nil
}
@@ -346,8 +368,8 @@ func safeRawJSONArray(raw string) json.RawMessage {
return json.RawMessage("[]")
}
// GetFrameSrcOrigins returns deduplicated http(s) origins from purchase_subscription_url
// and all custom_menu_items URLs. Used by the router layer for CSP frame-src injection.
// GetFrameSrcOrigins returns deduplicated http(s) origins from home_content URL,
// purchase_subscription_url, and all custom_menu_items URLs. Used by the router layer for CSP frame-src injection.
func (s *SettingService) GetFrameSrcOrigins(ctx context.Context) ([]string, error) {
settings, err := s.GetPublicSettings(ctx)
if err != nil {
@@ -366,6 +388,9 @@ func (s *SettingService) GetFrameSrcOrigins(ctx context.Context) ([]string, erro
}
}
// home content URL (when home_content is set to a URL for iframe embedding)
addOrigin(settings.HomeContent)
// purchase subscription URL
if settings.PurchaseSubscriptionEnabled {
addOrigin(settings.PurchaseSubscriptionURL)
@@ -473,6 +498,32 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet
updates[SettingKeyLinuxDoConnectClientSecret] = settings.LinuxDoConnectClientSecret
}
// Generic OIDC OAuth 登录
updates[SettingKeyOIDCConnectEnabled] = strconv.FormatBool(settings.OIDCConnectEnabled)
updates[SettingKeyOIDCConnectProviderName] = settings.OIDCConnectProviderName
updates[SettingKeyOIDCConnectClientID] = settings.OIDCConnectClientID
updates[SettingKeyOIDCConnectIssuerURL] = settings.OIDCConnectIssuerURL
updates[SettingKeyOIDCConnectDiscoveryURL] = settings.OIDCConnectDiscoveryURL
updates[SettingKeyOIDCConnectAuthorizeURL] = settings.OIDCConnectAuthorizeURL
updates[SettingKeyOIDCConnectTokenURL] = settings.OIDCConnectTokenURL
updates[SettingKeyOIDCConnectUserInfoURL] = settings.OIDCConnectUserInfoURL
updates[SettingKeyOIDCConnectJWKSURL] = settings.OIDCConnectJWKSURL
updates[SettingKeyOIDCConnectScopes] = settings.OIDCConnectScopes
updates[SettingKeyOIDCConnectRedirectURL] = settings.OIDCConnectRedirectURL
updates[SettingKeyOIDCConnectFrontendRedirectURL] = settings.OIDCConnectFrontendRedirectURL
updates[SettingKeyOIDCConnectTokenAuthMethod] = settings.OIDCConnectTokenAuthMethod
updates[SettingKeyOIDCConnectUsePKCE] = strconv.FormatBool(settings.OIDCConnectUsePKCE)
updates[SettingKeyOIDCConnectValidateIDToken] = strconv.FormatBool(settings.OIDCConnectValidateIDToken)
updates[SettingKeyOIDCConnectAllowedSigningAlgs] = settings.OIDCConnectAllowedSigningAlgs
updates[SettingKeyOIDCConnectClockSkewSeconds] = strconv.Itoa(settings.OIDCConnectClockSkewSeconds)
updates[SettingKeyOIDCConnectRequireEmailVerified] = strconv.FormatBool(settings.OIDCConnectRequireEmailVerified)
updates[SettingKeyOIDCConnectUserInfoEmailPath] = settings.OIDCConnectUserInfoEmailPath
updates[SettingKeyOIDCConnectUserInfoIDPath] = settings.OIDCConnectUserInfoIDPath
updates[SettingKeyOIDCConnectUserInfoUsernamePath] = settings.OIDCConnectUserInfoUsernamePath
if settings.OIDCConnectClientSecret != "" {
updates[SettingKeyOIDCConnectClientSecret] = settings.OIDCConnectClientSecret
}
// OEM设置
updates[SettingKeySiteName] = settings.SiteName
updates[SettingKeySiteLogo] = settings.SiteLogo
@@ -851,6 +902,8 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
SettingKeyTablePageSizeOptions: "[10,20,50,100]",
SettingKeyCustomMenuItems: "[]",
SettingKeyCustomEndpoints: "[]",
SettingKeyOIDCConnectEnabled: "false",
SettingKeyOIDCConnectProviderName: "OIDC",
SettingKeyDefaultConcurrency: strconv.Itoa(s.cfg.Default.UserConcurrency),
SettingKeyDefaultBalance: strconv.FormatFloat(s.cfg.Default.UserBalance, 'f', 8, 64),
SettingKeyDefaultSubscriptions: "[]",
@@ -980,6 +1033,138 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
}
result.LinuxDoConnectClientSecretConfigured = result.LinuxDoConnectClientSecret != ""
// Generic OIDC 设置:
// - 兼容 config.yaml/env
// - 支持后台系统设置覆盖并持久化(存储于 DB
oidcBase := config.OIDCConnectConfig{}
if s.cfg != nil {
oidcBase = s.cfg.OIDC
}
if raw, ok := settings[SettingKeyOIDCConnectEnabled]; ok {
result.OIDCConnectEnabled = raw == "true"
} else {
result.OIDCConnectEnabled = oidcBase.Enabled
}
if v, ok := settings[SettingKeyOIDCConnectProviderName]; ok && strings.TrimSpace(v) != "" {
result.OIDCConnectProviderName = strings.TrimSpace(v)
} else {
result.OIDCConnectProviderName = strings.TrimSpace(oidcBase.ProviderName)
}
if result.OIDCConnectProviderName == "" {
result.OIDCConnectProviderName = "OIDC"
}
if v, ok := settings[SettingKeyOIDCConnectClientID]; ok && strings.TrimSpace(v) != "" {
result.OIDCConnectClientID = strings.TrimSpace(v)
} else {
result.OIDCConnectClientID = strings.TrimSpace(oidcBase.ClientID)
}
if v, ok := settings[SettingKeyOIDCConnectIssuerURL]; ok && strings.TrimSpace(v) != "" {
result.OIDCConnectIssuerURL = strings.TrimSpace(v)
} else {
result.OIDCConnectIssuerURL = strings.TrimSpace(oidcBase.IssuerURL)
}
if v, ok := settings[SettingKeyOIDCConnectDiscoveryURL]; ok && strings.TrimSpace(v) != "" {
result.OIDCConnectDiscoveryURL = strings.TrimSpace(v)
} else {
result.OIDCConnectDiscoveryURL = strings.TrimSpace(oidcBase.DiscoveryURL)
}
if v, ok := settings[SettingKeyOIDCConnectAuthorizeURL]; ok && strings.TrimSpace(v) != "" {
result.OIDCConnectAuthorizeURL = strings.TrimSpace(v)
} else {
result.OIDCConnectAuthorizeURL = strings.TrimSpace(oidcBase.AuthorizeURL)
}
if v, ok := settings[SettingKeyOIDCConnectTokenURL]; ok && strings.TrimSpace(v) != "" {
result.OIDCConnectTokenURL = strings.TrimSpace(v)
} else {
result.OIDCConnectTokenURL = strings.TrimSpace(oidcBase.TokenURL)
}
if v, ok := settings[SettingKeyOIDCConnectUserInfoURL]; ok && strings.TrimSpace(v) != "" {
result.OIDCConnectUserInfoURL = strings.TrimSpace(v)
} else {
result.OIDCConnectUserInfoURL = strings.TrimSpace(oidcBase.UserInfoURL)
}
if v, ok := settings[SettingKeyOIDCConnectJWKSURL]; ok && strings.TrimSpace(v) != "" {
result.OIDCConnectJWKSURL = strings.TrimSpace(v)
} else {
result.OIDCConnectJWKSURL = strings.TrimSpace(oidcBase.JWKSURL)
}
if v, ok := settings[SettingKeyOIDCConnectScopes]; ok && strings.TrimSpace(v) != "" {
result.OIDCConnectScopes = strings.TrimSpace(v)
} else {
result.OIDCConnectScopes = strings.TrimSpace(oidcBase.Scopes)
}
if v, ok := settings[SettingKeyOIDCConnectRedirectURL]; ok && strings.TrimSpace(v) != "" {
result.OIDCConnectRedirectURL = strings.TrimSpace(v)
} else {
result.OIDCConnectRedirectURL = strings.TrimSpace(oidcBase.RedirectURL)
}
if v, ok := settings[SettingKeyOIDCConnectFrontendRedirectURL]; ok && strings.TrimSpace(v) != "" {
result.OIDCConnectFrontendRedirectURL = strings.TrimSpace(v)
} else {
result.OIDCConnectFrontendRedirectURL = strings.TrimSpace(oidcBase.FrontendRedirectURL)
}
if v, ok := settings[SettingKeyOIDCConnectTokenAuthMethod]; ok && strings.TrimSpace(v) != "" {
result.OIDCConnectTokenAuthMethod = strings.ToLower(strings.TrimSpace(v))
} else {
result.OIDCConnectTokenAuthMethod = strings.ToLower(strings.TrimSpace(oidcBase.TokenAuthMethod))
}
if raw, ok := settings[SettingKeyOIDCConnectUsePKCE]; ok {
result.OIDCConnectUsePKCE = raw == "true"
} else {
result.OIDCConnectUsePKCE = oidcBase.UsePKCE
}
if raw, ok := settings[SettingKeyOIDCConnectValidateIDToken]; ok {
result.OIDCConnectValidateIDToken = raw == "true"
} else {
result.OIDCConnectValidateIDToken = oidcBase.ValidateIDToken
}
if v, ok := settings[SettingKeyOIDCConnectAllowedSigningAlgs]; ok && strings.TrimSpace(v) != "" {
result.OIDCConnectAllowedSigningAlgs = strings.TrimSpace(v)
} else {
result.OIDCConnectAllowedSigningAlgs = strings.TrimSpace(oidcBase.AllowedSigningAlgs)
}
clockSkewSet := false
if raw, ok := settings[SettingKeyOIDCConnectClockSkewSeconds]; ok && strings.TrimSpace(raw) != "" {
if parsed, err := strconv.Atoi(strings.TrimSpace(raw)); err == nil {
result.OIDCConnectClockSkewSeconds = parsed
clockSkewSet = true
}
}
if !clockSkewSet {
result.OIDCConnectClockSkewSeconds = oidcBase.ClockSkewSeconds
}
if !clockSkewSet && result.OIDCConnectClockSkewSeconds == 0 {
result.OIDCConnectClockSkewSeconds = 120
}
if raw, ok := settings[SettingKeyOIDCConnectRequireEmailVerified]; ok {
result.OIDCConnectRequireEmailVerified = raw == "true"
} else {
result.OIDCConnectRequireEmailVerified = oidcBase.RequireEmailVerified
}
if v, ok := settings[SettingKeyOIDCConnectUserInfoEmailPath]; ok {
result.OIDCConnectUserInfoEmailPath = strings.TrimSpace(v)
} else {
result.OIDCConnectUserInfoEmailPath = strings.TrimSpace(oidcBase.UserInfoEmailPath)
}
if v, ok := settings[SettingKeyOIDCConnectUserInfoIDPath]; ok {
result.OIDCConnectUserInfoIDPath = strings.TrimSpace(v)
} else {
result.OIDCConnectUserInfoIDPath = strings.TrimSpace(oidcBase.UserInfoIDPath)
}
if v, ok := settings[SettingKeyOIDCConnectUserInfoUsernamePath]; ok {
result.OIDCConnectUserInfoUsernamePath = strings.TrimSpace(v)
} else {
result.OIDCConnectUserInfoUsernamePath = strings.TrimSpace(oidcBase.UserInfoUsernamePath)
}
result.OIDCConnectClientSecret = strings.TrimSpace(settings[SettingKeyOIDCConnectClientSecret])
if result.OIDCConnectClientSecret == "" {
result.OIDCConnectClientSecret = strings.TrimSpace(oidcBase.ClientSecret)
}
result.OIDCConnectClientSecretConfigured = result.OIDCConnectClientSecret != ""
// Model fallback settings
result.EnableModelFallback = settings[SettingKeyEnableModelFallback] == "true"
result.FallbackModelAnthropic = s.getStringOrDefault(settings, SettingKeyFallbackModelAnthropic, "claude-3-5-sonnet-20241022")
@@ -1396,6 +1581,282 @@ func (s *SettingService) SetOverloadCooldownSettings(ctx context.Context, settin
return s.settingRepo.Set(ctx, SettingKeyOverloadCooldownSettings, string(data))
}
// GetOIDCConnectOAuthConfig 返回用于登录的“最终生效” OIDC 配置。
//
// 优先级:
// - 若对应系统设置键存在,则覆盖 config.yaml/env 的值
// - 否则回退到 config.yaml/env 的值
func (s *SettingService) GetOIDCConnectOAuthConfig(ctx context.Context) (config.OIDCConnectConfig, error) {
if s == nil || s.cfg == nil {
return config.OIDCConnectConfig{}, infraerrors.ServiceUnavailable("CONFIG_NOT_READY", "config not loaded")
}
effective := s.cfg.OIDC
keys := []string{
SettingKeyOIDCConnectEnabled,
SettingKeyOIDCConnectProviderName,
SettingKeyOIDCConnectClientID,
SettingKeyOIDCConnectClientSecret,
SettingKeyOIDCConnectIssuerURL,
SettingKeyOIDCConnectDiscoveryURL,
SettingKeyOIDCConnectAuthorizeURL,
SettingKeyOIDCConnectTokenURL,
SettingKeyOIDCConnectUserInfoURL,
SettingKeyOIDCConnectJWKSURL,
SettingKeyOIDCConnectScopes,
SettingKeyOIDCConnectRedirectURL,
SettingKeyOIDCConnectFrontendRedirectURL,
SettingKeyOIDCConnectTokenAuthMethod,
SettingKeyOIDCConnectUsePKCE,
SettingKeyOIDCConnectValidateIDToken,
SettingKeyOIDCConnectAllowedSigningAlgs,
SettingKeyOIDCConnectClockSkewSeconds,
SettingKeyOIDCConnectRequireEmailVerified,
SettingKeyOIDCConnectUserInfoEmailPath,
SettingKeyOIDCConnectUserInfoIDPath,
SettingKeyOIDCConnectUserInfoUsernamePath,
}
settings, err := s.settingRepo.GetMultiple(ctx, keys)
if err != nil {
return config.OIDCConnectConfig{}, fmt.Errorf("get oidc connect settings: %w", err)
}
if raw, ok := settings[SettingKeyOIDCConnectEnabled]; ok {
effective.Enabled = raw == "true"
}
if v, ok := settings[SettingKeyOIDCConnectProviderName]; ok && strings.TrimSpace(v) != "" {
effective.ProviderName = strings.TrimSpace(v)
}
if v, ok := settings[SettingKeyOIDCConnectClientID]; ok && strings.TrimSpace(v) != "" {
effective.ClientID = strings.TrimSpace(v)
}
if v, ok := settings[SettingKeyOIDCConnectClientSecret]; ok && strings.TrimSpace(v) != "" {
effective.ClientSecret = strings.TrimSpace(v)
}
if v, ok := settings[SettingKeyOIDCConnectIssuerURL]; ok && strings.TrimSpace(v) != "" {
effective.IssuerURL = strings.TrimSpace(v)
}
if v, ok := settings[SettingKeyOIDCConnectDiscoveryURL]; ok && strings.TrimSpace(v) != "" {
effective.DiscoveryURL = strings.TrimSpace(v)
}
if v, ok := settings[SettingKeyOIDCConnectAuthorizeURL]; ok && strings.TrimSpace(v) != "" {
effective.AuthorizeURL = strings.TrimSpace(v)
}
if v, ok := settings[SettingKeyOIDCConnectTokenURL]; ok && strings.TrimSpace(v) != "" {
effective.TokenURL = strings.TrimSpace(v)
}
if v, ok := settings[SettingKeyOIDCConnectUserInfoURL]; ok && strings.TrimSpace(v) != "" {
effective.UserInfoURL = strings.TrimSpace(v)
}
if v, ok := settings[SettingKeyOIDCConnectJWKSURL]; ok && strings.TrimSpace(v) != "" {
effective.JWKSURL = strings.TrimSpace(v)
}
if v, ok := settings[SettingKeyOIDCConnectScopes]; ok && strings.TrimSpace(v) != "" {
effective.Scopes = strings.TrimSpace(v)
}
if v, ok := settings[SettingKeyOIDCConnectRedirectURL]; ok && strings.TrimSpace(v) != "" {
effective.RedirectURL = strings.TrimSpace(v)
}
if v, ok := settings[SettingKeyOIDCConnectFrontendRedirectURL]; ok && strings.TrimSpace(v) != "" {
effective.FrontendRedirectURL = strings.TrimSpace(v)
}
if v, ok := settings[SettingKeyOIDCConnectTokenAuthMethod]; ok && strings.TrimSpace(v) != "" {
effective.TokenAuthMethod = strings.ToLower(strings.TrimSpace(v))
}
if raw, ok := settings[SettingKeyOIDCConnectUsePKCE]; ok {
effective.UsePKCE = raw == "true"
}
if raw, ok := settings[SettingKeyOIDCConnectValidateIDToken]; ok {
effective.ValidateIDToken = raw == "true"
}
if v, ok := settings[SettingKeyOIDCConnectAllowedSigningAlgs]; ok && strings.TrimSpace(v) != "" {
effective.AllowedSigningAlgs = strings.TrimSpace(v)
}
if raw, ok := settings[SettingKeyOIDCConnectClockSkewSeconds]; ok && strings.TrimSpace(raw) != "" {
if parsed, parseErr := strconv.Atoi(strings.TrimSpace(raw)); parseErr == nil {
effective.ClockSkewSeconds = parsed
}
}
if raw, ok := settings[SettingKeyOIDCConnectRequireEmailVerified]; ok {
effective.RequireEmailVerified = raw == "true"
}
if v, ok := settings[SettingKeyOIDCConnectUserInfoEmailPath]; ok {
effective.UserInfoEmailPath = strings.TrimSpace(v)
}
if v, ok := settings[SettingKeyOIDCConnectUserInfoIDPath]; ok {
effective.UserInfoIDPath = strings.TrimSpace(v)
}
if v, ok := settings[SettingKeyOIDCConnectUserInfoUsernamePath]; ok {
effective.UserInfoUsernamePath = strings.TrimSpace(v)
}
if !effective.Enabled {
return config.OIDCConnectConfig{}, infraerrors.NotFound("OAUTH_DISABLED", "oauth login is disabled")
}
if strings.TrimSpace(effective.ProviderName) == "" {
effective.ProviderName = "OIDC"
}
if strings.TrimSpace(effective.ClientID) == "" {
return config.OIDCConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth client id not configured")
}
if strings.TrimSpace(effective.IssuerURL) == "" {
return config.OIDCConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth issuer url not configured")
}
if strings.TrimSpace(effective.RedirectURL) == "" {
return config.OIDCConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth redirect url not configured")
}
if strings.TrimSpace(effective.FrontendRedirectURL) == "" {
return config.OIDCConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth frontend redirect url not configured")
}
if !scopesContainOpenID(effective.Scopes) {
return config.OIDCConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth scopes must contain openid")
}
if effective.ClockSkewSeconds < 0 || effective.ClockSkewSeconds > 600 {
return config.OIDCConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth clock skew must be between 0 and 600")
}
if err := config.ValidateAbsoluteHTTPURL(effective.IssuerURL); err != nil {
return config.OIDCConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth issuer url invalid")
}
discoveryURL := strings.TrimSpace(effective.DiscoveryURL)
if discoveryURL == "" {
discoveryURL = oidcDefaultDiscoveryURL(effective.IssuerURL)
effective.DiscoveryURL = discoveryURL
}
if discoveryURL != "" {
if err := config.ValidateAbsoluteHTTPURL(discoveryURL); err != nil {
return config.OIDCConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth discovery url invalid")
}
}
needsDiscovery := strings.TrimSpace(effective.AuthorizeURL) == "" ||
strings.TrimSpace(effective.TokenURL) == "" ||
(effective.ValidateIDToken && strings.TrimSpace(effective.JWKSURL) == "")
if needsDiscovery && discoveryURL != "" {
metadata, resolveErr := oidcResolveProviderMetadata(ctx, discoveryURL)
if resolveErr != nil {
return config.OIDCConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth discovery resolve failed").WithCause(resolveErr)
}
if strings.TrimSpace(effective.AuthorizeURL) == "" {
effective.AuthorizeURL = strings.TrimSpace(metadata.AuthorizationEndpoint)
}
if strings.TrimSpace(effective.TokenURL) == "" {
effective.TokenURL = strings.TrimSpace(metadata.TokenEndpoint)
}
if strings.TrimSpace(effective.UserInfoURL) == "" {
effective.UserInfoURL = strings.TrimSpace(metadata.UserInfoEndpoint)
}
if strings.TrimSpace(effective.JWKSURL) == "" {
effective.JWKSURL = strings.TrimSpace(metadata.JWKSURI)
}
}
if strings.TrimSpace(effective.AuthorizeURL) == "" {
return config.OIDCConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth authorize url not configured")
}
if strings.TrimSpace(effective.TokenURL) == "" {
return config.OIDCConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth token url not configured")
}
if err := config.ValidateAbsoluteHTTPURL(effective.AuthorizeURL); err != nil {
return config.OIDCConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth authorize url invalid")
}
if err := config.ValidateAbsoluteHTTPURL(effective.TokenURL); err != nil {
return config.OIDCConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth token url invalid")
}
if v := strings.TrimSpace(effective.UserInfoURL); v != "" {
if err := config.ValidateAbsoluteHTTPURL(v); err != nil {
return config.OIDCConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth userinfo url invalid")
}
}
if effective.ValidateIDToken {
if strings.TrimSpace(effective.JWKSURL) == "" {
return config.OIDCConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth jwks url not configured")
}
if strings.TrimSpace(effective.AllowedSigningAlgs) == "" {
return config.OIDCConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth signing algs not configured")
}
}
if v := strings.TrimSpace(effective.JWKSURL); v != "" {
if err := config.ValidateAbsoluteHTTPURL(v); err != nil {
return config.OIDCConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth jwks url invalid")
}
}
if err := config.ValidateAbsoluteHTTPURL(effective.RedirectURL); err != nil {
return config.OIDCConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth redirect url invalid")
}
if err := config.ValidateFrontendRedirectURL(effective.FrontendRedirectURL); err != nil {
return config.OIDCConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth frontend redirect url invalid")
}
method := strings.ToLower(strings.TrimSpace(effective.TokenAuthMethod))
switch method {
case "", "client_secret_post", "client_secret_basic":
if strings.TrimSpace(effective.ClientSecret) == "" {
return config.OIDCConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth client secret not configured")
}
case "none":
if !effective.UsePKCE {
return config.OIDCConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth pkce must be enabled when token_auth_method=none")
}
default:
return config.OIDCConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth token_auth_method invalid")
}
return effective, nil
}
func scopesContainOpenID(scopes string) bool {
for _, scope := range strings.Fields(strings.ToLower(strings.TrimSpace(scopes))) {
if scope == "openid" {
return true
}
}
return false
}
type oidcProviderMetadata struct {
AuthorizationEndpoint string `json:"authorization_endpoint"`
TokenEndpoint string `json:"token_endpoint"`
UserInfoEndpoint string `json:"userinfo_endpoint"`
JWKSURI string `json:"jwks_uri"`
}
func oidcDefaultDiscoveryURL(issuerURL string) string {
issuerURL = strings.TrimSpace(issuerURL)
if issuerURL == "" {
return ""
}
return strings.TrimRight(issuerURL, "/") + "/.well-known/openid-configuration"
}
func oidcResolveProviderMetadata(ctx context.Context, discoveryURL string) (*oidcProviderMetadata, error) {
discoveryURL = strings.TrimSpace(discoveryURL)
if discoveryURL == "" {
return nil, fmt.Errorf("discovery url is empty")
}
resp, err := req.C().
SetTimeout(15*time.Second).
R().
SetContext(ctx).
SetHeader("Accept", "application/json").
Get(discoveryURL)
if err != nil {
return nil, fmt.Errorf("request discovery document: %w", err)
}
if !resp.IsSuccessState() {
return nil, fmt.Errorf("discovery request failed: status=%d", resp.StatusCode)
}
metadata := &oidcProviderMetadata{}
if err := json.Unmarshal(resp.Bytes(), metadata); err != nil {
return nil, fmt.Errorf("parse discovery document: %w", err)
}
return metadata, nil
}
// GetStreamTimeoutSettings 获取流超时处理配置
func (s *SettingService) GetStreamTimeoutSettings(ctx context.Context) (*StreamTimeoutSettings, error) {
value, err := s.settingRepo.GetValue(ctx, SettingKeyStreamTimeoutSettings)

View File

@@ -0,0 +1,103 @@
//go:build unit
package service
import (
"context"
"fmt"
"net/http"
"net/http/httptest"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
)
type settingOIDCRepoStub struct {
values map[string]string
}
func (s *settingOIDCRepoStub) Get(ctx context.Context, key string) (*Setting, error) {
panic("unexpected Get call")
}
func (s *settingOIDCRepoStub) GetValue(ctx context.Context, key string) (string, error) {
panic("unexpected GetValue call")
}
func (s *settingOIDCRepoStub) Set(ctx context.Context, key, value string) error {
panic("unexpected Set call")
}
func (s *settingOIDCRepoStub) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) {
out := make(map[string]string, len(keys))
for _, key := range keys {
if value, ok := s.values[key]; ok {
out[key] = value
}
}
return out, nil
}
func (s *settingOIDCRepoStub) SetMultiple(ctx context.Context, settings map[string]string) error {
panic("unexpected SetMultiple call")
}
func (s *settingOIDCRepoStub) GetAll(ctx context.Context) (map[string]string, error) {
panic("unexpected GetAll call")
}
func (s *settingOIDCRepoStub) Delete(ctx context.Context, key string) error {
panic("unexpected Delete call")
}
func TestGetOIDCConnectOAuthConfig_ResolvesEndpointsFromIssuerDiscovery(t *testing.T) {
var discoveryHits int
var baseURL string
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/issuer/.well-known/openid-configuration" {
http.NotFound(w, r)
return
}
discoveryHits++
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(fmt.Sprintf(`{
"authorization_endpoint":"%s/issuer/protocol/openid-connect/auth",
"token_endpoint":"%s/issuer/protocol/openid-connect/token",
"userinfo_endpoint":"%s/issuer/protocol/openid-connect/userinfo",
"jwks_uri":"%s/issuer/protocol/openid-connect/certs"
}`, baseURL, baseURL, baseURL, baseURL)))
}))
defer srv.Close()
baseURL = srv.URL
cfg := &config.Config{
OIDC: config.OIDCConnectConfig{
Enabled: true,
ProviderName: "OIDC",
ClientID: "oidc-client",
ClientSecret: "oidc-secret",
IssuerURL: srv.URL + "/issuer",
RedirectURL: "https://example.com/api/v1/auth/oauth/oidc/callback",
FrontendRedirectURL: "/auth/oidc/callback",
Scopes: "openid email profile",
TokenAuthMethod: "client_secret_post",
ValidateIDToken: true,
AllowedSigningAlgs: "RS256",
ClockSkewSeconds: 120,
},
}
repo := &settingOIDCRepoStub{values: map[string]string{}}
svc := NewSettingService(repo, cfg)
got, err := svc.GetOIDCConnectOAuthConfig(context.Background())
require.NoError(t, err)
require.Equal(t, 1, discoveryHits)
require.Equal(t, srv.URL+"/issuer/.well-known/openid-configuration", got.DiscoveryURL)
require.Equal(t, srv.URL+"/issuer/protocol/openid-connect/auth", got.AuthorizeURL)
require.Equal(t, srv.URL+"/issuer/protocol/openid-connect/token", got.TokenURL)
require.Equal(t, srv.URL+"/issuer/protocol/openid-connect/userinfo", got.UserInfoURL)
require.Equal(t, srv.URL+"/issuer/protocol/openid-connect/certs", got.JWKSURL)
}

View File

@@ -31,6 +31,31 @@ type SystemSettings struct {
LinuxDoConnectClientSecretConfigured bool
LinuxDoConnectRedirectURL string
// Generic OIDC OAuth 登录
OIDCConnectEnabled bool
OIDCConnectProviderName string
OIDCConnectClientID string
OIDCConnectClientSecret string
OIDCConnectClientSecretConfigured bool
OIDCConnectIssuerURL string
OIDCConnectDiscoveryURL string
OIDCConnectAuthorizeURL string
OIDCConnectTokenURL string
OIDCConnectUserInfoURL string
OIDCConnectJWKSURL string
OIDCConnectScopes string
OIDCConnectRedirectURL string
OIDCConnectFrontendRedirectURL string
OIDCConnectTokenAuthMethod string
OIDCConnectUsePKCE bool
OIDCConnectValidateIDToken bool
OIDCConnectAllowedSigningAlgs string
OIDCConnectClockSkewSeconds int
OIDCConnectRequireEmailVerified bool
OIDCConnectUserInfoEmailPath string
OIDCConnectUserInfoIDPath string
OIDCConnectUserInfoUsernamePath string
SiteName string
SiteLogo string
SiteSubtitle string
@@ -114,9 +139,11 @@ type PublicSettings struct {
CustomMenuItems string // JSON array of custom menu items
CustomEndpoints string // JSON array of custom endpoints
LinuxDoOAuthEnabled bool
BackendModeEnabled bool
Version string
LinuxDoOAuthEnabled bool
BackendModeEnabled bool
OIDCOAuthEnabled bool
OIDCOAuthProviderName string
Version string
}
// StreamTimeoutSettings 流超时处理配置(仅控制超时后的处理方式,超时判定由网关配置控制)