package handler import ( "context" "encoding/base64" "errors" "fmt" "log" "net/http" "net/url" "strings" "time" "unicode/utf8" "github.com/Wei-Shaw/sub2api/internal/config" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" "github.com/Wei-Shaw/sub2api/internal/pkg/oauth" "github.com/Wei-Shaw/sub2api/internal/pkg/response" "github.com/gin-gonic/gin" "github.com/imroc/req/v3" "github.com/tidwall/gjson" ) const ( linuxDoOAuthCookiePath = "/api/v1/auth/oauth/linuxdo" linuxDoOAuthStateCookieName = "linuxdo_oauth_state" linuxDoOAuthVerifierCookie = "linuxdo_oauth_verifier" linuxDoOAuthRedirectCookie = "linuxdo_oauth_redirect" linuxDoOAuthCookieMaxAgeSec = 10 * 60 // 10 minutes linuxDoOAuthDefaultRedirectTo = "/dashboard" linuxDoOAuthDefaultFrontendCB = "/auth/linuxdo/callback" linuxDoOAuthMaxRedirectLen = 2048 linuxDoOAuthMaxFragmentValueLen = 512 linuxDoOAuthMaxSubjectLen = 64 - len("linuxdo-") ) type linuxDoTokenResponse struct { AccessToken string `json:"access_token"` TokenType string `json:"token_type"` ExpiresIn int64 `json:"expires_in"` RefreshToken string `json:"refresh_token,omitempty"` Scope string `json:"scope,omitempty"` } // LinuxDoOAuthStart starts the LinuxDo Connect OAuth login flow. // GET /api/v1/auth/oauth/linuxdo/start?redirect=/dashboard func (h *AuthHandler) LinuxDoOAuthStart(c *gin.Context) { cfg, err := linuxDoOAuthConfig(h.cfg) if err != nil { response.ErrorFrom(c, err) return } state, err := oauth.GenerateState() if err != nil { response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_STATE_GEN_FAILED", "failed to generate oauth state").WithCause(err)) return } redirectTo := sanitizeFrontendRedirectPath(c.Query("redirect")) if redirectTo == "" { redirectTo = linuxDoOAuthDefaultRedirectTo } secureCookie := isRequestHTTPS(c) setCookie(c, linuxDoOAuthStateCookieName, encodeCookieValue(state), linuxDoOAuthCookieMaxAgeSec, secureCookie) setCookie(c, linuxDoOAuthRedirectCookie, encodeCookieValue(redirectTo), linuxDoOAuthCookieMaxAgeSec, secureCookie) codeChallenge := "" if cfg.UsePKCE { verifier, err := oauth.GenerateCodeVerifier() if err != nil { response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_PKCE_GEN_FAILED", "failed to generate pkce verifier").WithCause(err)) return } codeChallenge = oauth.GenerateCodeChallenge(verifier) setCookie(c, linuxDoOAuthVerifierCookie, encodeCookieValue(verifier), linuxDoOAuthCookieMaxAgeSec, secureCookie) } redirectURI := strings.TrimSpace(cfg.RedirectURL) if redirectURI == "" { response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth redirect url not configured")) return } authURL, err := buildLinuxDoAuthorizeURL(cfg, state, codeChallenge, redirectURI) if err != nil { response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_BUILD_URL_FAILED", "failed to build oauth authorization url").WithCause(err)) return } c.Redirect(http.StatusFound, authURL) } // LinuxDoOAuthCallback handles the OAuth callback, creates/logins the user, then redirects to frontend. // GET /api/v1/auth/oauth/linuxdo/callback?code=...&state=... func (h *AuthHandler) LinuxDoOAuthCallback(c *gin.Context) { cfg, cfgErr := linuxDoOAuthConfig(h.cfg) if cfgErr != nil { response.ErrorFrom(c, cfgErr) return } frontendCallback := strings.TrimSpace(cfg.FrontendRedirectURL) if frontendCallback == "" { frontendCallback = linuxDoOAuthDefaultFrontendCB } if providerErr := strings.TrimSpace(c.Query("error")); providerErr != "" { redirectOAuthError(c, frontendCallback, "provider_error", providerErr, c.Query("error_description")) return } code := strings.TrimSpace(c.Query("code")) state := strings.TrimSpace(c.Query("state")) if code == "" || state == "" { redirectOAuthError(c, frontendCallback, "missing_params", "missing code/state", "") return } secureCookie := isRequestHTTPS(c) defer func() { clearCookie(c, linuxDoOAuthStateCookieName, secureCookie) clearCookie(c, linuxDoOAuthVerifierCookie, secureCookie) clearCookie(c, linuxDoOAuthRedirectCookie, secureCookie) }() expectedState, err := readCookieDecoded(c, linuxDoOAuthStateCookieName) if err != nil || expectedState == "" || state != expectedState { redirectOAuthError(c, frontendCallback, "invalid_state", "invalid oauth state", "") return } redirectTo, _ := readCookieDecoded(c, linuxDoOAuthRedirectCookie) redirectTo = sanitizeFrontendRedirectPath(redirectTo) if redirectTo == "" { redirectTo = linuxDoOAuthDefaultRedirectTo } codeVerifier := "" if cfg.UsePKCE { codeVerifier, _ = readCookieDecoded(c, linuxDoOAuthVerifierCookie) if codeVerifier == "" { redirectOAuthError(c, frontendCallback, "missing_verifier", "missing pkce verifier", "") return } } redirectURI := strings.TrimSpace(cfg.RedirectURL) if redirectURI == "" { redirectOAuthError(c, frontendCallback, "config_error", "oauth redirect url not configured", "") return } tokenResp, err := linuxDoExchangeCode(c.Request.Context(), cfg, code, redirectURI, codeVerifier) if err != nil { log.Printf("[LinuxDo OAuth] token exchange failed: %v", err) redirectOAuthError(c, frontendCallback, "token_exchange_failed", "failed to exchange oauth code", "") return } email, username, _, err := linuxDoFetchUserInfo(c.Request.Context(), cfg, tokenResp) if err != nil { log.Printf("[LinuxDo OAuth] userinfo fetch failed: %v", err) redirectOAuthError(c, frontendCallback, "userinfo_failed", "failed to fetch user info", "") return } jwtToken, _, err := h.authService.LoginOrRegisterOAuth(c.Request.Context(), email, username) if err != nil { // Avoid leaking internal details to the client; keep structured reason for frontend. redirectOAuthError(c, frontendCallback, "login_failed", infraerrors.Reason(err), infraerrors.Message(err)) return } fragment := url.Values{} fragment.Set("access_token", jwtToken) fragment.Set("token_type", "Bearer") fragment.Set("redirect", redirectTo) redirectWithFragment(c, frontendCallback, fragment) } func linuxDoOAuthConfig(cfg *config.Config) (config.LinuxDoConnectConfig, error) { if cfg == nil { return config.LinuxDoConnectConfig{}, infraerrors.ServiceUnavailable("CONFIG_NOT_READY", "config not loaded") } if !cfg.LinuxDo.Enabled { return config.LinuxDoConnectConfig{}, infraerrors.NotFound("OAUTH_DISABLED", "oauth login is disabled") } return cfg.LinuxDo, nil } func linuxDoExchangeCode( ctx context.Context, cfg config.LinuxDoConnectConfig, code string, redirectURI string, codeVerifier string, ) (*linuxDoTokenResponse, error) { client := req.C().SetTimeout(30 * time.Second) form := url.Values{} form.Set("grant_type", "authorization_code") form.Set("client_id", cfg.ClientID) form.Set("code", code) form.Set("redirect_uri", redirectURI) if cfg.UsePKCE { form.Set("code_verifier", codeVerifier) } r := client.R(). SetContext(ctx). SetHeader("Accept", "application/json") switch strings.ToLower(strings.TrimSpace(cfg.TokenAuthMethod)) { case "", "client_secret_post": form.Set("client_secret", cfg.ClientSecret) case "client_secret_basic": r.SetBasicAuth(cfg.ClientID, cfg.ClientSecret) case "none": default: return nil, fmt.Errorf("unsupported token_auth_method: %s", cfg.TokenAuthMethod) } var tokenResp linuxDoTokenResponse resp, err := r.SetFormDataFromValues(form).SetSuccessResult(&tokenResp).Post(cfg.TokenURL) if err != nil { return nil, fmt.Errorf("request token: %w", err) } if !resp.IsSuccessState() { return nil, fmt.Errorf("token exchange status=%d", resp.StatusCode) } if strings.TrimSpace(tokenResp.AccessToken) == "" { return nil, errors.New("token response missing access_token") } if strings.TrimSpace(tokenResp.TokenType) == "" { tokenResp.TokenType = "Bearer" } return &tokenResp, nil } func linuxDoFetchUserInfo( ctx context.Context, cfg config.LinuxDoConnectConfig, token *linuxDoTokenResponse, ) (email string, username string, subject string, err error) { client := req.C().SetTimeout(30 * time.Second) authorization, err := buildBearerAuthorization(token.TokenType, token.AccessToken) if err != nil { return "", "", "", fmt.Errorf("invalid token for userinfo request: %w", err) } resp, err := client.R(). SetContext(ctx). SetHeader("Accept", "application/json"). SetHeader("Authorization", authorization). Get(cfg.UserInfoURL) if err != nil { return "", "", "", fmt.Errorf("request userinfo: %w", err) } if !resp.IsSuccessState() { return "", "", "", fmt.Errorf("userinfo status=%d", resp.StatusCode) } return linuxDoParseUserInfo(resp.String(), cfg) } func linuxDoParseUserInfo(body string, cfg config.LinuxDoConnectConfig) (email string, username string, subject string, err error) { email = firstNonEmpty( getGJSON(body, cfg.UserInfoEmailPath), getGJSON(body, "email"), getGJSON(body, "user.email"), getGJSON(body, "data.email"), getGJSON(body, "attributes.email"), ) username = firstNonEmpty( getGJSON(body, cfg.UserInfoUsernamePath), getGJSON(body, "username"), getGJSON(body, "preferred_username"), getGJSON(body, "name"), getGJSON(body, "user.username"), getGJSON(body, "user.name"), ) subject = firstNonEmpty( getGJSON(body, cfg.UserInfoIDPath), getGJSON(body, "sub"), getGJSON(body, "id"), getGJSON(body, "user_id"), getGJSON(body, "uid"), getGJSON(body, "user.id"), ) subject = strings.TrimSpace(subject) if subject == "" { return "", "", "", errors.New("userinfo missing id field") } if !isSafeLinuxDoSubject(subject) { return "", "", "", errors.New("userinfo returned invalid id field") } email = strings.TrimSpace(email) if email == "" { // LinuxDo Connect userinfo does not necessarily provide email. To keep compatibility with the // existing user schema (email is required/unique), use a stable synthetic email. email = fmt.Sprintf("linuxdo-%s@linuxdo-connect.invalid", subject) } username = strings.TrimSpace(username) if username == "" { username = "linuxdo_" + subject } return email, username, subject, nil } func buildLinuxDoAuthorizeURL(cfg config.LinuxDoConnectConfig, state string, codeChallenge string, redirectURI string) (string, error) { u, err := url.Parse(cfg.AuthorizeURL) if err != nil { return "", fmt.Errorf("parse authorize_url: %w", err) } q := u.Query() q.Set("response_type", "code") q.Set("client_id", cfg.ClientID) q.Set("redirect_uri", redirectURI) if strings.TrimSpace(cfg.Scopes) != "" { q.Set("scope", cfg.Scopes) } q.Set("state", state) if cfg.UsePKCE { q.Set("code_challenge", codeChallenge) q.Set("code_challenge_method", "S256") } u.RawQuery = q.Encode() return u.String(), nil } func redirectOAuthError(c *gin.Context, frontendCallback string, code string, message string, description string) { fragment := url.Values{} fragment.Set("error", truncateFragmentValue(code)) if strings.TrimSpace(message) != "" { fragment.Set("error_message", truncateFragmentValue(message)) } if strings.TrimSpace(description) != "" { fragment.Set("error_description", truncateFragmentValue(description)) } redirectWithFragment(c, frontendCallback, fragment) } func redirectWithFragment(c *gin.Context, frontendCallback string, fragment url.Values) { u, err := url.Parse(frontendCallback) if err != nil { // Fallback: best-effort redirect. c.Redirect(http.StatusFound, linuxDoOAuthDefaultRedirectTo) return } if u.Scheme != "" && !strings.EqualFold(u.Scheme, "http") && !strings.EqualFold(u.Scheme, "https") { c.Redirect(http.StatusFound, linuxDoOAuthDefaultRedirectTo) return } u.Fragment = fragment.Encode() c.Header("Cache-Control", "no-store") c.Header("Pragma", "no-cache") c.Redirect(http.StatusFound, u.String()) } func firstNonEmpty(values ...string) string { for _, v := range values { v = strings.TrimSpace(v) if v != "" { return v } } return "" } func getGJSON(body string, path string) string { path = strings.TrimSpace(path) if path == "" { return "" } res := gjson.Get(body, path) if !res.Exists() { return "" } return res.String() } func sanitizeFrontendRedirectPath(path string) string { path = strings.TrimSpace(path) if path == "" { return "" } if len(path) > linuxDoOAuthMaxRedirectLen { return "" } // Only allow same-origin relative paths (avoid open redirect). if !strings.HasPrefix(path, "/") { return "" } if strings.HasPrefix(path, "//") { return "" } if strings.Contains(path, "://") { return "" } if strings.ContainsAny(path, "\r\n") { return "" } return path } func isRequestHTTPS(c *gin.Context) bool { if c.Request.TLS != nil { return true } proto := strings.ToLower(strings.TrimSpace(c.GetHeader("X-Forwarded-Proto"))) return proto == "https" } func encodeCookieValue(value string) string { return base64.RawURLEncoding.EncodeToString([]byte(value)) } func decodeCookieValue(value string) (string, error) { raw, err := base64.RawURLEncoding.DecodeString(value) if err != nil { return "", err } return string(raw), nil } func readCookieDecoded(c *gin.Context, name string) (string, error) { ck, err := c.Request.Cookie(name) if err != nil { return "", err } return decodeCookieValue(ck.Value) } func setCookie(c *gin.Context, name string, value string, maxAgeSec int, secure bool) { http.SetCookie(c.Writer, &http.Cookie{ Name: name, Value: value, Path: linuxDoOAuthCookiePath, MaxAge: maxAgeSec, HttpOnly: true, Secure: secure, SameSite: http.SameSiteLaxMode, }) } func clearCookie(c *gin.Context, name string, secure bool) { http.SetCookie(c.Writer, &http.Cookie{ Name: name, Value: "", Path: linuxDoOAuthCookiePath, MaxAge: -1, HttpOnly: true, Secure: secure, SameSite: http.SameSiteLaxMode, }) } func truncateFragmentValue(value string) string { value = strings.TrimSpace(value) if value == "" { return "" } if len(value) > linuxDoOAuthMaxFragmentValueLen { value = value[:linuxDoOAuthMaxFragmentValueLen] for !utf8.ValidString(value) { value = value[:len(value)-1] } } return value } func buildBearerAuthorization(tokenType, accessToken string) (string, error) { tokenType = strings.TrimSpace(tokenType) if tokenType == "" { tokenType = "Bearer" } if !strings.EqualFold(tokenType, "Bearer") { return "", fmt.Errorf("unsupported token_type: %s", tokenType) } accessToken = strings.TrimSpace(accessToken) if accessToken == "" { return "", errors.New("missing access_token") } if strings.ContainsAny(accessToken, " \t\r\n") { return "", errors.New("access_token contains whitespace") } return "Bearer " + accessToken, nil } func isSafeLinuxDoSubject(subject string) bool { subject = strings.TrimSpace(subject) if subject == "" || len(subject) > linuxDoOAuthMaxSubjectLen { return false } for _, r := range subject { switch { case r >= '0' && r <= '9': case r >= 'a' && r <= 'z': case r >= 'A' && r <= 'Z': case r == '_' || r == '-': default: return false } } return true }