Files
sub2api/backend/internal/handler/auth_linuxdo_oauth.go
admin d1c2a61d19 refactor(auth): 将 Linux DO OAuth 配置迁移到系统设置
- 将 LinuxDo Connect 配置从环境变量迁移到数据库持久化
- 在管理后台系统设置中添加 LinuxDo OAuth 配置项
- 简化部署流程,无需修改 docker-compose.override.yml

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-09 18:26:32 +08:00

666 lines
19 KiB
Go

package handler
import (
"context"
"encoding/base64"
"errors"
"fmt"
"log"
"net/http"
"net/url"
"strconv"
"strings"
"time"
"unicode/utf8"
"github.com/Wei-Shaw/sub2api/internal/config"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/oauth"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/gin-gonic/gin"
"github.com/imroc/req/v3"
"github.com/tidwall/gjson"
)
const (
linuxDoOAuthCookiePath = "/api/v1/auth/oauth/linuxdo"
linuxDoOAuthStateCookieName = "linuxdo_oauth_state"
linuxDoOAuthVerifierCookie = "linuxdo_oauth_verifier"
linuxDoOAuthRedirectCookie = "linuxdo_oauth_redirect"
linuxDoOAuthCookieMaxAgeSec = 10 * 60 // 10 minutes
linuxDoOAuthDefaultRedirectTo = "/dashboard"
linuxDoOAuthDefaultFrontendCB = "/auth/linuxdo/callback"
linuxDoOAuthMaxRedirectLen = 2048
linuxDoOAuthMaxFragmentValueLen = 512
linuxDoOAuthMaxSubjectLen = 64 - len("linuxdo-")
)
type linuxDoTokenResponse struct {
AccessToken string `json:"access_token"`
TokenType string `json:"token_type"`
ExpiresIn int64 `json:"expires_in"`
RefreshToken string `json:"refresh_token,omitempty"`
Scope string `json:"scope,omitempty"`
}
type linuxDoTokenExchangeError struct {
StatusCode int
ProviderError string
ProviderDescription string
Body string
}
func (e *linuxDoTokenExchangeError) Error() string {
if e == nil {
return ""
}
parts := []string{fmt.Sprintf("token exchange status=%d", e.StatusCode)}
if strings.TrimSpace(e.ProviderError) != "" {
parts = append(parts, "error="+strings.TrimSpace(e.ProviderError))
}
if strings.TrimSpace(e.ProviderDescription) != "" {
parts = append(parts, "error_description="+strings.TrimSpace(e.ProviderDescription))
}
return strings.Join(parts, " ")
}
// LinuxDoOAuthStart starts the LinuxDo Connect OAuth login flow.
// GET /api/v1/auth/oauth/linuxdo/start?redirect=/dashboard
func (h *AuthHandler) LinuxDoOAuthStart(c *gin.Context) {
cfg, err := h.getLinuxDoOAuthConfig(c.Request.Context())
if err != nil {
response.ErrorFrom(c, err)
return
}
state, err := oauth.GenerateState()
if err != nil {
response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_STATE_GEN_FAILED", "failed to generate oauth state").WithCause(err))
return
}
redirectTo := sanitizeFrontendRedirectPath(c.Query("redirect"))
if redirectTo == "" {
redirectTo = linuxDoOAuthDefaultRedirectTo
}
secureCookie := isRequestHTTPS(c)
setCookie(c, linuxDoOAuthStateCookieName, encodeCookieValue(state), linuxDoOAuthCookieMaxAgeSec, secureCookie)
setCookie(c, linuxDoOAuthRedirectCookie, encodeCookieValue(redirectTo), linuxDoOAuthCookieMaxAgeSec, secureCookie)
codeChallenge := ""
if cfg.UsePKCE {
verifier, err := oauth.GenerateCodeVerifier()
if err != nil {
response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_PKCE_GEN_FAILED", "failed to generate pkce verifier").WithCause(err))
return
}
codeChallenge = oauth.GenerateCodeChallenge(verifier)
setCookie(c, linuxDoOAuthVerifierCookie, encodeCookieValue(verifier), linuxDoOAuthCookieMaxAgeSec, secureCookie)
}
redirectURI := strings.TrimSpace(cfg.RedirectURL)
if redirectURI == "" {
response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth redirect url not configured"))
return
}
authURL, err := buildLinuxDoAuthorizeURL(cfg, state, codeChallenge, redirectURI)
if err != nil {
response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_BUILD_URL_FAILED", "failed to build oauth authorization url").WithCause(err))
return
}
c.Redirect(http.StatusFound, authURL)
}
// LinuxDoOAuthCallback handles the OAuth callback, creates/logins the user, then redirects to frontend.
// GET /api/v1/auth/oauth/linuxdo/callback?code=...&state=...
func (h *AuthHandler) LinuxDoOAuthCallback(c *gin.Context) {
cfg, cfgErr := h.getLinuxDoOAuthConfig(c.Request.Context())
if cfgErr != nil {
response.ErrorFrom(c, cfgErr)
return
}
frontendCallback := strings.TrimSpace(cfg.FrontendRedirectURL)
if frontendCallback == "" {
frontendCallback = linuxDoOAuthDefaultFrontendCB
}
if providerErr := strings.TrimSpace(c.Query("error")); providerErr != "" {
redirectOAuthError(c, frontendCallback, "provider_error", providerErr, c.Query("error_description"))
return
}
code := strings.TrimSpace(c.Query("code"))
state := strings.TrimSpace(c.Query("state"))
if code == "" || state == "" {
redirectOAuthError(c, frontendCallback, "missing_params", "missing code/state", "")
return
}
secureCookie := isRequestHTTPS(c)
defer func() {
clearCookie(c, linuxDoOAuthStateCookieName, secureCookie)
clearCookie(c, linuxDoOAuthVerifierCookie, secureCookie)
clearCookie(c, linuxDoOAuthRedirectCookie, secureCookie)
}()
expectedState, err := readCookieDecoded(c, linuxDoOAuthStateCookieName)
if err != nil || expectedState == "" || state != expectedState {
redirectOAuthError(c, frontendCallback, "invalid_state", "invalid oauth state", "")
return
}
redirectTo, _ := readCookieDecoded(c, linuxDoOAuthRedirectCookie)
redirectTo = sanitizeFrontendRedirectPath(redirectTo)
if redirectTo == "" {
redirectTo = linuxDoOAuthDefaultRedirectTo
}
codeVerifier := ""
if cfg.UsePKCE {
codeVerifier, _ = readCookieDecoded(c, linuxDoOAuthVerifierCookie)
if codeVerifier == "" {
redirectOAuthError(c, frontendCallback, "missing_verifier", "missing pkce verifier", "")
return
}
}
redirectURI := strings.TrimSpace(cfg.RedirectURL)
if redirectURI == "" {
redirectOAuthError(c, frontendCallback, "config_error", "oauth redirect url not configured", "")
return
}
tokenResp, err := linuxDoExchangeCode(c.Request.Context(), cfg, code, redirectURI, codeVerifier)
if err != nil {
description := ""
var exchangeErr *linuxDoTokenExchangeError
if errors.As(err, &exchangeErr) && exchangeErr != nil {
log.Printf(
"[LinuxDo OAuth] token exchange failed: status=%d provider_error=%q provider_description=%q body=%s",
exchangeErr.StatusCode,
exchangeErr.ProviderError,
exchangeErr.ProviderDescription,
truncateLogValue(exchangeErr.Body, 2048),
)
description = exchangeErr.Error()
} else {
log.Printf("[LinuxDo OAuth] token exchange failed: %v", err)
description = err.Error()
}
redirectOAuthError(c, frontendCallback, "token_exchange_failed", "failed to exchange oauth code", singleLine(description))
return
}
email, username, _, err := linuxDoFetchUserInfo(c.Request.Context(), cfg, tokenResp)
if err != nil {
log.Printf("[LinuxDo OAuth] userinfo fetch failed: %v", err)
redirectOAuthError(c, frontendCallback, "userinfo_failed", "failed to fetch user info", "")
return
}
jwtToken, _, err := h.authService.LoginOrRegisterOAuth(c.Request.Context(), email, username)
if err != nil {
// Avoid leaking internal details to the client; keep structured reason for frontend.
redirectOAuthError(c, frontendCallback, "login_failed", infraerrors.Reason(err), infraerrors.Message(err))
return
}
fragment := url.Values{}
fragment.Set("access_token", jwtToken)
fragment.Set("token_type", "Bearer")
fragment.Set("redirect", redirectTo)
redirectWithFragment(c, frontendCallback, fragment)
}
func (h *AuthHandler) getLinuxDoOAuthConfig(ctx context.Context) (config.LinuxDoConnectConfig, error) {
if h != nil && h.settingSvc != nil {
return h.settingSvc.GetLinuxDoConnectOAuthConfig(ctx)
}
if h == nil || h.cfg == nil {
return config.LinuxDoConnectConfig{}, infraerrors.ServiceUnavailable("CONFIG_NOT_READY", "config not loaded")
}
if !h.cfg.LinuxDo.Enabled {
return config.LinuxDoConnectConfig{}, infraerrors.NotFound("OAUTH_DISABLED", "oauth login is disabled")
}
return h.cfg.LinuxDo, nil
}
func linuxDoExchangeCode(
ctx context.Context,
cfg config.LinuxDoConnectConfig,
code string,
redirectURI string,
codeVerifier string,
) (*linuxDoTokenResponse, error) {
client := req.C().SetTimeout(30 * time.Second)
form := url.Values{}
form.Set("grant_type", "authorization_code")
form.Set("client_id", cfg.ClientID)
form.Set("code", code)
form.Set("redirect_uri", redirectURI)
if cfg.UsePKCE {
form.Set("code_verifier", codeVerifier)
}
r := client.R().
SetContext(ctx).
SetHeader("Accept", "application/json")
switch strings.ToLower(strings.TrimSpace(cfg.TokenAuthMethod)) {
case "", "client_secret_post":
form.Set("client_secret", cfg.ClientSecret)
case "client_secret_basic":
r.SetBasicAuth(cfg.ClientID, cfg.ClientSecret)
case "none":
default:
return nil, fmt.Errorf("unsupported token_auth_method: %s", cfg.TokenAuthMethod)
}
resp, err := r.SetFormDataFromValues(form).Post(cfg.TokenURL)
if err != nil {
return nil, fmt.Errorf("request token: %w", err)
}
body := strings.TrimSpace(resp.String())
if !resp.IsSuccessState() {
providerErr, providerDesc := parseOAuthProviderError(body)
return nil, &linuxDoTokenExchangeError{
StatusCode: resp.StatusCode,
ProviderError: providerErr,
ProviderDescription: providerDesc,
Body: body,
}
}
tokenResp, ok := parseLinuxDoTokenResponse(body)
if !ok || strings.TrimSpace(tokenResp.AccessToken) == "" {
return nil, &linuxDoTokenExchangeError{
StatusCode: resp.StatusCode,
Body: body,
}
}
if strings.TrimSpace(tokenResp.TokenType) == "" {
tokenResp.TokenType = "Bearer"
}
return tokenResp, nil
}
func linuxDoFetchUserInfo(
ctx context.Context,
cfg config.LinuxDoConnectConfig,
token *linuxDoTokenResponse,
) (email string, username string, subject string, err error) {
client := req.C().SetTimeout(30 * time.Second)
authorization, err := buildBearerAuthorization(token.TokenType, token.AccessToken)
if err != nil {
return "", "", "", fmt.Errorf("invalid token for userinfo request: %w", err)
}
resp, err := client.R().
SetContext(ctx).
SetHeader("Accept", "application/json").
SetHeader("Authorization", authorization).
Get(cfg.UserInfoURL)
if err != nil {
return "", "", "", fmt.Errorf("request userinfo: %w", err)
}
if !resp.IsSuccessState() {
return "", "", "", fmt.Errorf("userinfo status=%d", resp.StatusCode)
}
return linuxDoParseUserInfo(resp.String(), cfg)
}
func linuxDoParseUserInfo(body string, cfg config.LinuxDoConnectConfig) (email string, username string, subject string, err error) {
email = firstNonEmpty(
getGJSON(body, cfg.UserInfoEmailPath),
getGJSON(body, "email"),
getGJSON(body, "user.email"),
getGJSON(body, "data.email"),
getGJSON(body, "attributes.email"),
)
username = firstNonEmpty(
getGJSON(body, cfg.UserInfoUsernamePath),
getGJSON(body, "username"),
getGJSON(body, "preferred_username"),
getGJSON(body, "name"),
getGJSON(body, "user.username"),
getGJSON(body, "user.name"),
)
subject = firstNonEmpty(
getGJSON(body, cfg.UserInfoIDPath),
getGJSON(body, "sub"),
getGJSON(body, "id"),
getGJSON(body, "user_id"),
getGJSON(body, "uid"),
getGJSON(body, "user.id"),
)
subject = strings.TrimSpace(subject)
if subject == "" {
return "", "", "", errors.New("userinfo missing id field")
}
if !isSafeLinuxDoSubject(subject) {
return "", "", "", errors.New("userinfo returned invalid id field")
}
email = strings.TrimSpace(email)
if email == "" {
// LinuxDo Connect userinfo does not necessarily provide email. To keep compatibility with the
// existing user schema (email is required/unique), use a stable synthetic email.
email = fmt.Sprintf("linuxdo-%s@linuxdo-connect.invalid", subject)
}
username = strings.TrimSpace(username)
if username == "" {
username = "linuxdo_" + subject
}
return email, username, subject, nil
}
func buildLinuxDoAuthorizeURL(cfg config.LinuxDoConnectConfig, state string, codeChallenge string, redirectURI string) (string, error) {
u, err := url.Parse(cfg.AuthorizeURL)
if err != nil {
return "", fmt.Errorf("parse authorize_url: %w", err)
}
q := u.Query()
q.Set("response_type", "code")
q.Set("client_id", cfg.ClientID)
q.Set("redirect_uri", redirectURI)
if strings.TrimSpace(cfg.Scopes) != "" {
q.Set("scope", cfg.Scopes)
}
q.Set("state", state)
if cfg.UsePKCE {
q.Set("code_challenge", codeChallenge)
q.Set("code_challenge_method", "S256")
}
u.RawQuery = q.Encode()
return u.String(), nil
}
func redirectOAuthError(c *gin.Context, frontendCallback string, code string, message string, description string) {
fragment := url.Values{}
fragment.Set("error", truncateFragmentValue(code))
if strings.TrimSpace(message) != "" {
fragment.Set("error_message", truncateFragmentValue(message))
}
if strings.TrimSpace(description) != "" {
fragment.Set("error_description", truncateFragmentValue(description))
}
redirectWithFragment(c, frontendCallback, fragment)
}
func redirectWithFragment(c *gin.Context, frontendCallback string, fragment url.Values) {
u, err := url.Parse(frontendCallback)
if err != nil {
// Fallback: best-effort redirect.
c.Redirect(http.StatusFound, linuxDoOAuthDefaultRedirectTo)
return
}
if u.Scheme != "" && !strings.EqualFold(u.Scheme, "http") && !strings.EqualFold(u.Scheme, "https") {
c.Redirect(http.StatusFound, linuxDoOAuthDefaultRedirectTo)
return
}
u.Fragment = fragment.Encode()
c.Header("Cache-Control", "no-store")
c.Header("Pragma", "no-cache")
c.Redirect(http.StatusFound, u.String())
}
func firstNonEmpty(values ...string) string {
for _, v := range values {
v = strings.TrimSpace(v)
if v != "" {
return v
}
}
return ""
}
func parseOAuthProviderError(body string) (providerErr string, providerDesc string) {
body = strings.TrimSpace(body)
if body == "" {
return "", ""
}
providerErr = firstNonEmpty(
getGJSON(body, "error"),
getGJSON(body, "code"),
getGJSON(body, "error.code"),
)
providerDesc = firstNonEmpty(
getGJSON(body, "error_description"),
getGJSON(body, "error.message"),
getGJSON(body, "message"),
getGJSON(body, "detail"),
)
if providerErr != "" || providerDesc != "" {
return providerErr, providerDesc
}
values, err := url.ParseQuery(body)
if err != nil {
return "", ""
}
providerErr = firstNonEmpty(values.Get("error"), values.Get("code"))
providerDesc = firstNonEmpty(values.Get("error_description"), values.Get("error_message"), values.Get("message"))
return providerErr, providerDesc
}
func parseLinuxDoTokenResponse(body string) (*linuxDoTokenResponse, bool) {
body = strings.TrimSpace(body)
if body == "" {
return nil, false
}
accessToken := strings.TrimSpace(getGJSON(body, "access_token"))
if accessToken != "" {
tokenType := strings.TrimSpace(getGJSON(body, "token_type"))
refreshToken := strings.TrimSpace(getGJSON(body, "refresh_token"))
scope := strings.TrimSpace(getGJSON(body, "scope"))
expiresIn := gjson.Get(body, "expires_in").Int()
return &linuxDoTokenResponse{
AccessToken: accessToken,
TokenType: tokenType,
ExpiresIn: expiresIn,
RefreshToken: refreshToken,
Scope: scope,
}, true
}
values, err := url.ParseQuery(body)
if err != nil {
return nil, false
}
accessToken = strings.TrimSpace(values.Get("access_token"))
if accessToken == "" {
return nil, false
}
expiresIn := int64(0)
if raw := strings.TrimSpace(values.Get("expires_in")); raw != "" {
if v, err := strconv.ParseInt(raw, 10, 64); err == nil {
expiresIn = v
}
}
return &linuxDoTokenResponse{
AccessToken: accessToken,
TokenType: strings.TrimSpace(values.Get("token_type")),
ExpiresIn: expiresIn,
RefreshToken: strings.TrimSpace(values.Get("refresh_token")),
Scope: strings.TrimSpace(values.Get("scope")),
}, true
}
func getGJSON(body string, path string) string {
path = strings.TrimSpace(path)
if path == "" {
return ""
}
res := gjson.Get(body, path)
if !res.Exists() {
return ""
}
return res.String()
}
func truncateLogValue(value string, maxLen int) string {
value = strings.TrimSpace(value)
if value == "" || maxLen <= 0 {
return ""
}
if len(value) <= maxLen {
return value
}
value = value[:maxLen]
for !utf8.ValidString(value) {
value = value[:len(value)-1]
}
return value
}
func singleLine(value string) string {
value = strings.TrimSpace(value)
if value == "" {
return ""
}
return strings.Join(strings.Fields(value), " ")
}
func sanitizeFrontendRedirectPath(path string) string {
path = strings.TrimSpace(path)
if path == "" {
return ""
}
if len(path) > linuxDoOAuthMaxRedirectLen {
return ""
}
// Only allow same-origin relative paths (avoid open redirect).
if !strings.HasPrefix(path, "/") {
return ""
}
if strings.HasPrefix(path, "//") {
return ""
}
if strings.Contains(path, "://") {
return ""
}
if strings.ContainsAny(path, "\r\n") {
return ""
}
return path
}
func isRequestHTTPS(c *gin.Context) bool {
if c.Request.TLS != nil {
return true
}
proto := strings.ToLower(strings.TrimSpace(c.GetHeader("X-Forwarded-Proto")))
return proto == "https"
}
func encodeCookieValue(value string) string {
return base64.RawURLEncoding.EncodeToString([]byte(value))
}
func decodeCookieValue(value string) (string, error) {
raw, err := base64.RawURLEncoding.DecodeString(value)
if err != nil {
return "", err
}
return string(raw), nil
}
func readCookieDecoded(c *gin.Context, name string) (string, error) {
ck, err := c.Request.Cookie(name)
if err != nil {
return "", err
}
return decodeCookieValue(ck.Value)
}
func setCookie(c *gin.Context, name string, value string, maxAgeSec int, secure bool) {
http.SetCookie(c.Writer, &http.Cookie{
Name: name,
Value: value,
Path: linuxDoOAuthCookiePath,
MaxAge: maxAgeSec,
HttpOnly: true,
Secure: secure,
SameSite: http.SameSiteLaxMode,
})
}
func clearCookie(c *gin.Context, name string, secure bool) {
http.SetCookie(c.Writer, &http.Cookie{
Name: name,
Value: "",
Path: linuxDoOAuthCookiePath,
MaxAge: -1,
HttpOnly: true,
Secure: secure,
SameSite: http.SameSiteLaxMode,
})
}
func truncateFragmentValue(value string) string {
value = strings.TrimSpace(value)
if value == "" {
return ""
}
if len(value) > linuxDoOAuthMaxFragmentValueLen {
value = value[:linuxDoOAuthMaxFragmentValueLen]
for !utf8.ValidString(value) {
value = value[:len(value)-1]
}
}
return value
}
func buildBearerAuthorization(tokenType, accessToken string) (string, error) {
tokenType = strings.TrimSpace(tokenType)
if tokenType == "" {
tokenType = "Bearer"
}
if !strings.EqualFold(tokenType, "Bearer") {
return "", fmt.Errorf("unsupported token_type: %s", tokenType)
}
accessToken = strings.TrimSpace(accessToken)
if accessToken == "" {
return "", errors.New("missing access_token")
}
if strings.ContainsAny(accessToken, " \t\r\n") {
return "", errors.New("access_token contains whitespace")
}
return "Bearer " + accessToken, nil
}
func isSafeLinuxDoSubject(subject string) bool {
subject = strings.TrimSpace(subject)
if subject == "" || len(subject) > linuxDoOAuthMaxSubjectLen {
return false
}
for _, r := range subject {
switch {
case r >= '0' && r <= '9':
case r >= 'a' && r <= 'z':
case r >= 'A' && r <= 'Z':
case r == '_' || r == '-':
default:
return false
}
}
return true
}