fix: Messages() routing refactor and subscription group test coverage

- Refactor OpenAI Messages() routing: pre-compute dispatch model using
  resolveOpenAIMessagesDispatchMappedModel + NormalizeOpenAICompatRequestedModel
  instead of try-fail-retry pattern with gin context passing
- Remove openai_messages_fallback_model context anti-pattern
- Use effectiveMappedModel directly for forward default mapped model
- Add 3 subscription group tests covering all branch paths:
  _Blocked (no active subscription → SUBSCRIPTION_REQUIRED),
  _RequiresRepo (nil repo → SUBSCRIPTION_REPOSITORY_UNAVAILABLE),
  _AllowsActiveSubscription (valid subscription → success)
This commit is contained in:
erio
2026-04-14 20:34:53 +08:00
parent 3d2027227b
commit 8548a130c7
2 changed files with 69 additions and 30 deletions

View File

@@ -557,6 +557,8 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
return
}
reqModel := modelResult.String()
routingModel := service.NormalizeOpenAICompatRequestedModel(reqModel)
preferredMappedModel := resolveOpenAIMessagesDispatchMappedModel(apiKey, reqModel)
reqStream := gjson.GetBytes(body, "stream").Bool()
reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", reqStream))
@@ -615,17 +617,20 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
failedAccountIDs := make(map[int64]struct{})
sameAccountRetryCount := make(map[int64]int)
var lastFailoverErr *service.UpstreamFailoverError
effectiveMappedModel := preferredMappedModel
for {
// 清除上一次迭代的降级模型标记,避免残留影响本次迭代
c.Set("openai_messages_fallback_model", "")
currentRoutingModel := routingModel
if effectiveMappedModel != "" {
currentRoutingModel = effectiveMappedModel
}
reqLog.Debug("openai_messages.account_selecting", zap.Int("excluded_account_count", len(failedAccountIDs)))
selection, scheduleDecision, err := h.gatewayService.SelectAccountWithScheduler(
c.Request.Context(),
apiKey.GroupID,
"", // no previous_response_id
sessionHash,
reqModel,
currentRoutingModel,
failedAccountIDs,
service.OpenAIUpstreamTransportAny,
)
@@ -634,29 +639,7 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
zap.Error(err),
zap.Int("excluded_account_count", len(failedAccountIDs)),
)
// 首次调度失败 + 有默认映射模型 → 用默认模型重试
if len(failedAccountIDs) == 0 {
defaultModel := ""
if apiKey.Group != nil {
defaultModel = apiKey.Group.DefaultMappedModel
}
if defaultModel != "" && defaultModel != reqModel {
reqLog.Info("openai_messages.fallback_to_default_model",
zap.String("default_mapped_model", defaultModel),
)
selection, scheduleDecision, err = h.gatewayService.SelectAccountWithScheduler(
c.Request.Context(),
apiKey.GroupID,
"",
sessionHash,
defaultModel,
failedAccountIDs,
service.OpenAIUpstreamTransportAny,
)
if err == nil && selection != nil {
c.Set("openai_messages_fallback_model", defaultModel)
}
}
if err != nil {
h.anthropicStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "Service temporarily unavailable", streamStarted)
return
@@ -688,9 +671,7 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
service.SetOpsLatencyMs(c, service.OpsRoutingLatencyMsKey, time.Since(routingStart).Milliseconds())
forwardStart := time.Now()
// Forward 层需要始终拿到 group 默认映射模型,这样未命中账号级映射的
// Claude 兼容模型才不会在后续 Codex 规范化中意外退化到 gpt-5.1。
defaultMappedModel := resolveOpenAIForwardDefaultMappedModel(apiKey, c.GetString("openai_messages_fallback_model"))
defaultMappedModel := strings.TrimSpace(effectiveMappedModel)
// 应用渠道模型映射到请求体
forwardBody := body
if channelMappingMsg.Mapped {

View File

@@ -216,6 +216,29 @@ func (s *groupRepoStubForGroupUpdate) UpdateSortOrders(context.Context, []GroupS
panic("unexpected")
}
type userSubRepoStubForGroupUpdate struct {
userSubRepoNoop
getActiveSub *UserSubscription
getActiveErr error
called bool
calledUserID int64
calledGroupID int64
}
func (s *userSubRepoStubForGroupUpdate) GetActiveByUserIDAndGroupID(_ context.Context, userID, groupID int64) (*UserSubscription, error) {
s.called = true
s.calledUserID = userID
s.calledGroupID = groupID
if s.getActiveErr != nil {
return nil, s.getActiveErr
}
if s.getActiveSub == nil {
return nil, ErrSubscriptionNotFound
}
clone := *s.getActiveSub
return &clone, nil
}
// ---------------------------------------------------------------------------
// Tests
// ---------------------------------------------------------------------------
@@ -408,17 +431,52 @@ func TestAdminService_AdminUpdateAPIKeyGroupID_NonExclusiveGroup_NoAllowedGroupU
func TestAdminService_AdminUpdateAPIKeyGroupID_SubscriptionGroup_Blocked(t *testing.T) {
existing := &APIKey{ID: 1, UserID: 42, Key: "sk-test", GroupID: nil}
apiKeyRepo := &apiKeyRepoStubForGroupUpdate{key: existing}
groupRepo := &groupRepoStubForGroupUpdate{group: &Group{ID: 10, Name: "Sub", Status: StatusActive, IsExclusive: true, SubscriptionType: SubscriptionTypeSubscription}}
groupRepo := &groupRepoStubForGroupUpdate{group: &Group{ID: 10, Name: "Sub", Status: StatusActive, IsExclusive: false, SubscriptionType: SubscriptionTypeSubscription}}
userRepo := &userRepoStubForGroupUpdate{}
userSubRepo := &userSubRepoStubForGroupUpdate{getActiveErr: ErrSubscriptionNotFound}
svc := &adminServiceImpl{apiKeyRepo: apiKeyRepo, groupRepo: groupRepo, userRepo: userRepo, userSubRepo: userSubRepo}
// 无有效订阅时应拒绝绑定
_, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, int64Ptr(10))
require.Error(t, err)
require.Equal(t, "SUBSCRIPTION_REQUIRED", infraerrors.Reason(err))
require.True(t, userSubRepo.called)
require.Equal(t, int64(42), userSubRepo.calledUserID)
require.Equal(t, int64(10), userSubRepo.calledGroupID)
require.False(t, userRepo.addGroupCalled)
}
func TestAdminService_AdminUpdateAPIKeyGroupID_SubscriptionGroup_RequiresRepo(t *testing.T) {
existing := &APIKey{ID: 1, UserID: 42, Key: "sk-test", GroupID: nil}
apiKeyRepo := &apiKeyRepoStubForGroupUpdate{key: existing}
groupRepo := &groupRepoStubForGroupUpdate{group: &Group{ID: 10, Name: "Sub", Status: StatusActive, IsExclusive: false, SubscriptionType: SubscriptionTypeSubscription}}
userRepo := &userRepoStubForGroupUpdate{}
svc := &adminServiceImpl{apiKeyRepo: apiKeyRepo, groupRepo: groupRepo, userRepo: userRepo}
// userSubRepo is nil → SUBSCRIPTION_REPOSITORY_UNAVAILABLE
_, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, int64Ptr(10))
require.Error(t, err)
require.Equal(t, "SUBSCRIPTION_REPOSITORY_UNAVAILABLE", infraerrors.Reason(err))
require.False(t, userRepo.addGroupCalled)
}
func TestAdminService_AdminUpdateAPIKeyGroupID_SubscriptionGroup_AllowsActiveSubscription(t *testing.T) {
existing := &APIKey{ID: 1, UserID: 42, Key: "sk-test", GroupID: nil}
apiKeyRepo := &apiKeyRepoStubForGroupUpdate{key: existing}
groupRepo := &groupRepoStubForGroupUpdate{group: &Group{ID: 10, Name: "Sub", Status: StatusActive, IsExclusive: true, SubscriptionType: SubscriptionTypeSubscription}}
userRepo := &userRepoStubForGroupUpdate{}
userSubRepo := &userSubRepoStubForGroupUpdate{
getActiveSub: &UserSubscription{ID: 99, UserID: 42, GroupID: 10},
}
svc := &adminServiceImpl{apiKeyRepo: apiKeyRepo, groupRepo: groupRepo, userRepo: userRepo, userSubRepo: userSubRepo}
got, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, int64Ptr(10))
require.NoError(t, err)
require.True(t, userSubRepo.called)
require.NotNil(t, got.APIKey.GroupID)
require.Equal(t, int64(10), *got.APIKey.GroupID)
require.False(t, userRepo.addGroupCalled)
}
func TestAdminService_AdminUpdateAPIKeyGroupID_ExclusiveGroup_AllowedGroupAddFails_ReturnsError(t *testing.T) {
existing := &APIKey{ID: 1, UserID: 42, Key: "sk-test", GroupID: nil}
apiKeyRepo := &apiKeyRepoStubForGroupUpdate{key: existing}