First commit

This commit is contained in:
shaw
2025-12-18 13:50:39 +08:00
parent 569f4882e5
commit 642842c29e
218 changed files with 86902 additions and 0 deletions

View File

@@ -0,0 +1,223 @@
package oauth
import (
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"encoding/hex"
"fmt"
"net/url"
"strings"
"sync"
"time"
)
// Claude OAuth Constants (from CRS project)
const (
// OAuth Client ID for Claude
ClientID = "9d1c250a-e61b-44d9-88ed-5944d1962f5e"
// OAuth endpoints
AuthorizeURL = "https://claude.ai/oauth/authorize"
TokenURL = "https://console.anthropic.com/v1/oauth/token"
RedirectURI = "https://console.anthropic.com/oauth/code/callback"
// Scopes
ScopeProfile = "user:profile"
ScopeInference = "user:inference"
// Session TTL
SessionTTL = 30 * time.Minute
)
// OAuthSession stores OAuth flow state
type OAuthSession struct {
State string `json:"state"`
CodeVerifier string `json:"code_verifier"`
Scope string `json:"scope"`
ProxyURL string `json:"proxy_url,omitempty"`
CreatedAt time.Time `json:"created_at"`
}
// SessionStore manages OAuth sessions in memory
type SessionStore struct {
mu sync.RWMutex
sessions map[string]*OAuthSession
}
// NewSessionStore creates a new session store
func NewSessionStore() *SessionStore {
store := &SessionStore{
sessions: make(map[string]*OAuthSession),
}
// Start cleanup goroutine
go store.cleanup()
return store
}
// Set stores a session
func (s *SessionStore) Set(sessionID string, session *OAuthSession) {
s.mu.Lock()
defer s.mu.Unlock()
s.sessions[sessionID] = session
}
// Get retrieves a session
func (s *SessionStore) Get(sessionID string) (*OAuthSession, bool) {
s.mu.RLock()
defer s.mu.RUnlock()
session, ok := s.sessions[sessionID]
if !ok {
return nil, false
}
// Check if expired
if time.Since(session.CreatedAt) > SessionTTL {
return nil, false
}
return session, true
}
// Delete removes a session
func (s *SessionStore) Delete(sessionID string) {
s.mu.Lock()
defer s.mu.Unlock()
delete(s.sessions, sessionID)
}
// cleanup removes expired sessions periodically
func (s *SessionStore) cleanup() {
ticker := time.NewTicker(5 * time.Minute)
for range ticker.C {
s.mu.Lock()
for id, session := range s.sessions {
if time.Since(session.CreatedAt) > SessionTTL {
delete(s.sessions, id)
}
}
s.mu.Unlock()
}
}
// GenerateRandomBytes generates cryptographically secure random bytes
func GenerateRandomBytes(n int) ([]byte, error) {
b := make([]byte, n)
_, err := rand.Read(b)
if err != nil {
return nil, err
}
return b, nil
}
// GenerateState generates a random state string for OAuth
func GenerateState() (string, error) {
bytes, err := GenerateRandomBytes(32)
if err != nil {
return "", err
}
return hex.EncodeToString(bytes), nil
}
// GenerateSessionID generates a unique session ID
func GenerateSessionID() (string, error) {
bytes, err := GenerateRandomBytes(16)
if err != nil {
return "", err
}
return hex.EncodeToString(bytes), nil
}
// GenerateCodeVerifier generates a PKCE code verifier (32 bytes -> base64url)
func GenerateCodeVerifier() (string, error) {
bytes, err := GenerateRandomBytes(32)
if err != nil {
return "", err
}
return base64URLEncode(bytes), nil
}
// GenerateCodeChallenge generates a PKCE code challenge using S256 method
func GenerateCodeChallenge(verifier string) string {
hash := sha256.Sum256([]byte(verifier))
return base64URLEncode(hash[:])
}
// base64URLEncode encodes bytes to base64url without padding
func base64URLEncode(data []byte) string {
encoded := base64.URLEncoding.EncodeToString(data)
// Remove padding
return strings.TrimRight(encoded, "=")
}
// BuildAuthorizationURL builds the OAuth authorization URL
func BuildAuthorizationURL(state, codeChallenge, scope string) string {
params := url.Values{}
params.Set("response_type", "code")
params.Set("client_id", ClientID)
params.Set("redirect_uri", RedirectURI)
params.Set("scope", scope)
params.Set("state", state)
params.Set("code_challenge", codeChallenge)
params.Set("code_challenge_method", "S256")
return fmt.Sprintf("%s?%s", AuthorizeURL, params.Encode())
}
// TokenRequest represents the token exchange request body
type TokenRequest struct {
GrantType string `json:"grant_type"`
ClientID string `json:"client_id"`
Code string `json:"code"`
RedirectURI string `json:"redirect_uri"`
CodeVerifier string `json:"code_verifier"`
State string `json:"state"`
}
// TokenResponse represents the token response from OAuth provider
type TokenResponse struct {
AccessToken string `json:"access_token"`
TokenType string `json:"token_type"`
ExpiresIn int64 `json:"expires_in"`
RefreshToken string `json:"refresh_token,omitempty"`
Scope string `json:"scope,omitempty"`
// Organization and Account info from OAuth response
Organization *OrgInfo `json:"organization,omitempty"`
Account *AccountInfo `json:"account,omitempty"`
}
// OrgInfo represents organization info from OAuth response
type OrgInfo struct {
UUID string `json:"uuid"`
}
// AccountInfo represents account info from OAuth response
type AccountInfo struct {
UUID string `json:"uuid"`
}
// RefreshTokenRequest represents the refresh token request
type RefreshTokenRequest struct {
GrantType string `json:"grant_type"`
RefreshToken string `json:"refresh_token"`
ClientID string `json:"client_id"`
}
// BuildTokenRequest creates a token exchange request
func BuildTokenRequest(code, codeVerifier, state string) *TokenRequest {
return &TokenRequest{
GrantType: "authorization_code",
ClientID: ClientID,
Code: code,
RedirectURI: RedirectURI,
CodeVerifier: codeVerifier,
State: state,
}
}
// BuildRefreshTokenRequest creates a refresh token request
func BuildRefreshTokenRequest(refreshToken string) *RefreshTokenRequest {
return &RefreshTokenRequest{
GrantType: "refresh_token",
RefreshToken: refreshToken,
ClientID: ClientID,
}
}

