212 lines
7.5 KiB
Go
212 lines
7.5 KiB
Go
package service
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/json"
|
|
"io"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"testing"
|
|
|
|
"github.com/Wei-Shaw/sub2api/internal/model"
|
|
"github.com/gin-gonic/gin"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
)
|
|
|
|
func TestApplyErrorPassthroughRule_NoBoundService(t *testing.T) {
|
|
gin.SetMode(gin.TestMode)
|
|
rec := httptest.NewRecorder()
|
|
c, _ := gin.CreateTestContext(rec)
|
|
|
|
status, errType, errMsg, matched := applyErrorPassthroughRule(
|
|
c,
|
|
PlatformAnthropic,
|
|
http.StatusUnprocessableEntity,
|
|
[]byte(`{"error":{"message":"invalid schema"}}`),
|
|
http.StatusBadGateway,
|
|
"upstream_error",
|
|
"Upstream request failed",
|
|
)
|
|
|
|
assert.False(t, matched)
|
|
assert.Equal(t, http.StatusBadGateway, status)
|
|
assert.Equal(t, "upstream_error", errType)
|
|
assert.Equal(t, "Upstream request failed", errMsg)
|
|
}
|
|
|
|
func TestGatewayHandleErrorResponse_NoRuleKeepsDefault(t *testing.T) {
|
|
gin.SetMode(gin.TestMode)
|
|
rec := httptest.NewRecorder()
|
|
c, _ := gin.CreateTestContext(rec)
|
|
|
|
svc := &GatewayService{}
|
|
respBody := []byte(`{"error":{"message":"Invalid schema for field messages"}}`)
|
|
resp := &http.Response{
|
|
StatusCode: http.StatusUnprocessableEntity,
|
|
Body: io.NopCloser(bytes.NewReader(respBody)),
|
|
Header: http.Header{},
|
|
}
|
|
account := &Account{ID: 11, Platform: PlatformAnthropic, Type: AccountTypeAPIKey}
|
|
|
|
_, err := svc.handleErrorResponse(context.Background(), resp, c, account)
|
|
require.Error(t, err)
|
|
assert.Equal(t, http.StatusBadGateway, rec.Code)
|
|
|
|
var payload map[string]any
|
|
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &payload))
|
|
errField, ok := payload["error"].(map[string]any)
|
|
require.True(t, ok)
|
|
assert.Equal(t, "upstream_error", errField["type"])
|
|
assert.Equal(t, "Upstream request failed", errField["message"])
|
|
}
|
|
|
|
func TestOpenAIHandleErrorResponse_NoRuleKeepsDefault(t *testing.T) {
|
|
gin.SetMode(gin.TestMode)
|
|
rec := httptest.NewRecorder()
|
|
c, _ := gin.CreateTestContext(rec)
|
|
|
|
svc := &OpenAIGatewayService{}
|
|
respBody := []byte(`{"error":{"message":"Invalid schema for field messages"}}`)
|
|
resp := &http.Response{
|
|
StatusCode: http.StatusUnprocessableEntity,
|
|
Body: io.NopCloser(bytes.NewReader(respBody)),
|
|
Header: http.Header{},
|
|
}
|
|
account := &Account{ID: 12, Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
|
|
|
|
_, err := svc.handleErrorResponse(context.Background(), resp, c, account)
|
|
require.Error(t, err)
|
|
assert.Equal(t, http.StatusBadGateway, rec.Code)
|
|
|
|
var payload map[string]any
|
|
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &payload))
|
|
errField, ok := payload["error"].(map[string]any)
|
|
require.True(t, ok)
|
|
assert.Equal(t, "upstream_error", errField["type"])
|
|
assert.Equal(t, "Upstream request failed", errField["message"])
|
|
}
|
|
|
|
func TestGeminiWriteGeminiMappedError_NoRuleKeepsDefault(t *testing.T) {
|
|
gin.SetMode(gin.TestMode)
|
|
rec := httptest.NewRecorder()
|
|
c, _ := gin.CreateTestContext(rec)
|
|
|
|
svc := &GeminiMessagesCompatService{}
|
|
respBody := []byte(`{"error":{"code":422,"message":"Invalid schema for field messages","status":"INVALID_ARGUMENT"}}`)
|
|
account := &Account{ID: 13, Platform: PlatformGemini, Type: AccountTypeAPIKey}
|
|
|
|
err := svc.writeGeminiMappedError(c, account, http.StatusUnprocessableEntity, "req-2", respBody)
|
|
require.Error(t, err)
|
|
assert.Equal(t, http.StatusBadRequest, rec.Code)
|
|
|
|
var payload map[string]any
|
|
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &payload))
|
|
errField, ok := payload["error"].(map[string]any)
|
|
require.True(t, ok)
|
|
assert.Equal(t, "invalid_request_error", errField["type"])
|
|
assert.Equal(t, "Upstream request failed", errField["message"])
|
|
}
|
|
|
|
func TestGatewayHandleErrorResponse_AppliesRuleFor422(t *testing.T) {
|
|
gin.SetMode(gin.TestMode)
|
|
rec := httptest.NewRecorder()
|
|
c, _ := gin.CreateTestContext(rec)
|
|
|
|
ruleSvc := &ErrorPassthroughService{}
|
|
ruleSvc.setLocalCache([]*model.ErrorPassthroughRule{newNonFailoverPassthroughRule(http.StatusUnprocessableEntity, "invalid schema", http.StatusTeapot, "上游请求失败")})
|
|
BindErrorPassthroughService(c, ruleSvc)
|
|
|
|
svc := &GatewayService{}
|
|
respBody := []byte(`{"error":{"message":"Invalid schema for field messages"}}`)
|
|
resp := &http.Response{
|
|
StatusCode: http.StatusUnprocessableEntity,
|
|
Body: io.NopCloser(bytes.NewReader(respBody)),
|
|
Header: http.Header{},
|
|
}
|
|
account := &Account{ID: 1, Platform: PlatformAnthropic, Type: AccountTypeAPIKey}
|
|
|
|
_, err := svc.handleErrorResponse(context.Background(), resp, c, account)
|
|
require.Error(t, err)
|
|
assert.Equal(t, http.StatusTeapot, rec.Code)
|
|
|
|
var payload map[string]any
|
|
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &payload))
|
|
errField, ok := payload["error"].(map[string]any)
|
|
require.True(t, ok)
|
|
assert.Equal(t, "upstream_error", errField["type"])
|
|
assert.Equal(t, "上游请求失败", errField["message"])
|
|
}
|
|
|
|
func TestOpenAIHandleErrorResponse_AppliesRuleFor422(t *testing.T) {
|
|
gin.SetMode(gin.TestMode)
|
|
rec := httptest.NewRecorder()
|
|
c, _ := gin.CreateTestContext(rec)
|
|
|
|
ruleSvc := &ErrorPassthroughService{}
|
|
ruleSvc.setLocalCache([]*model.ErrorPassthroughRule{newNonFailoverPassthroughRule(http.StatusUnprocessableEntity, "invalid schema", http.StatusTeapot, "OpenAI上游失败")})
|
|
BindErrorPassthroughService(c, ruleSvc)
|
|
|
|
svc := &OpenAIGatewayService{}
|
|
respBody := []byte(`{"error":{"message":"Invalid schema for field messages"}}`)
|
|
resp := &http.Response{
|
|
StatusCode: http.StatusUnprocessableEntity,
|
|
Body: io.NopCloser(bytes.NewReader(respBody)),
|
|
Header: http.Header{},
|
|
}
|
|
account := &Account{ID: 2, Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
|
|
|
|
_, err := svc.handleErrorResponse(context.Background(), resp, c, account)
|
|
require.Error(t, err)
|
|
assert.Equal(t, http.StatusTeapot, rec.Code)
|
|
|
|
var payload map[string]any
|
|
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &payload))
|
|
errField, ok := payload["error"].(map[string]any)
|
|
require.True(t, ok)
|
|
assert.Equal(t, "upstream_error", errField["type"])
|
|
assert.Equal(t, "OpenAI上游失败", errField["message"])
|
|
}
|
|
|
|
func TestGeminiWriteGeminiMappedError_AppliesRuleFor422(t *testing.T) {
|
|
gin.SetMode(gin.TestMode)
|
|
rec := httptest.NewRecorder()
|
|
c, _ := gin.CreateTestContext(rec)
|
|
|
|
ruleSvc := &ErrorPassthroughService{}
|
|
ruleSvc.setLocalCache([]*model.ErrorPassthroughRule{newNonFailoverPassthroughRule(http.StatusUnprocessableEntity, "invalid schema", http.StatusTeapot, "Gemini上游失败")})
|
|
BindErrorPassthroughService(c, ruleSvc)
|
|
|
|
svc := &GeminiMessagesCompatService{}
|
|
respBody := []byte(`{"error":{"code":422,"message":"Invalid schema for field messages","status":"INVALID_ARGUMENT"}}`)
|
|
account := &Account{ID: 3, Platform: PlatformGemini, Type: AccountTypeAPIKey}
|
|
|
|
err := svc.writeGeminiMappedError(c, account, http.StatusUnprocessableEntity, "req-1", respBody)
|
|
require.Error(t, err)
|
|
assert.Equal(t, http.StatusTeapot, rec.Code)
|
|
|
|
var payload map[string]any
|
|
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &payload))
|
|
errField, ok := payload["error"].(map[string]any)
|
|
require.True(t, ok)
|
|
assert.Equal(t, "upstream_error", errField["type"])
|
|
assert.Equal(t, "Gemini上游失败", errField["message"])
|
|
}
|
|
|
|
func newNonFailoverPassthroughRule(statusCode int, keyword string, respCode int, customMessage string) *model.ErrorPassthroughRule {
|
|
return &model.ErrorPassthroughRule{
|
|
ID: 1,
|
|
Name: "non-failover-rule",
|
|
Enabled: true,
|
|
Priority: 1,
|
|
ErrorCodes: []int{statusCode},
|
|
Keywords: []string{keyword},
|
|
MatchMode: model.MatchModeAll,
|
|
PassthroughCode: false,
|
|
ResponseCode: &respCode,
|
|
PassthroughBody: false,
|
|
CustomMessage: &customMessage,
|
|
}
|
|
}
|