From e0a79e853d012fbad3ac3be2df92b3e0d64d5998 Mon Sep 17 00:00:00 2001 From: CaIon Date: Sat, 13 Dec 2025 01:38:12 +0800 Subject: [PATCH] refactor(auth): replace direct token group setting with context key retrieval --- middleware/auth.go | 2 +- relay/common/relay_info.go | 10 ++++++++-- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/middleware/auth.go b/middleware/auth.go index b1fca471..cefc4e06 100644 --- a/middleware/auth.go +++ b/middleware/auth.go @@ -307,7 +307,7 @@ func SetupContextForToken(c *gin.Context, token *model.Token, parts ...string) e } else { c.Set("token_model_limit_enabled", false) } - c.Set("token_group", token.Group) + common.SetContextKey(c, constant.ContextKeyTokenGroup, token.Group) c.Set("token_cross_group_retry", token.CrossGroupRetry) if len(parts) > 1 { if model.IsAdmin(token.UserId) { diff --git a/relay/common/relay_info.go b/relay/common/relay_info.go index 8bc47bb5..40f79463 100644 --- a/relay/common/relay_info.go +++ b/relay/common/relay_info.go @@ -83,7 +83,7 @@ type RelayInfo struct { TokenKey string TokenGroup string UserId int - UsingGroup string // 使用的分组 + UsingGroup string // 使用的分组,当auto跨分组重试时,会变动 UserGroup string // 用户所在分组 TokenUnlimited bool StartTime time.Time @@ -374,6 +374,12 @@ func genBaseRelayInfo(c *gin.Context, request dto.Request) *RelayInfo { //channelId := common.GetContextKeyInt(c, constant.ContextKeyChannelId) //paramOverride := common.GetContextKeyStringMap(c, constant.ContextKeyChannelParamOverride) + tokenGroup := common.GetContextKeyString(c, constant.ContextKeyTokenGroup) + // 当令牌分组为空时,表示使用用户分组 + if tokenGroup == "" { + tokenGroup = common.GetContextKeyString(c, constant.ContextKeyUserGroup) + } + startTime := common.GetContextKeyTime(c, constant.ContextKeyRequestStartTime) if startTime.IsZero() { startTime = time.Now() @@ -401,7 +407,7 @@ func genBaseRelayInfo(c *gin.Context, request dto.Request) *RelayInfo { TokenId: common.GetContextKeyInt(c, constant.ContextKeyTokenId), TokenKey: common.GetContextKeyString(c, constant.ContextKeyTokenKey), TokenUnlimited: common.GetContextKeyBool(c, constant.ContextKeyTokenUnlimited), - TokenGroup: common.GetContextKeyString(c, constant.ContextKeyTokenGroup), + TokenGroup: tokenGroup, isFirstResponse: true, RelayMode: relayconstant.Path2RelayMode(c.Request.URL.Path),