View File

@@ -0,0 +1,157 @@
package response
import (
"math"
"net/http"
"github.com/gin-gonic/gin"
)
// Response 标准API响应格式
type Response struct {
Code int `json:"code"`
Message string `json:"message"`
Data interface{} `json:"data,omitempty"`
}
// PaginatedData 分页数据格式(匹配前端期望)
type PaginatedData struct {
Items interface{} `json:"items"`
Total int64 `json:"total"`
Page int `json:"page"`
PageSize int `json:"page_size"`
Pages int `json:"pages"`
}
// Success 返回成功响应
func Success(c *gin.Context, data interface{}) {
c.JSON(http.StatusOK, Response{
Code: 0,
Message: "success",
Data: data,
})
}
// Created 返回创建成功响应
func Created(c *gin.Context, data interface{}) {
c.JSON(http.StatusCreated, Response{
Code: 0,
Message: "success",
Data: data,
})
}
// Error 返回错误响应
func Error(c *gin.Context, statusCode int, message string) {
c.JSON(statusCode, Response{
Code: statusCode,
Message: message,
})
}
// BadRequest 返回400错误
func BadRequest(c *gin.Context, message string) {
Error(c, http.StatusBadRequest, message)
}
// Unauthorized 返回401错误
func Unauthorized(c *gin.Context, message string) {
Error(c, http.StatusUnauthorized, message)
}
// Forbidden 返回403错误
func Forbidden(c *gin.Context, message string) {
Error(c, http.StatusForbidden, message)
}
// NotFound 返回404错误
func NotFound(c *gin.Context, message string) {
Error(c, http.StatusNotFound, message)
}
// InternalError 返回500错误
func InternalError(c *gin.Context, message string) {
Error(c, http.StatusInternalServerError, message)
}
// Paginated 返回分页数据
func Paginated(c *gin.Context, items interface{}, total int64, page, pageSize int) {
pages := int(math.Ceil(float64(total) / float64(pageSize)))
if pages < 1 {
pages = 1
}
Success(c, PaginatedData{
Items: items,
Total: total,
Page: page,
PageSize: pageSize,
Pages: pages,
})
}
// PaginationResult 分页结果与repository.PaginationResult兼容
type PaginationResult struct {
Total int64
Page int
PageSize int
Pages int
}
// PaginatedWithResult 使用PaginationResult返回分页数据
func PaginatedWithResult(c *gin.Context, items interface{}, pagination *PaginationResult) {
if pagination == nil {
Success(c, PaginatedData{
Items: items,
Total: 0,
Page: 1,
PageSize: 20,
Pages: 1,
})
return
}
Success(c, PaginatedData{
Items: items,
Total: pagination.Total,
Page: pagination.Page,
PageSize: pagination.PageSize,
Pages: pagination.Pages,
})
}
// ParsePagination 解析分页参数
func ParsePagination(c *gin.Context) (page, pageSize int) {
page = 1
pageSize = 20
if p := c.Query("page"); p != "" {
if val, err := parseInt(p); err == nil && val > 0 {
page = val
}
}
// 支持 page_size 和 limit 两种参数名
if ps := c.Query("page_size"); ps != "" {
if val, err := parseInt(ps); err == nil && val > 0 && val <= 100 {
pageSize = val
}
} else if l := c.Query("limit"); l != "" {
if val, err := parseInt(l); err == nil && val > 0 && val <= 100 {
pageSize = val
}
}
return page, pageSize
}
func parseInt(s string) (int, error) {
var result int
for _, c := range s {
if c < '0' || c > '9' {
return 0, nil
}
result = result*10 + int(c-'0')
}
return result, nil
}

