This commit is contained in:
yangjianbo
2025-12-30 14:15:52 +08:00
7 changed files with 167 additions and 26 deletions

View File

@@ -216,7 +216,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
}
}
const maxAccountSwitches = 3
const maxAccountSwitches = 10
switchCount := 0
failedAccountIDs := make(map[int64]struct{})
lastFailoverStatus := 0

View File

@@ -61,6 +61,13 @@ func jwtAuth(authService *service.AuthService, userService *service.UserService)
return
}
// Security: Validate TokenVersion to ensure token hasn't been invalidated
// This check ensures tokens issued before a password change are rejected
if claims.TokenVersion != user.TokenVersion {
AbortWithError(c, 401, "TOKEN_REVOKED", "Token has been revoked (password changed)")
return
}
c.Set(string(ContextKeyUser), AuthSubject{
UserID: user.ID,
Concurrency: user.Concurrency,

View File

@@ -20,6 +20,7 @@ var (
ErrEmailExists = infraerrors.Conflict("EMAIL_EXISTS", "email already exists")
ErrInvalidToken = infraerrors.Unauthorized("INVALID_TOKEN", "invalid token")
ErrTokenExpired = infraerrors.Unauthorized("TOKEN_EXPIRED", "token has expired")
ErrTokenRevoked = infraerrors.Unauthorized("TOKEN_REVOKED", "token has been revoked")
ErrEmailVerifyRequired = infraerrors.BadRequest("EMAIL_VERIFY_REQUIRED", "email verification is required")
ErrRegDisabled = infraerrors.Forbidden("REGISTRATION_DISABLED", "registration is currently disabled")
ErrServiceUnavailable = infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "service temporarily unavailable")
@@ -27,9 +28,10 @@ var (
// JWTClaims JWT载荷数据
type JWTClaims struct {
UserID int64 `json:"user_id"`
Email string `json:"email"`
Role string `json:"role"`
UserID int64 `json:"user_id"`
Email string `json:"email"`
Role string `json:"role"`
TokenVersion int64 `json:"token_version"` // Used to invalidate tokens on password change
jwt.RegisteredClaims
}
@@ -311,9 +313,10 @@ func (s *AuthService) GenerateToken(user *User) (string, error) {
expiresAt := now.Add(time.Duration(s.cfg.JWT.ExpireHour) * time.Hour)
claims := &JWTClaims{
UserID: user.ID,
Email: user.Email,
Role: user.Role,
UserID: user.ID,
Email: user.Email,
Role: user.Role,
TokenVersion: user.TokenVersion,
RegisteredClaims: jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(expiresAt),
IssuedAt: jwt.NewNumericDate(now),
@@ -368,6 +371,12 @@ func (s *AuthService) RefreshToken(ctx context.Context, oldTokenString string) (
return "", ErrUserNotActive
}
// Security: Check TokenVersion to prevent refreshing revoked tokens
// This ensures tokens issued before a password change cannot be refreshed
if claims.TokenVersion != user.TokenVersion {
return "", ErrTokenRevoked
}
// 生成新token
return s.GenerateToken(user)
}

View File

@@ -695,6 +695,11 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
if req.Stream {
streamResult, err := s.handleStreamingResponse(ctx, resp, c, account, startTime, originalModel, req.Model)
if err != nil {
if err.Error() == "have error in stream" {
return nil, &UpstreamFailoverError{
StatusCode: 403,
}
}
return nil, err
}
usage = streamResult.usage
@@ -969,6 +974,9 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
for scanner.Scan() {
line := scanner.Text()
if line == "event: error" {
return nil, errors.New("have error in stream")
}
// Extract data from SSE line (supports both "data: " and "data:" formats)
if sseDataRe.MatchString(line) {

View File

@@ -18,6 +18,7 @@ type User struct {
Concurrency int
Status string
AllowedGroups []int64
TokenVersion int64 // Incremented on password change to invalidate existing tokens
CreatedAt time.Time
UpdatedAt time.Time

View File

@@ -116,6 +116,7 @@ func (s *UserService) UpdateProfile(ctx context.Context, userID int64, req Updat
}
// ChangePassword 修改密码
// Security: Increments TokenVersion to invalidate all existing JWT tokens
func (s *UserService) ChangePassword(ctx context.Context, userID int64, req ChangePasswordRequest) error {
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
@@ -131,6 +132,10 @@ func (s *UserService) ChangePassword(ctx context.Context, userID int64, req Chan
return fmt.Errorf("set password: %w", err)
}
// Increment TokenVersion to invalidate all existing tokens
// This ensures that any tokens issued before the password change become invalid
user.TokenVersion++
if err := s.userRepo.Update(ctx, user); err != nil {
return fmt.Errorf("update user: %w", err)
}