View File

@@ -0,0 +1,124 @@
// Package timezone provides global timezone management for the application.
// Similar to PHP's date_default_timezone_set, this package allows setting
// a global timezone that affects all time.Now() calls.
package timezone
import (
"fmt"
"log"
"time"
)
var (
// location is the global timezone location
location *time.Location
// tzName stores the timezone name for logging/debugging
tzName string
)
// Init initializes the global timezone setting.
// This should be called once at application startup.
// Example timezone values: "Asia/Shanghai", "America/New_York", "UTC"
func Init(tz string) error {
if tz == "" {
tz = "Asia/Shanghai" // Default timezone
}
loc, err := time.LoadLocation(tz)
if err != nil {
return fmt.Errorf("invalid timezone %q: %w", tz, err)
}
// Set the global Go time.Local to our timezone
// This affects time.Now() throughout the application
time.Local = loc
location = loc
tzName = tz
log.Printf("Timezone initialized: %s (UTC offset: %s)", tz, getUTCOffset(loc))
return nil
}
// getUTCOffset returns the current UTC offset for a location
func getUTCOffset(loc *time.Location) string {
_, offset := time.Now().In(loc).Zone()
hours := offset / 3600
minutes := (offset % 3600) / 60
if minutes < 0 {
minutes = -minutes
}
sign := "+"
if hours < 0 {
sign = "-"
hours = -hours
}
return fmt.Sprintf("%s%02d:%02d", sign, hours, minutes)
}
// Now returns the current time in the configured timezone.
// This is equivalent to time.Now() after Init() is called,
// but provided for explicit timezone-aware code.
func Now() time.Time {
if location == nil {
return time.Now()
}
return time.Now().In(location)
}
// Location returns the configured timezone location.
func Location() *time.Location {
if location == nil {
return time.Local
}
return location
}
// Name returns the configured timezone name.
func Name() string {
if tzName == "" {
return "Local"
}
return tzName
}
// StartOfDay returns the start of the given day (00:00:00) in the configured timezone.
func StartOfDay(t time.Time) time.Time {
loc := Location()
t = t.In(loc)
return time.Date(t.Year(), t.Month(), t.Day(), 0, 0, 0, 0, loc)
}
// Today returns the start of today (00:00:00) in the configured timezone.
func Today() time.Time {
return StartOfDay(Now())
}
// EndOfDay returns the end of the given day (23:59:59.999999999) in the configured timezone.
func EndOfDay(t time.Time) time.Time {
loc := Location()
t = t.In(loc)
return time.Date(t.Year(), t.Month(), t.Day(), 23, 59, 59, 999999999, loc)
}
// StartOfWeek returns the start of the week (Monday 00:00:00) for the given time.
func StartOfWeek(t time.Time) time.Time {
loc := Location()
t = t.In(loc)
weekday := int(t.Weekday())
if weekday == 0 {
weekday = 7 // Sunday is day 7
}
return time.Date(t.Year(), t.Month(), t.Day()-weekday+1, 0, 0, 0, 0, loc)
}
// StartOfMonth returns the start of the month (1st day 00:00:00) for the given time.
func StartOfMonth(t time.Time) time.Time {
loc := Location()
t = t.In(loc)
return time.Date(t.Year(), t.Month(), 1, 0, 0, 0, 0, loc)
}
// ParseInLocation parses a time string in the configured timezone.
func ParseInLocation(layout, value string) (time.Time, error) {
return time.ParseInLocation(layout, value, Location())
}

View File

@@ -0,0 +1,127 @@
package timezone
import (
"testing"
"time"
)
func TestInit(t *testing.T) {
// Test with valid timezone
err := Init("Asia/Shanghai")
if err != nil {
t.Fatalf("Init failed with valid timezone: %v", err)
}
// Verify time.Local was set
if time.Local.String() != "Asia/Shanghai" {
t.Errorf("time.Local not set correctly, got %s", time.Local.String())
}
// Verify our location variable
if Location().String() != "Asia/Shanghai" {
t.Errorf("Location() not set correctly, got %s", Location().String())
}
// Test Name()
if Name() != "Asia/Shanghai" {
t.Errorf("Name() not set correctly, got %s", Name())
}
}
func TestInitInvalidTimezone(t *testing.T) {
err := Init("Invalid/Timezone")
if err == nil {
t.Error("Init should fail with invalid timezone")
}
}
func TestTimeNowAffected(t *testing.T) {
// Reset to UTC first
Init("UTC")
utcNow := time.Now()
// Switch to Shanghai (UTC+8)
Init("Asia/Shanghai")
shanghaiNow := time.Now()
// The times should be the same instant, but different timezone representation
// Shanghai should be 8 hours ahead in display
_, utcOffset := utcNow.Zone()
_, shanghaiOffset := shanghaiNow.Zone()
expectedDiff := 8 * 3600 // 8 hours in seconds
actualDiff := shanghaiOffset - utcOffset
if actualDiff != expectedDiff {
t.Errorf("Timezone offset difference incorrect: expected %d, got %d", expectedDiff, actualDiff)
}
}
func TestToday(t *testing.T) {
Init("Asia/Shanghai")
today := Today()
now := Now()
// Today should be at 00:00:00
if today.Hour() != 0 || today.Minute() != 0 || today.Second() != 0 {
t.Errorf("Today() not at start of day: %v", today)
}
// Today should be same date as now
if today.Year() != now.Year() || today.Month() != now.Month() || today.Day() != now.Day() {
t.Errorf("Today() date mismatch: today=%v, now=%v", today, now)
}
}
func TestStartOfDay(t *testing.T) {
Init("Asia/Shanghai")
// Create a time at 15:30:45
testTime := time.Date(2024, 6, 15, 15, 30, 45, 123456789, Location())
startOfDay := StartOfDay(testTime)
expected := time.Date(2024, 6, 15, 0, 0, 0, 0, Location())
if !startOfDay.Equal(expected) {
t.Errorf("StartOfDay incorrect: expected %v, got %v", expected, startOfDay)
}
}
func TestTruncateVsStartOfDay(t *testing.T) {
// This test demonstrates why Truncate(24*time.Hour) can be problematic
// and why StartOfDay is more reliable for timezone-aware code
Init("Asia/Shanghai")
now := Now()
// Truncate operates on UTC, not local time
truncated := now.Truncate(24 * time.Hour)
// StartOfDay operates on local time
startOfDay := StartOfDay(now)
// These will likely be different for non-UTC timezones
t.Logf("Now: %v", now)
t.Logf("Truncate(24h): %v", truncated)
t.Logf("StartOfDay: %v", startOfDay)
// The truncated time may not be at local midnight
// StartOfDay is always at local midnight
if startOfDay.Hour() != 0 {
t.Errorf("StartOfDay should be at hour 0, got %d", startOfDay.Hour())
}
}
func TestDSTAwareness(t *testing.T) {
// Test with a timezone that has DST (America/New_York)
err := Init("America/New_York")
if err != nil {
t.Skipf("America/New_York timezone not available: %v", err)
}
// Just verify it doesn't crash
_ = Today()
_ = Now()
_ = StartOfDay(Now())
}