feat(sync): full code sync from release
This commit is contained in:
@@ -1337,6 +1337,34 @@ func (h *AccountHandler) GetTodayStats(c *gin.Context) {
|
||||
response.Success(c, stats)
|
||||
}
|
||||
|
||||
// BatchTodayStatsRequest 批量今日统计请求体。
|
||||
type BatchTodayStatsRequest struct {
|
||||
AccountIDs []int64 `json:"account_ids" binding:"required"`
|
||||
}
|
||||
|
||||
// GetBatchTodayStats 批量获取多个账号的今日统计。
|
||||
// POST /api/v1/admin/accounts/today-stats/batch
|
||||
func (h *AccountHandler) GetBatchTodayStats(c *gin.Context) {
|
||||
var req BatchTodayStatsRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if len(req.AccountIDs) == 0 {
|
||||
response.Success(c, gin.H{"stats": map[string]any{}})
|
||||
return
|
||||
}
|
||||
|
||||
stats, err := h.accountUsageService.GetTodayStatsBatch(c.Request.Context(), req.AccountIDs)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{"stats": stats})
|
||||
}
|
||||
|
||||
// SetSchedulableRequest represents the request body for setting schedulable status
|
||||
type SetSchedulableRequest struct {
|
||||
Schedulable bool `json:"schedulable"`
|
||||
|
||||
@@ -3,6 +3,7 @@ package admin
|
||||
import (
|
||||
"errors"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
@@ -186,7 +187,7 @@ func (h *DashboardHandler) GetRealtimeMetrics(c *gin.Context) {
|
||||
|
||||
// GetUsageTrend handles getting usage trend data
|
||||
// GET /api/v1/admin/dashboard/trend
|
||||
// Query params: start_date, end_date (YYYY-MM-DD), granularity (day/hour), user_id, api_key_id, model, account_id, group_id, stream, billing_type
|
||||
// Query params: start_date, end_date (YYYY-MM-DD), granularity (day/hour), user_id, api_key_id, model, account_id, group_id, request_type, stream, billing_type
|
||||
func (h *DashboardHandler) GetUsageTrend(c *gin.Context) {
|
||||
startTime, endTime := parseTimeRange(c)
|
||||
granularity := c.DefaultQuery("granularity", "day")
|
||||
@@ -194,6 +195,7 @@ func (h *DashboardHandler) GetUsageTrend(c *gin.Context) {
|
||||
// Parse optional filter params
|
||||
var userID, apiKeyID, accountID, groupID int64
|
||||
var model string
|
||||
var requestType *int16
|
||||
var stream *bool
|
||||
var billingType *int8
|
||||
|
||||
@@ -220,9 +222,20 @@ func (h *DashboardHandler) GetUsageTrend(c *gin.Context) {
|
||||
if modelStr := c.Query("model"); modelStr != "" {
|
||||
model = modelStr
|
||||
}
|
||||
if streamStr := c.Query("stream"); streamStr != "" {
|
||||
if requestTypeStr := strings.TrimSpace(c.Query("request_type")); requestTypeStr != "" {
|
||||
parsed, err := service.ParseUsageRequestType(requestTypeStr)
|
||||
if err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
value := int16(parsed)
|
||||
requestType = &value
|
||||
} else if streamStr := c.Query("stream"); streamStr != "" {
|
||||
if streamVal, err := strconv.ParseBool(streamStr); err == nil {
|
||||
stream = &streamVal
|
||||
} else {
|
||||
response.BadRequest(c, "Invalid stream value, use true or false")
|
||||
return
|
||||
}
|
||||
}
|
||||
if billingTypeStr := c.Query("billing_type"); billingTypeStr != "" {
|
||||
@@ -235,7 +248,7 @@ func (h *DashboardHandler) GetUsageTrend(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
trend, err := h.dashboardService.GetUsageTrendWithFilters(c.Request.Context(), startTime, endTime, granularity, userID, apiKeyID, accountID, groupID, model, stream, billingType)
|
||||
trend, err := h.dashboardService.GetUsageTrendWithFilters(c.Request.Context(), startTime, endTime, granularity, userID, apiKeyID, accountID, groupID, model, requestType, stream, billingType)
|
||||
if err != nil {
|
||||
response.Error(c, 500, "Failed to get usage trend")
|
||||
return
|
||||
@@ -251,12 +264,13 @@ func (h *DashboardHandler) GetUsageTrend(c *gin.Context) {
|
||||
|
||||
// GetModelStats handles getting model usage statistics
|
||||
// GET /api/v1/admin/dashboard/models
|
||||
// Query params: start_date, end_date (YYYY-MM-DD), user_id, api_key_id, account_id, group_id, stream, billing_type
|
||||
// Query params: start_date, end_date (YYYY-MM-DD), user_id, api_key_id, account_id, group_id, request_type, stream, billing_type
|
||||
func (h *DashboardHandler) GetModelStats(c *gin.Context) {
|
||||
startTime, endTime := parseTimeRange(c)
|
||||
|
||||
// Parse optional filter params
|
||||
var userID, apiKeyID, accountID, groupID int64
|
||||
var requestType *int16
|
||||
var stream *bool
|
||||
var billingType *int8
|
||||
|
||||
@@ -280,9 +294,20 @@ func (h *DashboardHandler) GetModelStats(c *gin.Context) {
|
||||
groupID = id
|
||||
}
|
||||
}
|
||||
if streamStr := c.Query("stream"); streamStr != "" {
|
||||
if requestTypeStr := strings.TrimSpace(c.Query("request_type")); requestTypeStr != "" {
|
||||
parsed, err := service.ParseUsageRequestType(requestTypeStr)
|
||||
if err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
value := int16(parsed)
|
||||
requestType = &value
|
||||
} else if streamStr := c.Query("stream"); streamStr != "" {
|
||||
if streamVal, err := strconv.ParseBool(streamStr); err == nil {
|
||||
stream = &streamVal
|
||||
} else {
|
||||
response.BadRequest(c, "Invalid stream value, use true or false")
|
||||
return
|
||||
}
|
||||
}
|
||||
if billingTypeStr := c.Query("billing_type"); billingTypeStr != "" {
|
||||
@@ -295,7 +320,7 @@ func (h *DashboardHandler) GetModelStats(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
stats, err := h.dashboardService.GetModelStatsWithFilters(c.Request.Context(), startTime, endTime, userID, apiKeyID, accountID, groupID, stream, billingType)
|
||||
stats, err := h.dashboardService.GetModelStatsWithFilters(c.Request.Context(), startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType)
|
||||
if err != nil {
|
||||
response.Error(c, 500, "Failed to get model statistics")
|
||||
return
|
||||
|
||||
@@ -0,0 +1,132 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type dashboardUsageRepoCapture struct {
|
||||
service.UsageLogRepository
|
||||
trendRequestType *int16
|
||||
trendStream *bool
|
||||
modelRequestType *int16
|
||||
modelStream *bool
|
||||
}
|
||||
|
||||
func (s *dashboardUsageRepoCapture) GetUsageTrendWithFilters(
|
||||
ctx context.Context,
|
||||
startTime, endTime time.Time,
|
||||
granularity string,
|
||||
userID, apiKeyID, accountID, groupID int64,
|
||||
model string,
|
||||
requestType *int16,
|
||||
stream *bool,
|
||||
billingType *int8,
|
||||
) ([]usagestats.TrendDataPoint, error) {
|
||||
s.trendRequestType = requestType
|
||||
s.trendStream = stream
|
||||
return []usagestats.TrendDataPoint{}, nil
|
||||
}
|
||||
|
||||
func (s *dashboardUsageRepoCapture) GetModelStatsWithFilters(
|
||||
ctx context.Context,
|
||||
startTime, endTime time.Time,
|
||||
userID, apiKeyID, accountID, groupID int64,
|
||||
requestType *int16,
|
||||
stream *bool,
|
||||
billingType *int8,
|
||||
) ([]usagestats.ModelStat, error) {
|
||||
s.modelRequestType = requestType
|
||||
s.modelStream = stream
|
||||
return []usagestats.ModelStat{}, nil
|
||||
}
|
||||
|
||||
func newDashboardRequestTypeTestRouter(repo *dashboardUsageRepoCapture) *gin.Engine {
|
||||
gin.SetMode(gin.TestMode)
|
||||
dashboardSvc := service.NewDashboardService(repo, nil, nil, nil)
|
||||
handler := NewDashboardHandler(dashboardSvc, nil)
|
||||
router := gin.New()
|
||||
router.GET("/admin/dashboard/trend", handler.GetUsageTrend)
|
||||
router.GET("/admin/dashboard/models", handler.GetModelStats)
|
||||
return router
|
||||
}
|
||||
|
||||
func TestDashboardTrendRequestTypePriority(t *testing.T) {
|
||||
repo := &dashboardUsageRepoCapture{}
|
||||
router := newDashboardRequestTypeTestRouter(repo)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/admin/dashboard/trend?request_type=ws_v2&stream=bad", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
require.NotNil(t, repo.trendRequestType)
|
||||
require.Equal(t, int16(service.RequestTypeWSV2), *repo.trendRequestType)
|
||||
require.Nil(t, repo.trendStream)
|
||||
}
|
||||
|
||||
func TestDashboardTrendInvalidRequestType(t *testing.T) {
|
||||
repo := &dashboardUsageRepoCapture{}
|
||||
router := newDashboardRequestTypeTestRouter(repo)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/admin/dashboard/trend?request_type=bad", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusBadRequest, rec.Code)
|
||||
}
|
||||
|
||||
func TestDashboardTrendInvalidStream(t *testing.T) {
|
||||
repo := &dashboardUsageRepoCapture{}
|
||||
router := newDashboardRequestTypeTestRouter(repo)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/admin/dashboard/trend?stream=bad", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusBadRequest, rec.Code)
|
||||
}
|
||||
|
||||
func TestDashboardModelStatsRequestTypePriority(t *testing.T) {
|
||||
repo := &dashboardUsageRepoCapture{}
|
||||
router := newDashboardRequestTypeTestRouter(repo)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/admin/dashboard/models?request_type=sync&stream=bad", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
require.NotNil(t, repo.modelRequestType)
|
||||
require.Equal(t, int16(service.RequestTypeSync), *repo.modelRequestType)
|
||||
require.Nil(t, repo.modelStream)
|
||||
}
|
||||
|
||||
func TestDashboardModelStatsInvalidRequestType(t *testing.T) {
|
||||
repo := &dashboardUsageRepoCapture{}
|
||||
router := newDashboardRequestTypeTestRouter(repo)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/admin/dashboard/models?request_type=bad", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusBadRequest, rec.Code)
|
||||
}
|
||||
|
||||
func TestDashboardModelStatsInvalidStream(t *testing.T) {
|
||||
repo := &dashboardUsageRepoCapture{}
|
||||
router := newDashboardRequestTypeTestRouter(repo)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/admin/dashboard/models?stream=bad", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusBadRequest, rec.Code)
|
||||
}
|
||||
523
backend/internal/handler/admin/data_management_handler.go
Normal file
523
backend/internal/handler/admin/data_management_handler.go
Normal file
@@ -0,0 +1,523 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type DataManagementHandler struct {
|
||||
dataManagementService *service.DataManagementService
|
||||
}
|
||||
|
||||
func NewDataManagementHandler(dataManagementService *service.DataManagementService) *DataManagementHandler {
|
||||
return &DataManagementHandler{dataManagementService: dataManagementService}
|
||||
}
|
||||
|
||||
type TestS3ConnectionRequest struct {
|
||||
Endpoint string `json:"endpoint"`
|
||||
Region string `json:"region" binding:"required"`
|
||||
Bucket string `json:"bucket" binding:"required"`
|
||||
AccessKeyID string `json:"access_key_id"`
|
||||
SecretAccessKey string `json:"secret_access_key"`
|
||||
Prefix string `json:"prefix"`
|
||||
ForcePathStyle bool `json:"force_path_style"`
|
||||
UseSSL bool `json:"use_ssl"`
|
||||
}
|
||||
|
||||
type CreateBackupJobRequest struct {
|
||||
BackupType string `json:"backup_type" binding:"required,oneof=postgres redis full"`
|
||||
UploadToS3 bool `json:"upload_to_s3"`
|
||||
S3ProfileID string `json:"s3_profile_id"`
|
||||
PostgresID string `json:"postgres_profile_id"`
|
||||
RedisID string `json:"redis_profile_id"`
|
||||
IdempotencyKey string `json:"idempotency_key"`
|
||||
}
|
||||
|
||||
type CreateSourceProfileRequest struct {
|
||||
ProfileID string `json:"profile_id" binding:"required"`
|
||||
Name string `json:"name" binding:"required"`
|
||||
Config service.DataManagementSourceConfig `json:"config" binding:"required"`
|
||||
SetActive bool `json:"set_active"`
|
||||
}
|
||||
|
||||
type UpdateSourceProfileRequest struct {
|
||||
Name string `json:"name" binding:"required"`
|
||||
Config service.DataManagementSourceConfig `json:"config" binding:"required"`
|
||||
}
|
||||
|
||||
type CreateS3ProfileRequest struct {
|
||||
ProfileID string `json:"profile_id" binding:"required"`
|
||||
Name string `json:"name" binding:"required"`
|
||||
Enabled bool `json:"enabled"`
|
||||
Endpoint string `json:"endpoint"`
|
||||
Region string `json:"region"`
|
||||
Bucket string `json:"bucket"`
|
||||
AccessKeyID string `json:"access_key_id"`
|
||||
SecretAccessKey string `json:"secret_access_key"`
|
||||
Prefix string `json:"prefix"`
|
||||
ForcePathStyle bool `json:"force_path_style"`
|
||||
UseSSL bool `json:"use_ssl"`
|
||||
SetActive bool `json:"set_active"`
|
||||
}
|
||||
|
||||
type UpdateS3ProfileRequest struct {
|
||||
Name string `json:"name" binding:"required"`
|
||||
Enabled bool `json:"enabled"`
|
||||
Endpoint string `json:"endpoint"`
|
||||
Region string `json:"region"`
|
||||
Bucket string `json:"bucket"`
|
||||
AccessKeyID string `json:"access_key_id"`
|
||||
SecretAccessKey string `json:"secret_access_key"`
|
||||
Prefix string `json:"prefix"`
|
||||
ForcePathStyle bool `json:"force_path_style"`
|
||||
UseSSL bool `json:"use_ssl"`
|
||||
}
|
||||
|
||||
func (h *DataManagementHandler) GetAgentHealth(c *gin.Context) {
|
||||
health := h.getAgentHealth(c)
|
||||
payload := gin.H{
|
||||
"enabled": health.Enabled,
|
||||
"reason": health.Reason,
|
||||
"socket_path": health.SocketPath,
|
||||
}
|
||||
if health.Agent != nil {
|
||||
payload["agent"] = gin.H{
|
||||
"status": health.Agent.Status,
|
||||
"version": health.Agent.Version,
|
||||
"uptime_seconds": health.Agent.UptimeSeconds,
|
||||
}
|
||||
}
|
||||
response.Success(c, payload)
|
||||
}
|
||||
|
||||
func (h *DataManagementHandler) GetConfig(c *gin.Context) {
|
||||
if !h.requireAgentEnabled(c) {
|
||||
return
|
||||
}
|
||||
cfg, err := h.dataManagementService.GetConfig(c.Request.Context())
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, cfg)
|
||||
}
|
||||
|
||||
func (h *DataManagementHandler) UpdateConfig(c *gin.Context) {
|
||||
var req service.DataManagementConfig
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if !h.requireAgentEnabled(c) {
|
||||
return
|
||||
}
|
||||
cfg, err := h.dataManagementService.UpdateConfig(c.Request.Context(), req)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, cfg)
|
||||
}
|
||||
|
||||
func (h *DataManagementHandler) TestS3(c *gin.Context) {
|
||||
var req TestS3ConnectionRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if !h.requireAgentEnabled(c) {
|
||||
return
|
||||
}
|
||||
result, err := h.dataManagementService.ValidateS3(c.Request.Context(), service.DataManagementS3Config{
|
||||
Enabled: true,
|
||||
Endpoint: req.Endpoint,
|
||||
Region: req.Region,
|
||||
Bucket: req.Bucket,
|
||||
AccessKeyID: req.AccessKeyID,
|
||||
SecretAccessKey: req.SecretAccessKey,
|
||||
Prefix: req.Prefix,
|
||||
ForcePathStyle: req.ForcePathStyle,
|
||||
UseSSL: req.UseSSL,
|
||||
})
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, gin.H{"ok": result.OK, "message": result.Message})
|
||||
}
|
||||
|
||||
func (h *DataManagementHandler) CreateBackupJob(c *gin.Context) {
|
||||
var req CreateBackupJobRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
req.IdempotencyKey = normalizeBackupIdempotencyKey(c.GetHeader("X-Idempotency-Key"), req.IdempotencyKey)
|
||||
if !h.requireAgentEnabled(c) {
|
||||
return
|
||||
}
|
||||
|
||||
triggeredBy := "admin:unknown"
|
||||
if subject, ok := middleware2.GetAuthSubjectFromContext(c); ok {
|
||||
triggeredBy = "admin:" + strconv.FormatInt(subject.UserID, 10)
|
||||
}
|
||||
job, err := h.dataManagementService.CreateBackupJob(c.Request.Context(), service.DataManagementCreateBackupJobInput{
|
||||
BackupType: req.BackupType,
|
||||
UploadToS3: req.UploadToS3,
|
||||
S3ProfileID: req.S3ProfileID,
|
||||
PostgresID: req.PostgresID,
|
||||
RedisID: req.RedisID,
|
||||
TriggeredBy: triggeredBy,
|
||||
IdempotencyKey: req.IdempotencyKey,
|
||||
})
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, gin.H{"job_id": job.JobID, "status": job.Status})
|
||||
}
|
||||
|
||||
func (h *DataManagementHandler) ListSourceProfiles(c *gin.Context) {
|
||||
sourceType := strings.TrimSpace(c.Param("source_type"))
|
||||
if sourceType == "" {
|
||||
response.BadRequest(c, "Invalid source_type")
|
||||
return
|
||||
}
|
||||
if sourceType != "postgres" && sourceType != "redis" {
|
||||
response.BadRequest(c, "source_type must be postgres or redis")
|
||||
return
|
||||
}
|
||||
|
||||
if !h.requireAgentEnabled(c) {
|
||||
return
|
||||
}
|
||||
items, err := h.dataManagementService.ListSourceProfiles(c.Request.Context(), sourceType)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, gin.H{"items": items})
|
||||
}
|
||||
|
||||
func (h *DataManagementHandler) CreateSourceProfile(c *gin.Context) {
|
||||
sourceType := strings.TrimSpace(c.Param("source_type"))
|
||||
if sourceType != "postgres" && sourceType != "redis" {
|
||||
response.BadRequest(c, "source_type must be postgres or redis")
|
||||
return
|
||||
}
|
||||
|
||||
var req CreateSourceProfileRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if !h.requireAgentEnabled(c) {
|
||||
return
|
||||
}
|
||||
profile, err := h.dataManagementService.CreateSourceProfile(c.Request.Context(), service.DataManagementCreateSourceProfileInput{
|
||||
SourceType: sourceType,
|
||||
ProfileID: req.ProfileID,
|
||||
Name: req.Name,
|
||||
Config: req.Config,
|
||||
SetActive: req.SetActive,
|
||||
})
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, profile)
|
||||
}
|
||||
|
||||
func (h *DataManagementHandler) UpdateSourceProfile(c *gin.Context) {
|
||||
sourceType := strings.TrimSpace(c.Param("source_type"))
|
||||
if sourceType != "postgres" && sourceType != "redis" {
|
||||
response.BadRequest(c, "source_type must be postgres or redis")
|
||||
return
|
||||
}
|
||||
profileID := strings.TrimSpace(c.Param("profile_id"))
|
||||
if profileID == "" {
|
||||
response.BadRequest(c, "Invalid profile_id")
|
||||
return
|
||||
}
|
||||
|
||||
var req UpdateSourceProfileRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if !h.requireAgentEnabled(c) {
|
||||
return
|
||||
}
|
||||
profile, err := h.dataManagementService.UpdateSourceProfile(c.Request.Context(), service.DataManagementUpdateSourceProfileInput{
|
||||
SourceType: sourceType,
|
||||
ProfileID: profileID,
|
||||
Name: req.Name,
|
||||
Config: req.Config,
|
||||
})
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, profile)
|
||||
}
|
||||
|
||||
func (h *DataManagementHandler) DeleteSourceProfile(c *gin.Context) {
|
||||
sourceType := strings.TrimSpace(c.Param("source_type"))
|
||||
if sourceType != "postgres" && sourceType != "redis" {
|
||||
response.BadRequest(c, "source_type must be postgres or redis")
|
||||
return
|
||||
}
|
||||
profileID := strings.TrimSpace(c.Param("profile_id"))
|
||||
if profileID == "" {
|
||||
response.BadRequest(c, "Invalid profile_id")
|
||||
return
|
||||
}
|
||||
|
||||
if !h.requireAgentEnabled(c) {
|
||||
return
|
||||
}
|
||||
if err := h.dataManagementService.DeleteSourceProfile(c.Request.Context(), sourceType, profileID); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, gin.H{"deleted": true})
|
||||
}
|
||||
|
||||
func (h *DataManagementHandler) SetActiveSourceProfile(c *gin.Context) {
|
||||
sourceType := strings.TrimSpace(c.Param("source_type"))
|
||||
if sourceType != "postgres" && sourceType != "redis" {
|
||||
response.BadRequest(c, "source_type must be postgres or redis")
|
||||
return
|
||||
}
|
||||
profileID := strings.TrimSpace(c.Param("profile_id"))
|
||||
if profileID == "" {
|
||||
response.BadRequest(c, "Invalid profile_id")
|
||||
return
|
||||
}
|
||||
|
||||
if !h.requireAgentEnabled(c) {
|
||||
return
|
||||
}
|
||||
profile, err := h.dataManagementService.SetActiveSourceProfile(c.Request.Context(), sourceType, profileID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, profile)
|
||||
}
|
||||
|
||||
func (h *DataManagementHandler) ListS3Profiles(c *gin.Context) {
|
||||
if !h.requireAgentEnabled(c) {
|
||||
return
|
||||
}
|
||||
|
||||
items, err := h.dataManagementService.ListS3Profiles(c.Request.Context())
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, gin.H{"items": items})
|
||||
}
|
||||
|
||||
func (h *DataManagementHandler) CreateS3Profile(c *gin.Context) {
|
||||
var req CreateS3ProfileRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if !h.requireAgentEnabled(c) {
|
||||
return
|
||||
}
|
||||
|
||||
profile, err := h.dataManagementService.CreateS3Profile(c.Request.Context(), service.DataManagementCreateS3ProfileInput{
|
||||
ProfileID: req.ProfileID,
|
||||
Name: req.Name,
|
||||
SetActive: req.SetActive,
|
||||
S3: service.DataManagementS3Config{
|
||||
Enabled: req.Enabled,
|
||||
Endpoint: req.Endpoint,
|
||||
Region: req.Region,
|
||||
Bucket: req.Bucket,
|
||||
AccessKeyID: req.AccessKeyID,
|
||||
SecretAccessKey: req.SecretAccessKey,
|
||||
Prefix: req.Prefix,
|
||||
ForcePathStyle: req.ForcePathStyle,
|
||||
UseSSL: req.UseSSL,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, profile)
|
||||
}
|
||||
|
||||
func (h *DataManagementHandler) UpdateS3Profile(c *gin.Context) {
|
||||
var req UpdateS3ProfileRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
profileID := strings.TrimSpace(c.Param("profile_id"))
|
||||
if profileID == "" {
|
||||
response.BadRequest(c, "Invalid profile_id")
|
||||
return
|
||||
}
|
||||
|
||||
if !h.requireAgentEnabled(c) {
|
||||
return
|
||||
}
|
||||
|
||||
profile, err := h.dataManagementService.UpdateS3Profile(c.Request.Context(), service.DataManagementUpdateS3ProfileInput{
|
||||
ProfileID: profileID,
|
||||
Name: req.Name,
|
||||
S3: service.DataManagementS3Config{
|
||||
Enabled: req.Enabled,
|
||||
Endpoint: req.Endpoint,
|
||||
Region: req.Region,
|
||||
Bucket: req.Bucket,
|
||||
AccessKeyID: req.AccessKeyID,
|
||||
SecretAccessKey: req.SecretAccessKey,
|
||||
Prefix: req.Prefix,
|
||||
ForcePathStyle: req.ForcePathStyle,
|
||||
UseSSL: req.UseSSL,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, profile)
|
||||
}
|
||||
|
||||
func (h *DataManagementHandler) DeleteS3Profile(c *gin.Context) {
|
||||
profileID := strings.TrimSpace(c.Param("profile_id"))
|
||||
if profileID == "" {
|
||||
response.BadRequest(c, "Invalid profile_id")
|
||||
return
|
||||
}
|
||||
|
||||
if !h.requireAgentEnabled(c) {
|
||||
return
|
||||
}
|
||||
if err := h.dataManagementService.DeleteS3Profile(c.Request.Context(), profileID); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, gin.H{"deleted": true})
|
||||
}
|
||||
|
||||
func (h *DataManagementHandler) SetActiveS3Profile(c *gin.Context) {
|
||||
profileID := strings.TrimSpace(c.Param("profile_id"))
|
||||
if profileID == "" {
|
||||
response.BadRequest(c, "Invalid profile_id")
|
||||
return
|
||||
}
|
||||
|
||||
if !h.requireAgentEnabled(c) {
|
||||
return
|
||||
}
|
||||
profile, err := h.dataManagementService.SetActiveS3Profile(c.Request.Context(), profileID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, profile)
|
||||
}
|
||||
|
||||
func (h *DataManagementHandler) ListBackupJobs(c *gin.Context) {
|
||||
if !h.requireAgentEnabled(c) {
|
||||
return
|
||||
}
|
||||
|
||||
pageSize := int32(20)
|
||||
if raw := strings.TrimSpace(c.Query("page_size")); raw != "" {
|
||||
v, err := strconv.Atoi(raw)
|
||||
if err != nil || v <= 0 {
|
||||
response.BadRequest(c, "Invalid page_size")
|
||||
return
|
||||
}
|
||||
pageSize = int32(v)
|
||||
}
|
||||
|
||||
result, err := h.dataManagementService.ListBackupJobs(c.Request.Context(), service.DataManagementListBackupJobsInput{
|
||||
PageSize: pageSize,
|
||||
PageToken: c.Query("page_token"),
|
||||
Status: c.Query("status"),
|
||||
BackupType: c.Query("backup_type"),
|
||||
})
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, result)
|
||||
}
|
||||
|
||||
func (h *DataManagementHandler) GetBackupJob(c *gin.Context) {
|
||||
jobID := strings.TrimSpace(c.Param("job_id"))
|
||||
if jobID == "" {
|
||||
response.BadRequest(c, "Invalid backup job ID")
|
||||
return
|
||||
}
|
||||
|
||||
if !h.requireAgentEnabled(c) {
|
||||
return
|
||||
}
|
||||
job, err := h.dataManagementService.GetBackupJob(c.Request.Context(), jobID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, job)
|
||||
}
|
||||
|
||||
func (h *DataManagementHandler) requireAgentEnabled(c *gin.Context) bool {
|
||||
if h.dataManagementService == nil {
|
||||
err := infraerrors.ServiceUnavailable(
|
||||
service.DataManagementAgentUnavailableReason,
|
||||
"data management agent service is not configured",
|
||||
).WithMetadata(map[string]string{"socket_path": service.DefaultDataManagementAgentSocketPath})
|
||||
response.ErrorFrom(c, err)
|
||||
return false
|
||||
}
|
||||
|
||||
if err := h.dataManagementService.EnsureAgentEnabled(c.Request.Context()); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func (h *DataManagementHandler) getAgentHealth(c *gin.Context) service.DataManagementAgentHealth {
|
||||
if h.dataManagementService == nil {
|
||||
return service.DataManagementAgentHealth{
|
||||
Enabled: false,
|
||||
Reason: service.DataManagementAgentUnavailableReason,
|
||||
SocketPath: service.DefaultDataManagementAgentSocketPath,
|
||||
}
|
||||
}
|
||||
return h.dataManagementService.GetAgentHealth(c.Request.Context())
|
||||
}
|
||||
|
||||
func normalizeBackupIdempotencyKey(headerValue, bodyValue string) string {
|
||||
headerKey := strings.TrimSpace(headerValue)
|
||||
if headerKey != "" {
|
||||
return headerKey
|
||||
}
|
||||
return strings.TrimSpace(bodyValue)
|
||||
}
|
||||
@@ -0,0 +1,78 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type apiEnvelope struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Reason string `json:"reason"`
|
||||
Data json.RawMessage `json:"data"`
|
||||
}
|
||||
|
||||
func TestDataManagementHandler_AgentHealthAlways200(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
svc := service.NewDataManagementServiceWithOptions(filepath.Join(t.TempDir(), "missing.sock"), 50*time.Millisecond)
|
||||
h := NewDataManagementHandler(svc)
|
||||
|
||||
r := gin.New()
|
||||
r.GET("/api/v1/admin/data-management/agent/health", h.GetAgentHealth)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/data-management/agent/health", nil)
|
||||
r.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
var envelope apiEnvelope
|
||||
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &envelope))
|
||||
require.Equal(t, 0, envelope.Code)
|
||||
|
||||
var data struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
Reason string `json:"reason"`
|
||||
SocketPath string `json:"socket_path"`
|
||||
}
|
||||
require.NoError(t, json.Unmarshal(envelope.Data, &data))
|
||||
require.False(t, data.Enabled)
|
||||
require.Equal(t, service.DataManagementDeprecatedReason, data.Reason)
|
||||
require.Equal(t, svc.SocketPath(), data.SocketPath)
|
||||
}
|
||||
|
||||
func TestDataManagementHandler_NonHealthRouteReturns503WhenDisabled(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
svc := service.NewDataManagementServiceWithOptions(filepath.Join(t.TempDir(), "missing.sock"), 50*time.Millisecond)
|
||||
h := NewDataManagementHandler(svc)
|
||||
|
||||
r := gin.New()
|
||||
r.GET("/api/v1/admin/data-management/config", h.GetConfig)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/data-management/config", nil)
|
||||
r.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusServiceUnavailable, rec.Code)
|
||||
|
||||
var envelope apiEnvelope
|
||||
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &envelope))
|
||||
require.Equal(t, http.StatusServiceUnavailable, envelope.Code)
|
||||
require.Equal(t, service.DataManagementDeprecatedReason, envelope.Reason)
|
||||
}
|
||||
|
||||
func TestNormalizeBackupIdempotencyKey(t *testing.T) {
|
||||
require.Equal(t, "from-header", normalizeBackupIdempotencyKey("from-header", "from-body"))
|
||||
require.Equal(t, "from-body", normalizeBackupIdempotencyKey(" ", " from-body "))
|
||||
require.Equal(t, "", normalizeBackupIdempotencyKey("", ""))
|
||||
}
|
||||
@@ -51,6 +51,8 @@ type CreateGroupRequest struct {
|
||||
MCPXMLInject *bool `json:"mcp_xml_inject"`
|
||||
// 支持的模型系列(仅 antigravity 平台使用)
|
||||
SupportedModelScopes []string `json:"supported_model_scopes"`
|
||||
// Sora 存储配额
|
||||
SoraStorageQuotaBytes int64 `json:"sora_storage_quota_bytes"`
|
||||
// 从指定分组复制账号(创建后自动绑定)
|
||||
CopyAccountsFromGroupIDs []int64 `json:"copy_accounts_from_group_ids"`
|
||||
}
|
||||
@@ -84,6 +86,8 @@ type UpdateGroupRequest struct {
|
||||
MCPXMLInject *bool `json:"mcp_xml_inject"`
|
||||
// 支持的模型系列(仅 antigravity 平台使用)
|
||||
SupportedModelScopes *[]string `json:"supported_model_scopes"`
|
||||
// Sora 存储配额
|
||||
SoraStorageQuotaBytes *int64 `json:"sora_storage_quota_bytes"`
|
||||
// 从指定分组复制账号(同步操作:先清空当前分组的账号绑定,再绑定源分组的账号)
|
||||
CopyAccountsFromGroupIDs []int64 `json:"copy_accounts_from_group_ids"`
|
||||
}
|
||||
@@ -198,6 +202,7 @@ func (h *GroupHandler) Create(c *gin.Context) {
|
||||
ModelRoutingEnabled: req.ModelRoutingEnabled,
|
||||
MCPXMLInject: req.MCPXMLInject,
|
||||
SupportedModelScopes: req.SupportedModelScopes,
|
||||
SoraStorageQuotaBytes: req.SoraStorageQuotaBytes,
|
||||
CopyAccountsFromGroupIDs: req.CopyAccountsFromGroupIDs,
|
||||
})
|
||||
if err != nil {
|
||||
@@ -248,6 +253,7 @@ func (h *GroupHandler) Update(c *gin.Context) {
|
||||
ModelRoutingEnabled: req.ModelRoutingEnabled,
|
||||
MCPXMLInject: req.MCPXMLInject,
|
||||
SupportedModelScopes: req.SupportedModelScopes,
|
||||
SoraStorageQuotaBytes: req.SoraStorageQuotaBytes,
|
||||
CopyAccountsFromGroupIDs: req.CopyAccountsFromGroupIDs,
|
||||
})
|
||||
if err != nil {
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
@@ -47,7 +48,12 @@ func (h *OpenAIOAuthHandler) GenerateAuthURL(c *gin.Context) {
|
||||
req = OpenAIGenerateAuthURLRequest{}
|
||||
}
|
||||
|
||||
result, err := h.openaiOAuthService.GenerateAuthURL(c.Request.Context(), req.ProxyID, req.RedirectURI)
|
||||
result, err := h.openaiOAuthService.GenerateAuthURL(
|
||||
c.Request.Context(),
|
||||
req.ProxyID,
|
||||
req.RedirectURI,
|
||||
oauthPlatformFromPath(c),
|
||||
)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
@@ -123,7 +129,14 @@ func (h *OpenAIOAuthHandler) RefreshToken(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
tokenInfo, err := h.openaiOAuthService.RefreshTokenWithClientID(c.Request.Context(), refreshToken, proxyURL, strings.TrimSpace(req.ClientID))
|
||||
// 未指定 client_id 时,根据请求路径平台自动设置默认值,避免 repository 层盲猜
|
||||
clientID := strings.TrimSpace(req.ClientID)
|
||||
if clientID == "" {
|
||||
platform := oauthPlatformFromPath(c)
|
||||
clientID, _ = openai.OAuthClientConfigByPlatform(platform)
|
||||
}
|
||||
|
||||
tokenInfo, err := h.openaiOAuthService.RefreshTokenWithClientID(c.Request.Context(), refreshToken, proxyURL, clientID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
|
||||
@@ -62,7 +62,8 @@ const (
|
||||
)
|
||||
|
||||
var wsConnCount atomic.Int32
|
||||
var wsConnCountByIP sync.Map // map[string]*atomic.Int32
|
||||
var wsConnCountByIPMu sync.Mutex
|
||||
var wsConnCountByIP = make(map[string]int32)
|
||||
|
||||
const qpsWSIdleStopDelay = 30 * time.Second
|
||||
|
||||
@@ -389,42 +390,31 @@ func tryAcquireOpsWSIPSlot(clientIP string, limit int32) bool {
|
||||
if strings.TrimSpace(clientIP) == "" || limit <= 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
v, _ := wsConnCountByIP.LoadOrStore(clientIP, &atomic.Int32{})
|
||||
counter, ok := v.(*atomic.Int32)
|
||||
if !ok {
|
||||
wsConnCountByIPMu.Lock()
|
||||
defer wsConnCountByIPMu.Unlock()
|
||||
current := wsConnCountByIP[clientIP]
|
||||
if current >= limit {
|
||||
return false
|
||||
}
|
||||
|
||||
for {
|
||||
current := counter.Load()
|
||||
if current >= limit {
|
||||
return false
|
||||
}
|
||||
if counter.CompareAndSwap(current, current+1) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
wsConnCountByIP[clientIP] = current + 1
|
||||
return true
|
||||
}
|
||||
|
||||
func releaseOpsWSIPSlot(clientIP string) {
|
||||
if strings.TrimSpace(clientIP) == "" {
|
||||
return
|
||||
}
|
||||
|
||||
v, ok := wsConnCountByIP.Load(clientIP)
|
||||
wsConnCountByIPMu.Lock()
|
||||
defer wsConnCountByIPMu.Unlock()
|
||||
current, ok := wsConnCountByIP[clientIP]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
counter, ok := v.(*atomic.Int32)
|
||||
if !ok {
|
||||
if current <= 1 {
|
||||
delete(wsConnCountByIP, clientIP)
|
||||
return
|
||||
}
|
||||
next := counter.Add(-1)
|
||||
if next <= 0 {
|
||||
// Best-effort cleanup; safe even if a new slot was acquired concurrently.
|
||||
wsConnCountByIP.Delete(clientIP)
|
||||
}
|
||||
wsConnCountByIP[clientIP] = current - 1
|
||||
}
|
||||
|
||||
func handleQPSWebSocket(parentCtx context.Context, conn *websocket.Conn) {
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -20,15 +21,17 @@ type SettingHandler struct {
|
||||
emailService *service.EmailService
|
||||
turnstileService *service.TurnstileService
|
||||
opsService *service.OpsService
|
||||
soraS3Storage *service.SoraS3Storage
|
||||
}
|
||||
|
||||
// NewSettingHandler 创建系统设置处理器
|
||||
func NewSettingHandler(settingService *service.SettingService, emailService *service.EmailService, turnstileService *service.TurnstileService, opsService *service.OpsService) *SettingHandler {
|
||||
func NewSettingHandler(settingService *service.SettingService, emailService *service.EmailService, turnstileService *service.TurnstileService, opsService *service.OpsService, soraS3Storage *service.SoraS3Storage) *SettingHandler {
|
||||
return &SettingHandler{
|
||||
settingService: settingService,
|
||||
emailService: emailService,
|
||||
turnstileService: turnstileService,
|
||||
opsService: opsService,
|
||||
soraS3Storage: soraS3Storage,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -76,6 +79,7 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
|
||||
HideCcsImportButton: settings.HideCcsImportButton,
|
||||
PurchaseSubscriptionEnabled: settings.PurchaseSubscriptionEnabled,
|
||||
PurchaseSubscriptionURL: settings.PurchaseSubscriptionURL,
|
||||
SoraClientEnabled: settings.SoraClientEnabled,
|
||||
DefaultConcurrency: settings.DefaultConcurrency,
|
||||
DefaultBalance: settings.DefaultBalance,
|
||||
EnableModelFallback: settings.EnableModelFallback,
|
||||
@@ -133,6 +137,7 @@ type UpdateSettingsRequest struct {
|
||||
HideCcsImportButton bool `json:"hide_ccs_import_button"`
|
||||
PurchaseSubscriptionEnabled *bool `json:"purchase_subscription_enabled"`
|
||||
PurchaseSubscriptionURL *string `json:"purchase_subscription_url"`
|
||||
SoraClientEnabled bool `json:"sora_client_enabled"`
|
||||
|
||||
// 默认配置
|
||||
DefaultConcurrency int `json:"default_concurrency"`
|
||||
@@ -319,6 +324,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
HideCcsImportButton: req.HideCcsImportButton,
|
||||
PurchaseSubscriptionEnabled: purchaseEnabled,
|
||||
PurchaseSubscriptionURL: purchaseURL,
|
||||
SoraClientEnabled: req.SoraClientEnabled,
|
||||
DefaultConcurrency: req.DefaultConcurrency,
|
||||
DefaultBalance: req.DefaultBalance,
|
||||
EnableModelFallback: req.EnableModelFallback,
|
||||
@@ -400,6 +406,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
HideCcsImportButton: updatedSettings.HideCcsImportButton,
|
||||
PurchaseSubscriptionEnabled: updatedSettings.PurchaseSubscriptionEnabled,
|
||||
PurchaseSubscriptionURL: updatedSettings.PurchaseSubscriptionURL,
|
||||
SoraClientEnabled: updatedSettings.SoraClientEnabled,
|
||||
DefaultConcurrency: updatedSettings.DefaultConcurrency,
|
||||
DefaultBalance: updatedSettings.DefaultBalance,
|
||||
EnableModelFallback: updatedSettings.EnableModelFallback,
|
||||
@@ -750,6 +757,384 @@ func (h *SettingHandler) GetStreamTimeoutSettings(c *gin.Context) {
|
||||
})
|
||||
}
|
||||
|
||||
func toSoraS3SettingsDTO(settings *service.SoraS3Settings) dto.SoraS3Settings {
|
||||
if settings == nil {
|
||||
return dto.SoraS3Settings{}
|
||||
}
|
||||
return dto.SoraS3Settings{
|
||||
Enabled: settings.Enabled,
|
||||
Endpoint: settings.Endpoint,
|
||||
Region: settings.Region,
|
||||
Bucket: settings.Bucket,
|
||||
AccessKeyID: settings.AccessKeyID,
|
||||
SecretAccessKeyConfigured: settings.SecretAccessKeyConfigured,
|
||||
Prefix: settings.Prefix,
|
||||
ForcePathStyle: settings.ForcePathStyle,
|
||||
CDNURL: settings.CDNURL,
|
||||
DefaultStorageQuotaBytes: settings.DefaultStorageQuotaBytes,
|
||||
}
|
||||
}
|
||||
|
||||
func toSoraS3ProfileDTO(profile service.SoraS3Profile) dto.SoraS3Profile {
|
||||
return dto.SoraS3Profile{
|
||||
ProfileID: profile.ProfileID,
|
||||
Name: profile.Name,
|
||||
IsActive: profile.IsActive,
|
||||
Enabled: profile.Enabled,
|
||||
Endpoint: profile.Endpoint,
|
||||
Region: profile.Region,
|
||||
Bucket: profile.Bucket,
|
||||
AccessKeyID: profile.AccessKeyID,
|
||||
SecretAccessKeyConfigured: profile.SecretAccessKeyConfigured,
|
||||
Prefix: profile.Prefix,
|
||||
ForcePathStyle: profile.ForcePathStyle,
|
||||
CDNURL: profile.CDNURL,
|
||||
DefaultStorageQuotaBytes: profile.DefaultStorageQuotaBytes,
|
||||
UpdatedAt: profile.UpdatedAt,
|
||||
}
|
||||
}
|
||||
|
||||
func validateSoraS3RequiredWhenEnabled(enabled bool, endpoint, bucket, accessKeyID, secretAccessKey string, hasStoredSecret bool) error {
|
||||
if !enabled {
|
||||
return nil
|
||||
}
|
||||
if strings.TrimSpace(endpoint) == "" {
|
||||
return fmt.Errorf("S3 Endpoint is required when enabled")
|
||||
}
|
||||
if strings.TrimSpace(bucket) == "" {
|
||||
return fmt.Errorf("S3 Bucket is required when enabled")
|
||||
}
|
||||
if strings.TrimSpace(accessKeyID) == "" {
|
||||
return fmt.Errorf("S3 Access Key ID is required when enabled")
|
||||
}
|
||||
if strings.TrimSpace(secretAccessKey) != "" || hasStoredSecret {
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("S3 Secret Access Key is required when enabled")
|
||||
}
|
||||
|
||||
func findSoraS3ProfileByID(items []service.SoraS3Profile, profileID string) *service.SoraS3Profile {
|
||||
for idx := range items {
|
||||
if items[idx].ProfileID == profileID {
|
||||
return &items[idx]
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetSoraS3Settings 获取 Sora S3 存储配置(兼容旧单配置接口)
|
||||
// GET /api/v1/admin/settings/sora-s3
|
||||
func (h *SettingHandler) GetSoraS3Settings(c *gin.Context) {
|
||||
settings, err := h.settingService.GetSoraS3Settings(c.Request.Context())
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, toSoraS3SettingsDTO(settings))
|
||||
}
|
||||
|
||||
// ListSoraS3Profiles 获取 Sora S3 多配置
|
||||
// GET /api/v1/admin/settings/sora-s3/profiles
|
||||
func (h *SettingHandler) ListSoraS3Profiles(c *gin.Context) {
|
||||
result, err := h.settingService.ListSoraS3Profiles(c.Request.Context())
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
items := make([]dto.SoraS3Profile, 0, len(result.Items))
|
||||
for idx := range result.Items {
|
||||
items = append(items, toSoraS3ProfileDTO(result.Items[idx]))
|
||||
}
|
||||
response.Success(c, dto.ListSoraS3ProfilesResponse{
|
||||
ActiveProfileID: result.ActiveProfileID,
|
||||
Items: items,
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateSoraS3SettingsRequest 更新/测试 Sora S3 配置请求(兼容旧接口)
|
||||
type UpdateSoraS3SettingsRequest struct {
|
||||
ProfileID string `json:"profile_id"`
|
||||
Enabled bool `json:"enabled"`
|
||||
Endpoint string `json:"endpoint"`
|
||||
Region string `json:"region"`
|
||||
Bucket string `json:"bucket"`
|
||||
AccessKeyID string `json:"access_key_id"`
|
||||
SecretAccessKey string `json:"secret_access_key"`
|
||||
Prefix string `json:"prefix"`
|
||||
ForcePathStyle bool `json:"force_path_style"`
|
||||
CDNURL string `json:"cdn_url"`
|
||||
DefaultStorageQuotaBytes int64 `json:"default_storage_quota_bytes"`
|
||||
}
|
||||
|
||||
type CreateSoraS3ProfileRequest struct {
|
||||
ProfileID string `json:"profile_id"`
|
||||
Name string `json:"name"`
|
||||
SetActive bool `json:"set_active"`
|
||||
Enabled bool `json:"enabled"`
|
||||
Endpoint string `json:"endpoint"`
|
||||
Region string `json:"region"`
|
||||
Bucket string `json:"bucket"`
|
||||
AccessKeyID string `json:"access_key_id"`
|
||||
SecretAccessKey string `json:"secret_access_key"`
|
||||
Prefix string `json:"prefix"`
|
||||
ForcePathStyle bool `json:"force_path_style"`
|
||||
CDNURL string `json:"cdn_url"`
|
||||
DefaultStorageQuotaBytes int64 `json:"default_storage_quota_bytes"`
|
||||
}
|
||||
|
||||
type UpdateSoraS3ProfileRequest struct {
|
||||
Name string `json:"name"`
|
||||
Enabled bool `json:"enabled"`
|
||||
Endpoint string `json:"endpoint"`
|
||||
Region string `json:"region"`
|
||||
Bucket string `json:"bucket"`
|
||||
AccessKeyID string `json:"access_key_id"`
|
||||
SecretAccessKey string `json:"secret_access_key"`
|
||||
Prefix string `json:"prefix"`
|
||||
ForcePathStyle bool `json:"force_path_style"`
|
||||
CDNURL string `json:"cdn_url"`
|
||||
DefaultStorageQuotaBytes int64 `json:"default_storage_quota_bytes"`
|
||||
}
|
||||
|
||||
// CreateSoraS3Profile 创建 Sora S3 配置
|
||||
// POST /api/v1/admin/settings/sora-s3/profiles
|
||||
func (h *SettingHandler) CreateSoraS3Profile(c *gin.Context) {
|
||||
var req CreateSoraS3ProfileRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if req.DefaultStorageQuotaBytes < 0 {
|
||||
req.DefaultStorageQuotaBytes = 0
|
||||
}
|
||||
if strings.TrimSpace(req.Name) == "" {
|
||||
response.BadRequest(c, "Name is required")
|
||||
return
|
||||
}
|
||||
if strings.TrimSpace(req.ProfileID) == "" {
|
||||
response.BadRequest(c, "Profile ID is required")
|
||||
return
|
||||
}
|
||||
if err := validateSoraS3RequiredWhenEnabled(req.Enabled, req.Endpoint, req.Bucket, req.AccessKeyID, req.SecretAccessKey, false); err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
created, err := h.settingService.CreateSoraS3Profile(c.Request.Context(), &service.SoraS3Profile{
|
||||
ProfileID: req.ProfileID,
|
||||
Name: req.Name,
|
||||
Enabled: req.Enabled,
|
||||
Endpoint: req.Endpoint,
|
||||
Region: req.Region,
|
||||
Bucket: req.Bucket,
|
||||
AccessKeyID: req.AccessKeyID,
|
||||
SecretAccessKey: req.SecretAccessKey,
|
||||
Prefix: req.Prefix,
|
||||
ForcePathStyle: req.ForcePathStyle,
|
||||
CDNURL: req.CDNURL,
|
||||
DefaultStorageQuotaBytes: req.DefaultStorageQuotaBytes,
|
||||
}, req.SetActive)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, toSoraS3ProfileDTO(*created))
|
||||
}
|
||||
|
||||
// UpdateSoraS3Profile 更新 Sora S3 配置
|
||||
// PUT /api/v1/admin/settings/sora-s3/profiles/:profile_id
|
||||
func (h *SettingHandler) UpdateSoraS3Profile(c *gin.Context) {
|
||||
profileID := strings.TrimSpace(c.Param("profile_id"))
|
||||
if profileID == "" {
|
||||
response.BadRequest(c, "Profile ID is required")
|
||||
return
|
||||
}
|
||||
|
||||
var req UpdateSoraS3ProfileRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if req.DefaultStorageQuotaBytes < 0 {
|
||||
req.DefaultStorageQuotaBytes = 0
|
||||
}
|
||||
if strings.TrimSpace(req.Name) == "" {
|
||||
response.BadRequest(c, "Name is required")
|
||||
return
|
||||
}
|
||||
|
||||
existingList, err := h.settingService.ListSoraS3Profiles(c.Request.Context())
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
existing := findSoraS3ProfileByID(existingList.Items, profileID)
|
||||
if existing == nil {
|
||||
response.ErrorFrom(c, service.ErrSoraS3ProfileNotFound)
|
||||
return
|
||||
}
|
||||
if err := validateSoraS3RequiredWhenEnabled(req.Enabled, req.Endpoint, req.Bucket, req.AccessKeyID, req.SecretAccessKey, existing.SecretAccessKeyConfigured); err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
updated, updateErr := h.settingService.UpdateSoraS3Profile(c.Request.Context(), profileID, &service.SoraS3Profile{
|
||||
Name: req.Name,
|
||||
Enabled: req.Enabled,
|
||||
Endpoint: req.Endpoint,
|
||||
Region: req.Region,
|
||||
Bucket: req.Bucket,
|
||||
AccessKeyID: req.AccessKeyID,
|
||||
SecretAccessKey: req.SecretAccessKey,
|
||||
Prefix: req.Prefix,
|
||||
ForcePathStyle: req.ForcePathStyle,
|
||||
CDNURL: req.CDNURL,
|
||||
DefaultStorageQuotaBytes: req.DefaultStorageQuotaBytes,
|
||||
})
|
||||
if updateErr != nil {
|
||||
response.ErrorFrom(c, updateErr)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, toSoraS3ProfileDTO(*updated))
|
||||
}
|
||||
|
||||
// DeleteSoraS3Profile 删除 Sora S3 配置
|
||||
// DELETE /api/v1/admin/settings/sora-s3/profiles/:profile_id
|
||||
func (h *SettingHandler) DeleteSoraS3Profile(c *gin.Context) {
|
||||
profileID := strings.TrimSpace(c.Param("profile_id"))
|
||||
if profileID == "" {
|
||||
response.BadRequest(c, "Profile ID is required")
|
||||
return
|
||||
}
|
||||
if err := h.settingService.DeleteSoraS3Profile(c.Request.Context(), profileID); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, gin.H{"deleted": true})
|
||||
}
|
||||
|
||||
// SetActiveSoraS3Profile 切换激活 Sora S3 配置
|
||||
// POST /api/v1/admin/settings/sora-s3/profiles/:profile_id/activate
|
||||
func (h *SettingHandler) SetActiveSoraS3Profile(c *gin.Context) {
|
||||
profileID := strings.TrimSpace(c.Param("profile_id"))
|
||||
if profileID == "" {
|
||||
response.BadRequest(c, "Profile ID is required")
|
||||
return
|
||||
}
|
||||
active, err := h.settingService.SetActiveSoraS3Profile(c.Request.Context(), profileID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, toSoraS3ProfileDTO(*active))
|
||||
}
|
||||
|
||||
// UpdateSoraS3Settings 更新 Sora S3 存储配置(兼容旧单配置接口)
|
||||
// PUT /api/v1/admin/settings/sora-s3
|
||||
func (h *SettingHandler) UpdateSoraS3Settings(c *gin.Context) {
|
||||
var req UpdateSoraS3SettingsRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
existing, err := h.settingService.GetSoraS3Settings(c.Request.Context())
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
if req.DefaultStorageQuotaBytes < 0 {
|
||||
req.DefaultStorageQuotaBytes = 0
|
||||
}
|
||||
if err := validateSoraS3RequiredWhenEnabled(req.Enabled, req.Endpoint, req.Bucket, req.AccessKeyID, req.SecretAccessKey, existing.SecretAccessKeyConfigured); err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
settings := &service.SoraS3Settings{
|
||||
Enabled: req.Enabled,
|
||||
Endpoint: req.Endpoint,
|
||||
Region: req.Region,
|
||||
Bucket: req.Bucket,
|
||||
AccessKeyID: req.AccessKeyID,
|
||||
SecretAccessKey: req.SecretAccessKey,
|
||||
Prefix: req.Prefix,
|
||||
ForcePathStyle: req.ForcePathStyle,
|
||||
CDNURL: req.CDNURL,
|
||||
DefaultStorageQuotaBytes: req.DefaultStorageQuotaBytes,
|
||||
}
|
||||
if err := h.settingService.SetSoraS3Settings(c.Request.Context(), settings); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
updatedSettings, err := h.settingService.GetSoraS3Settings(c.Request.Context())
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, toSoraS3SettingsDTO(updatedSettings))
|
||||
}
|
||||
|
||||
// TestSoraS3Connection 测试 Sora S3 连接(HeadBucket)
|
||||
// POST /api/v1/admin/settings/sora-s3/test
|
||||
func (h *SettingHandler) TestSoraS3Connection(c *gin.Context) {
|
||||
if h.soraS3Storage == nil {
|
||||
response.Error(c, 500, "S3 存储服务未初始化")
|
||||
return
|
||||
}
|
||||
|
||||
var req UpdateSoraS3SettingsRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
if !req.Enabled {
|
||||
response.BadRequest(c, "S3 未启用,无法测试连接")
|
||||
return
|
||||
}
|
||||
|
||||
if req.SecretAccessKey == "" {
|
||||
if req.ProfileID != "" {
|
||||
profiles, err := h.settingService.ListSoraS3Profiles(c.Request.Context())
|
||||
if err == nil {
|
||||
profile := findSoraS3ProfileByID(profiles.Items, req.ProfileID)
|
||||
if profile != nil {
|
||||
req.SecretAccessKey = profile.SecretAccessKey
|
||||
}
|
||||
}
|
||||
}
|
||||
if req.SecretAccessKey == "" {
|
||||
existing, err := h.settingService.GetSoraS3Settings(c.Request.Context())
|
||||
if err == nil {
|
||||
req.SecretAccessKey = existing.SecretAccessKey
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
testCfg := &service.SoraS3Settings{
|
||||
Enabled: true,
|
||||
Endpoint: req.Endpoint,
|
||||
Region: req.Region,
|
||||
Bucket: req.Bucket,
|
||||
AccessKeyID: req.AccessKeyID,
|
||||
SecretAccessKey: req.SecretAccessKey,
|
||||
Prefix: req.Prefix,
|
||||
ForcePathStyle: req.ForcePathStyle,
|
||||
CDNURL: req.CDNURL,
|
||||
}
|
||||
if err := h.soraS3Storage.TestConnectionWithSettings(c.Request.Context(), testCfg); err != nil {
|
||||
response.Error(c, 400, "S3 连接测试失败: "+err.Error())
|
||||
return
|
||||
}
|
||||
response.Success(c, gin.H{"message": "S3 连接成功"})
|
||||
}
|
||||
|
||||
// UpdateStreamTimeoutSettingsRequest 更新流超时配置请求
|
||||
type UpdateStreamTimeoutSettingsRequest struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
|
||||
@@ -225,6 +225,92 @@ func TestUsageHandlerCreateCleanupTaskInvalidEndDate(t *testing.T) {
|
||||
require.Equal(t, http.StatusBadRequest, recorder.Code)
|
||||
}
|
||||
|
||||
func TestUsageHandlerCreateCleanupTaskInvalidRequestType(t *testing.T) {
|
||||
repo := &cleanupRepoStub{}
|
||||
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, MaxRangeDays: 31}}
|
||||
cleanupService := service.NewUsageCleanupService(repo, nil, nil, cfg)
|
||||
router := setupCleanupRouter(cleanupService, 88)
|
||||
|
||||
payload := map[string]any{
|
||||
"start_date": "2024-01-01",
|
||||
"end_date": "2024-01-02",
|
||||
"timezone": "UTC",
|
||||
"request_type": "invalid",
|
||||
}
|
||||
body, err := json.Marshal(payload)
|
||||
require.NoError(t, err)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/usage/cleanup-tasks", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
recorder := httptest.NewRecorder()
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
require.Equal(t, http.StatusBadRequest, recorder.Code)
|
||||
}
|
||||
|
||||
func TestUsageHandlerCreateCleanupTaskRequestTypePriority(t *testing.T) {
|
||||
repo := &cleanupRepoStub{}
|
||||
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, MaxRangeDays: 31}}
|
||||
cleanupService := service.NewUsageCleanupService(repo, nil, nil, cfg)
|
||||
router := setupCleanupRouter(cleanupService, 99)
|
||||
|
||||
payload := map[string]any{
|
||||
"start_date": "2024-01-01",
|
||||
"end_date": "2024-01-02",
|
||||
"timezone": "UTC",
|
||||
"request_type": "ws_v2",
|
||||
"stream": false,
|
||||
}
|
||||
body, err := json.Marshal(payload)
|
||||
require.NoError(t, err)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/usage/cleanup-tasks", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
recorder := httptest.NewRecorder()
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, recorder.Code)
|
||||
|
||||
repo.mu.Lock()
|
||||
defer repo.mu.Unlock()
|
||||
require.Len(t, repo.created, 1)
|
||||
created := repo.created[0]
|
||||
require.NotNil(t, created.Filters.RequestType)
|
||||
require.Equal(t, int16(service.RequestTypeWSV2), *created.Filters.RequestType)
|
||||
require.Nil(t, created.Filters.Stream)
|
||||
}
|
||||
|
||||
func TestUsageHandlerCreateCleanupTaskWithLegacyStream(t *testing.T) {
|
||||
repo := &cleanupRepoStub{}
|
||||
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, MaxRangeDays: 31}}
|
||||
cleanupService := service.NewUsageCleanupService(repo, nil, nil, cfg)
|
||||
router := setupCleanupRouter(cleanupService, 99)
|
||||
|
||||
payload := map[string]any{
|
||||
"start_date": "2024-01-01",
|
||||
"end_date": "2024-01-02",
|
||||
"timezone": "UTC",
|
||||
"stream": true,
|
||||
}
|
||||
body, err := json.Marshal(payload)
|
||||
require.NoError(t, err)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/usage/cleanup-tasks", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
recorder := httptest.NewRecorder()
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, recorder.Code)
|
||||
|
||||
repo.mu.Lock()
|
||||
defer repo.mu.Unlock()
|
||||
require.Len(t, repo.created, 1)
|
||||
created := repo.created[0]
|
||||
require.Nil(t, created.Filters.RequestType)
|
||||
require.NotNil(t, created.Filters.Stream)
|
||||
require.True(t, *created.Filters.Stream)
|
||||
}
|
||||
|
||||
func TestUsageHandlerCreateCleanupTaskSuccess(t *testing.T) {
|
||||
repo := &cleanupRepoStub{}
|
||||
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, MaxRangeDays: 31}}
|
||||
|
||||
@@ -51,6 +51,7 @@ type CreateUsageCleanupTaskRequest struct {
|
||||
AccountID *int64 `json:"account_id"`
|
||||
GroupID *int64 `json:"group_id"`
|
||||
Model *string `json:"model"`
|
||||
RequestType *string `json:"request_type"`
|
||||
Stream *bool `json:"stream"`
|
||||
BillingType *int8 `json:"billing_type"`
|
||||
Timezone string `json:"timezone"`
|
||||
@@ -101,8 +102,17 @@ func (h *UsageHandler) List(c *gin.Context) {
|
||||
|
||||
model := c.Query("model")
|
||||
|
||||
var requestType *int16
|
||||
var stream *bool
|
||||
if streamStr := c.Query("stream"); streamStr != "" {
|
||||
if requestTypeStr := strings.TrimSpace(c.Query("request_type")); requestTypeStr != "" {
|
||||
parsed, err := service.ParseUsageRequestType(requestTypeStr)
|
||||
if err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
value := int16(parsed)
|
||||
requestType = &value
|
||||
} else if streamStr := c.Query("stream"); streamStr != "" {
|
||||
val, err := strconv.ParseBool(streamStr)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid stream value, use true or false")
|
||||
@@ -152,6 +162,7 @@ func (h *UsageHandler) List(c *gin.Context) {
|
||||
AccountID: accountID,
|
||||
GroupID: groupID,
|
||||
Model: model,
|
||||
RequestType: requestType,
|
||||
Stream: stream,
|
||||
BillingType: billingType,
|
||||
StartTime: startTime,
|
||||
@@ -214,8 +225,17 @@ func (h *UsageHandler) Stats(c *gin.Context) {
|
||||
|
||||
model := c.Query("model")
|
||||
|
||||
var requestType *int16
|
||||
var stream *bool
|
||||
if streamStr := c.Query("stream"); streamStr != "" {
|
||||
if requestTypeStr := strings.TrimSpace(c.Query("request_type")); requestTypeStr != "" {
|
||||
parsed, err := service.ParseUsageRequestType(requestTypeStr)
|
||||
if err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
value := int16(parsed)
|
||||
requestType = &value
|
||||
} else if streamStr := c.Query("stream"); streamStr != "" {
|
||||
val, err := strconv.ParseBool(streamStr)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid stream value, use true or false")
|
||||
@@ -278,6 +298,7 @@ func (h *UsageHandler) Stats(c *gin.Context) {
|
||||
AccountID: accountID,
|
||||
GroupID: groupID,
|
||||
Model: model,
|
||||
RequestType: requestType,
|
||||
Stream: stream,
|
||||
BillingType: billingType,
|
||||
StartTime: &startTime,
|
||||
@@ -432,6 +453,19 @@ func (h *UsageHandler) CreateCleanupTask(c *gin.Context) {
|
||||
}
|
||||
endTime = endTime.Add(24*time.Hour - time.Nanosecond)
|
||||
|
||||
var requestType *int16
|
||||
stream := req.Stream
|
||||
if req.RequestType != nil {
|
||||
parsed, err := service.ParseUsageRequestType(*req.RequestType)
|
||||
if err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
value := int16(parsed)
|
||||
requestType = &value
|
||||
stream = nil
|
||||
}
|
||||
|
||||
filters := service.UsageCleanupFilters{
|
||||
StartTime: startTime,
|
||||
EndTime: endTime,
|
||||
@@ -440,7 +474,8 @@ func (h *UsageHandler) CreateCleanupTask(c *gin.Context) {
|
||||
AccountID: req.AccountID,
|
||||
GroupID: req.GroupID,
|
||||
Model: req.Model,
|
||||
Stream: req.Stream,
|
||||
RequestType: requestType,
|
||||
Stream: stream,
|
||||
BillingType: req.BillingType,
|
||||
}
|
||||
|
||||
@@ -464,9 +499,13 @@ func (h *UsageHandler) CreateCleanupTask(c *gin.Context) {
|
||||
if filters.Model != nil {
|
||||
model = *filters.Model
|
||||
}
|
||||
var stream any
|
||||
var streamValue any
|
||||
if filters.Stream != nil {
|
||||
stream = *filters.Stream
|
||||
streamValue = *filters.Stream
|
||||
}
|
||||
var requestTypeName any
|
||||
if filters.RequestType != nil {
|
||||
requestTypeName = service.RequestTypeFromInt16(*filters.RequestType).String()
|
||||
}
|
||||
var billingType any
|
||||
if filters.BillingType != nil {
|
||||
@@ -481,7 +520,7 @@ func (h *UsageHandler) CreateCleanupTask(c *gin.Context) {
|
||||
Body: req,
|
||||
}
|
||||
executeAdminIdempotentJSON(c, "admin.usage.cleanup_tasks.create", idempotencyPayload, service.DefaultWriteIdempotencyTTL(), func(ctx context.Context) (any, error) {
|
||||
logger.LegacyPrintf("handler.admin.usage", "[UsageCleanup] 请求创建清理任务: operator=%d start=%s end=%s user_id=%v api_key_id=%v account_id=%v group_id=%v model=%v stream=%v billing_type=%v tz=%q",
|
||||
logger.LegacyPrintf("handler.admin.usage", "[UsageCleanup] 请求创建清理任务: operator=%d start=%s end=%s user_id=%v api_key_id=%v account_id=%v group_id=%v model=%v request_type=%v stream=%v billing_type=%v tz=%q",
|
||||
subject.UserID,
|
||||
filters.StartTime.Format(time.RFC3339),
|
||||
filters.EndTime.Format(time.RFC3339),
|
||||
@@ -490,7 +529,8 @@ func (h *UsageHandler) CreateCleanupTask(c *gin.Context) {
|
||||
accountID,
|
||||
groupID,
|
||||
model,
|
||||
stream,
|
||||
requestTypeName,
|
||||
streamValue,
|
||||
billingType,
|
||||
req.Timezone,
|
||||
)
|
||||
|
||||
@@ -0,0 +1,117 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type adminUsageRepoCapture struct {
|
||||
service.UsageLogRepository
|
||||
listFilters usagestats.UsageLogFilters
|
||||
statsFilters usagestats.UsageLogFilters
|
||||
}
|
||||
|
||||
func (s *adminUsageRepoCapture) ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters usagestats.UsageLogFilters) ([]service.UsageLog, *pagination.PaginationResult, error) {
|
||||
s.listFilters = filters
|
||||
return []service.UsageLog{}, &pagination.PaginationResult{
|
||||
Total: 0,
|
||||
Page: params.Page,
|
||||
PageSize: params.PageSize,
|
||||
Pages: 0,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *adminUsageRepoCapture) GetStatsWithFilters(ctx context.Context, filters usagestats.UsageLogFilters) (*usagestats.UsageStats, error) {
|
||||
s.statsFilters = filters
|
||||
return &usagestats.UsageStats{}, nil
|
||||
}
|
||||
|
||||
func newAdminUsageRequestTypeTestRouter(repo *adminUsageRepoCapture) *gin.Engine {
|
||||
gin.SetMode(gin.TestMode)
|
||||
usageSvc := service.NewUsageService(repo, nil, nil, nil)
|
||||
handler := NewUsageHandler(usageSvc, nil, nil, nil)
|
||||
router := gin.New()
|
||||
router.GET("/admin/usage", handler.List)
|
||||
router.GET("/admin/usage/stats", handler.Stats)
|
||||
return router
|
||||
}
|
||||
|
||||
func TestAdminUsageListRequestTypePriority(t *testing.T) {
|
||||
repo := &adminUsageRepoCapture{}
|
||||
router := newAdminUsageRequestTypeTestRouter(repo)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/admin/usage?request_type=ws_v2&stream=false", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
require.NotNil(t, repo.listFilters.RequestType)
|
||||
require.Equal(t, int16(service.RequestTypeWSV2), *repo.listFilters.RequestType)
|
||||
require.Nil(t, repo.listFilters.Stream)
|
||||
}
|
||||
|
||||
func TestAdminUsageListInvalidRequestType(t *testing.T) {
|
||||
repo := &adminUsageRepoCapture{}
|
||||
router := newAdminUsageRequestTypeTestRouter(repo)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/admin/usage?request_type=bad", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusBadRequest, rec.Code)
|
||||
}
|
||||
|
||||
func TestAdminUsageListInvalidStream(t *testing.T) {
|
||||
repo := &adminUsageRepoCapture{}
|
||||
router := newAdminUsageRequestTypeTestRouter(repo)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/admin/usage?stream=bad", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusBadRequest, rec.Code)
|
||||
}
|
||||
|
||||
func TestAdminUsageStatsRequestTypePriority(t *testing.T) {
|
||||
repo := &adminUsageRepoCapture{}
|
||||
router := newAdminUsageRequestTypeTestRouter(repo)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/admin/usage/stats?request_type=stream&stream=bad", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
require.NotNil(t, repo.statsFilters.RequestType)
|
||||
require.Equal(t, int16(service.RequestTypeStream), *repo.statsFilters.RequestType)
|
||||
require.Nil(t, repo.statsFilters.Stream)
|
||||
}
|
||||
|
||||
func TestAdminUsageStatsInvalidRequestType(t *testing.T) {
|
||||
repo := &adminUsageRepoCapture{}
|
||||
router := newAdminUsageRequestTypeTestRouter(repo)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/admin/usage/stats?request_type=oops", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusBadRequest, rec.Code)
|
||||
}
|
||||
|
||||
func TestAdminUsageStatsInvalidStream(t *testing.T) {
|
||||
repo := &adminUsageRepoCapture{}
|
||||
router := newAdminUsageRequestTypeTestRouter(repo)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/admin/usage/stats?stream=oops", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusBadRequest, rec.Code)
|
||||
}
|
||||
@@ -34,13 +34,14 @@ func NewUserHandler(adminService service.AdminService, concurrencyService *servi
|
||||
|
||||
// CreateUserRequest represents admin create user request
|
||||
type CreateUserRequest struct {
|
||||
Email string `json:"email" binding:"required,email"`
|
||||
Password string `json:"password" binding:"required,min=6"`
|
||||
Username string `json:"username"`
|
||||
Notes string `json:"notes"`
|
||||
Balance float64 `json:"balance"`
|
||||
Concurrency int `json:"concurrency"`
|
||||
AllowedGroups []int64 `json:"allowed_groups"`
|
||||
Email string `json:"email" binding:"required,email"`
|
||||
Password string `json:"password" binding:"required,min=6"`
|
||||
Username string `json:"username"`
|
||||
Notes string `json:"notes"`
|
||||
Balance float64 `json:"balance"`
|
||||
Concurrency int `json:"concurrency"`
|
||||
AllowedGroups []int64 `json:"allowed_groups"`
|
||||
SoraStorageQuotaBytes int64 `json:"sora_storage_quota_bytes"`
|
||||
}
|
||||
|
||||
// UpdateUserRequest represents admin update user request
|
||||
@@ -56,7 +57,8 @@ type UpdateUserRequest struct {
|
||||
AllowedGroups *[]int64 `json:"allowed_groups"`
|
||||
// GroupRates 用户专属分组倍率配置
|
||||
// map[groupID]*rate,nil 表示删除该分组的专属倍率
|
||||
GroupRates map[int64]*float64 `json:"group_rates"`
|
||||
GroupRates map[int64]*float64 `json:"group_rates"`
|
||||
SoraStorageQuotaBytes *int64 `json:"sora_storage_quota_bytes"`
|
||||
}
|
||||
|
||||
// UpdateBalanceRequest represents balance update request
|
||||
@@ -174,13 +176,14 @@ func (h *UserHandler) Create(c *gin.Context) {
|
||||
}
|
||||
|
||||
user, err := h.adminService.CreateUser(c.Request.Context(), &service.CreateUserInput{
|
||||
Email: req.Email,
|
||||
Password: req.Password,
|
||||
Username: req.Username,
|
||||
Notes: req.Notes,
|
||||
Balance: req.Balance,
|
||||
Concurrency: req.Concurrency,
|
||||
AllowedGroups: req.AllowedGroups,
|
||||
Email: req.Email,
|
||||
Password: req.Password,
|
||||
Username: req.Username,
|
||||
Notes: req.Notes,
|
||||
Balance: req.Balance,
|
||||
Concurrency: req.Concurrency,
|
||||
AllowedGroups: req.AllowedGroups,
|
||||
SoraStorageQuotaBytes: req.SoraStorageQuotaBytes,
|
||||
})
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
@@ -207,15 +210,16 @@ func (h *UserHandler) Update(c *gin.Context) {
|
||||
|
||||
// 使用指针类型直接传递,nil 表示未提供该字段
|
||||
user, err := h.adminService.UpdateUser(c.Request.Context(), userID, &service.UpdateUserInput{
|
||||
Email: req.Email,
|
||||
Password: req.Password,
|
||||
Username: req.Username,
|
||||
Notes: req.Notes,
|
||||
Balance: req.Balance,
|
||||
Concurrency: req.Concurrency,
|
||||
Status: req.Status,
|
||||
AllowedGroups: req.AllowedGroups,
|
||||
GroupRates: req.GroupRates,
|
||||
Email: req.Email,
|
||||
Password: req.Password,
|
||||
Username: req.Username,
|
||||
Notes: req.Notes,
|
||||
Balance: req.Balance,
|
||||
Concurrency: req.Concurrency,
|
||||
Status: req.Status,
|
||||
AllowedGroups: req.AllowedGroups,
|
||||
GroupRates: req.GroupRates,
|
||||
SoraStorageQuotaBytes: req.SoraStorageQuotaBytes,
|
||||
})
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
|
||||
@@ -113,9 +113,8 @@ func (h *AuthHandler) Register(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// Turnstile 验证 — 始终执行,防止绕过
|
||||
// TODO: 确认前端在提交邮箱验证码注册时也传递了 turnstile_token
|
||||
if err := h.authService.VerifyTurnstile(c.Request.Context(), req.TurnstileToken, ip.GetClientIP(c)); err != nil {
|
||||
// Turnstile 验证(邮箱验证码注册场景避免重复校验一次性 token)
|
||||
if err := h.authService.VerifyTurnstileForRegister(c.Request.Context(), req.TurnstileToken, ip.GetClientIP(c), req.VerifyCode); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -59,9 +59,11 @@ func UserFromServiceAdmin(u *service.User) *AdminUser {
|
||||
return nil
|
||||
}
|
||||
return &AdminUser{
|
||||
User: *base,
|
||||
Notes: u.Notes,
|
||||
GroupRates: u.GroupRates,
|
||||
User: *base,
|
||||
Notes: u.Notes,
|
||||
GroupRates: u.GroupRates,
|
||||
SoraStorageQuotaBytes: u.SoraStorageQuotaBytes,
|
||||
SoraStorageUsedBytes: u.SoraStorageUsedBytes,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -152,6 +154,7 @@ func groupFromServiceBase(g *service.Group) Group {
|
||||
ClaudeCodeOnly: g.ClaudeCodeOnly,
|
||||
FallbackGroupID: g.FallbackGroupID,
|
||||
FallbackGroupIDOnInvalidRequest: g.FallbackGroupIDOnInvalidRequest,
|
||||
SoraStorageQuotaBytes: g.SoraStorageQuotaBytes,
|
||||
CreatedAt: g.CreatedAt,
|
||||
UpdatedAt: g.UpdatedAt,
|
||||
}
|
||||
@@ -385,6 +388,8 @@ func AccountSummaryFromService(a *service.Account) *AccountSummary {
|
||||
|
||||
func usageLogFromServiceUser(l *service.UsageLog) UsageLog {
|
||||
// 普通用户 DTO:严禁包含管理员字段(例如 account_rate_multiplier、ip_address、account)。
|
||||
requestType := l.EffectiveRequestType()
|
||||
stream, openAIWSMode := service.ApplyLegacyRequestFields(requestType, l.Stream, l.OpenAIWSMode)
|
||||
return UsageLog{
|
||||
ID: l.ID,
|
||||
UserID: l.UserID,
|
||||
@@ -409,7 +414,9 @@ func usageLogFromServiceUser(l *service.UsageLog) UsageLog {
|
||||
ActualCost: l.ActualCost,
|
||||
RateMultiplier: l.RateMultiplier,
|
||||
BillingType: l.BillingType,
|
||||
Stream: l.Stream,
|
||||
RequestType: requestType.String(),
|
||||
Stream: stream,
|
||||
OpenAIWSMode: openAIWSMode,
|
||||
DurationMs: l.DurationMs,
|
||||
FirstTokenMs: l.FirstTokenMs,
|
||||
ImageCount: l.ImageCount,
|
||||
@@ -464,6 +471,7 @@ func UsageCleanupTaskFromService(task *service.UsageCleanupTask) *UsageCleanupTa
|
||||
AccountID: task.Filters.AccountID,
|
||||
GroupID: task.Filters.GroupID,
|
||||
Model: task.Filters.Model,
|
||||
RequestType: requestTypeStringPtr(task.Filters.RequestType),
|
||||
Stream: task.Filters.Stream,
|
||||
BillingType: task.Filters.BillingType,
|
||||
},
|
||||
@@ -479,6 +487,14 @@ func UsageCleanupTaskFromService(task *service.UsageCleanupTask) *UsageCleanupTa
|
||||
}
|
||||
}
|
||||
|
||||
func requestTypeStringPtr(requestType *int16) *string {
|
||||
if requestType == nil {
|
||||
return nil
|
||||
}
|
||||
value := service.RequestTypeFromInt16(*requestType).String()
|
||||
return &value
|
||||
}
|
||||
|
||||
func SettingFromService(s *service.Setting) *Setting {
|
||||
if s == nil {
|
||||
return nil
|
||||
|
||||
73
backend/internal/handler/dto/mappers_usage_test.go
Normal file
73
backend/internal/handler/dto/mappers_usage_test.go
Normal file
@@ -0,0 +1,73 @@
|
||||
package dto
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestUsageLogFromService_IncludesOpenAIWSMode(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
wsLog := &service.UsageLog{
|
||||
RequestID: "req_1",
|
||||
Model: "gpt-5.3-codex",
|
||||
OpenAIWSMode: true,
|
||||
}
|
||||
httpLog := &service.UsageLog{
|
||||
RequestID: "resp_1",
|
||||
Model: "gpt-5.3-codex",
|
||||
OpenAIWSMode: false,
|
||||
}
|
||||
|
||||
require.True(t, UsageLogFromService(wsLog).OpenAIWSMode)
|
||||
require.False(t, UsageLogFromService(httpLog).OpenAIWSMode)
|
||||
require.True(t, UsageLogFromServiceAdmin(wsLog).OpenAIWSMode)
|
||||
require.False(t, UsageLogFromServiceAdmin(httpLog).OpenAIWSMode)
|
||||
}
|
||||
|
||||
func TestUsageLogFromService_PrefersRequestTypeForLegacyFields(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
log := &service.UsageLog{
|
||||
RequestID: "req_2",
|
||||
Model: "gpt-5.3-codex",
|
||||
RequestType: service.RequestTypeWSV2,
|
||||
Stream: false,
|
||||
OpenAIWSMode: false,
|
||||
}
|
||||
|
||||
userDTO := UsageLogFromService(log)
|
||||
adminDTO := UsageLogFromServiceAdmin(log)
|
||||
|
||||
require.Equal(t, "ws_v2", userDTO.RequestType)
|
||||
require.True(t, userDTO.Stream)
|
||||
require.True(t, userDTO.OpenAIWSMode)
|
||||
require.Equal(t, "ws_v2", adminDTO.RequestType)
|
||||
require.True(t, adminDTO.Stream)
|
||||
require.True(t, adminDTO.OpenAIWSMode)
|
||||
}
|
||||
|
||||
func TestUsageCleanupTaskFromService_RequestTypeMapping(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
requestType := int16(service.RequestTypeStream)
|
||||
task := &service.UsageCleanupTask{
|
||||
ID: 1,
|
||||
Status: service.UsageCleanupStatusPending,
|
||||
Filters: service.UsageCleanupFilters{
|
||||
RequestType: &requestType,
|
||||
},
|
||||
}
|
||||
|
||||
dtoTask := UsageCleanupTaskFromService(task)
|
||||
require.NotNil(t, dtoTask)
|
||||
require.NotNil(t, dtoTask.Filters.RequestType)
|
||||
require.Equal(t, "stream", *dtoTask.Filters.RequestType)
|
||||
}
|
||||
|
||||
func TestRequestTypeStringPtrNil(t *testing.T) {
|
||||
t.Parallel()
|
||||
require.Nil(t, requestTypeStringPtr(nil))
|
||||
}
|
||||
@@ -37,6 +37,7 @@ type SystemSettings struct {
|
||||
HideCcsImportButton bool `json:"hide_ccs_import_button"`
|
||||
PurchaseSubscriptionEnabled bool `json:"purchase_subscription_enabled"`
|
||||
PurchaseSubscriptionURL string `json:"purchase_subscription_url"`
|
||||
SoraClientEnabled bool `json:"sora_client_enabled"`
|
||||
|
||||
DefaultConcurrency int `json:"default_concurrency"`
|
||||
DefaultBalance float64 `json:"default_balance"`
|
||||
@@ -79,9 +80,48 @@ type PublicSettings struct {
|
||||
PurchaseSubscriptionEnabled bool `json:"purchase_subscription_enabled"`
|
||||
PurchaseSubscriptionURL string `json:"purchase_subscription_url"`
|
||||
LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"`
|
||||
SoraClientEnabled bool `json:"sora_client_enabled"`
|
||||
Version string `json:"version"`
|
||||
}
|
||||
|
||||
// SoraS3Settings Sora S3 存储配置 DTO(响应用,不含敏感字段)
|
||||
type SoraS3Settings struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
Endpoint string `json:"endpoint"`
|
||||
Region string `json:"region"`
|
||||
Bucket string `json:"bucket"`
|
||||
AccessKeyID string `json:"access_key_id"`
|
||||
SecretAccessKeyConfigured bool `json:"secret_access_key_configured"`
|
||||
Prefix string `json:"prefix"`
|
||||
ForcePathStyle bool `json:"force_path_style"`
|
||||
CDNURL string `json:"cdn_url"`
|
||||
DefaultStorageQuotaBytes int64 `json:"default_storage_quota_bytes"`
|
||||
}
|
||||
|
||||
// SoraS3Profile Sora S3 存储配置项 DTO(响应用,不含敏感字段)
|
||||
type SoraS3Profile struct {
|
||||
ProfileID string `json:"profile_id"`
|
||||
Name string `json:"name"`
|
||||
IsActive bool `json:"is_active"`
|
||||
Enabled bool `json:"enabled"`
|
||||
Endpoint string `json:"endpoint"`
|
||||
Region string `json:"region"`
|
||||
Bucket string `json:"bucket"`
|
||||
AccessKeyID string `json:"access_key_id"`
|
||||
SecretAccessKeyConfigured bool `json:"secret_access_key_configured"`
|
||||
Prefix string `json:"prefix"`
|
||||
ForcePathStyle bool `json:"force_path_style"`
|
||||
CDNURL string `json:"cdn_url"`
|
||||
DefaultStorageQuotaBytes int64 `json:"default_storage_quota_bytes"`
|
||||
UpdatedAt string `json:"updated_at"`
|
||||
}
|
||||
|
||||
// ListSoraS3ProfilesResponse Sora S3 配置列表响应
|
||||
type ListSoraS3ProfilesResponse struct {
|
||||
ActiveProfileID string `json:"active_profile_id"`
|
||||
Items []SoraS3Profile `json:"items"`
|
||||
}
|
||||
|
||||
// StreamTimeoutSettings 流超时处理配置 DTO
|
||||
type StreamTimeoutSettings struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
|
||||
@@ -26,7 +26,9 @@ type AdminUser struct {
|
||||
Notes string `json:"notes"`
|
||||
// GroupRates 用户专属分组倍率配置
|
||||
// map[groupID]rateMultiplier
|
||||
GroupRates map[int64]float64 `json:"group_rates,omitempty"`
|
||||
GroupRates map[int64]float64 `json:"group_rates,omitempty"`
|
||||
SoraStorageQuotaBytes int64 `json:"sora_storage_quota_bytes"`
|
||||
SoraStorageUsedBytes int64 `json:"sora_storage_used_bytes"`
|
||||
}
|
||||
|
||||
type APIKey struct {
|
||||
@@ -80,6 +82,9 @@ type Group struct {
|
||||
// 无效请求兜底分组
|
||||
FallbackGroupIDOnInvalidRequest *int64 `json:"fallback_group_id_on_invalid_request"`
|
||||
|
||||
// Sora 存储配额
|
||||
SoraStorageQuotaBytes int64 `json:"sora_storage_quota_bytes"`
|
||||
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
@@ -278,10 +283,12 @@ type UsageLog struct {
|
||||
ActualCost float64 `json:"actual_cost"`
|
||||
RateMultiplier float64 `json:"rate_multiplier"`
|
||||
|
||||
BillingType int8 `json:"billing_type"`
|
||||
Stream bool `json:"stream"`
|
||||
DurationMs *int `json:"duration_ms"`
|
||||
FirstTokenMs *int `json:"first_token_ms"`
|
||||
BillingType int8 `json:"billing_type"`
|
||||
RequestType string `json:"request_type"`
|
||||
Stream bool `json:"stream"`
|
||||
OpenAIWSMode bool `json:"openai_ws_mode"`
|
||||
DurationMs *int `json:"duration_ms"`
|
||||
FirstTokenMs *int `json:"first_token_ms"`
|
||||
|
||||
// 图片生成字段
|
||||
ImageCount int `json:"image_count"`
|
||||
@@ -324,6 +331,7 @@ type UsageCleanupFilters struct {
|
||||
AccountID *int64 `json:"account_id,omitempty"`
|
||||
GroupID *int64 `json:"group_id,omitempty"`
|
||||
Model *string `json:"model,omitempty"`
|
||||
RequestType *string `json:"request_type,omitempty"`
|
||||
Stream *bool `json:"stream,omitempty"`
|
||||
BillingType *int8 `json:"billing_type,omitempty"`
|
||||
}
|
||||
|
||||
@@ -2,11 +2,12 @@ package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// TempUnscheduler 用于 HandleFailoverError 中同账号重试耗尽后的临时封禁。
|
||||
@@ -78,8 +79,12 @@ func (s *FailoverState) HandleFailoverError(
|
||||
// 同账号重试:对 RetryableOnSameAccount 的临时性错误,先在同一账号上重试
|
||||
if failoverErr.RetryableOnSameAccount && s.SameAccountRetryCount[accountID] < maxSameAccountRetries {
|
||||
s.SameAccountRetryCount[accountID]++
|
||||
log.Printf("Account %d: retryable error %d, same-account retry %d/%d",
|
||||
accountID, failoverErr.StatusCode, s.SameAccountRetryCount[accountID], maxSameAccountRetries)
|
||||
logger.FromContext(ctx).Warn("gateway.failover_same_account_retry",
|
||||
zap.Int64("account_id", accountID),
|
||||
zap.Int("upstream_status", failoverErr.StatusCode),
|
||||
zap.Int("same_account_retry_count", s.SameAccountRetryCount[accountID]),
|
||||
zap.Int("same_account_retry_max", maxSameAccountRetries),
|
||||
)
|
||||
if !sleepWithContext(ctx, sameAccountRetryDelay) {
|
||||
return FailoverCanceled
|
||||
}
|
||||
@@ -101,8 +106,12 @@ func (s *FailoverState) HandleFailoverError(
|
||||
|
||||
// 递增切换计数
|
||||
s.SwitchCount++
|
||||
log.Printf("Account %d: upstream error %d, switching account %d/%d",
|
||||
accountID, failoverErr.StatusCode, s.SwitchCount, s.MaxSwitches)
|
||||
logger.FromContext(ctx).Warn("gateway.failover_switch_account",
|
||||
zap.Int64("account_id", accountID),
|
||||
zap.Int("upstream_status", failoverErr.StatusCode),
|
||||
zap.Int("switch_count", s.SwitchCount),
|
||||
zap.Int("max_switches", s.MaxSwitches),
|
||||
)
|
||||
|
||||
// Antigravity 平台换号线性递增延时
|
||||
if platform == service.PlatformAntigravity {
|
||||
@@ -127,13 +136,18 @@ func (s *FailoverState) HandleSelectionExhausted(ctx context.Context) FailoverAc
|
||||
s.LastFailoverErr.StatusCode == http.StatusServiceUnavailable &&
|
||||
s.SwitchCount <= s.MaxSwitches {
|
||||
|
||||
log.Printf("Antigravity single-account 503 backoff: waiting %v before retry (attempt %d)",
|
||||
singleAccountBackoffDelay, s.SwitchCount)
|
||||
logger.FromContext(ctx).Warn("gateway.failover_single_account_backoff",
|
||||
zap.Duration("backoff_delay", singleAccountBackoffDelay),
|
||||
zap.Int("switch_count", s.SwitchCount),
|
||||
zap.Int("max_switches", s.MaxSwitches),
|
||||
)
|
||||
if !sleepWithContext(ctx, singleAccountBackoffDelay) {
|
||||
return FailoverCanceled
|
||||
}
|
||||
log.Printf("Antigravity single-account 503 retry: clearing failed accounts, retry %d/%d",
|
||||
s.SwitchCount, s.MaxSwitches)
|
||||
logger.FromContext(ctx).Warn("gateway.failover_single_account_retry",
|
||||
zap.Int("switch_count", s.SwitchCount),
|
||||
zap.Int("max_switches", s.MaxSwitches),
|
||||
)
|
||||
s.FailedAccountIDs = make(map[int64]struct{})
|
||||
return FailoverContinue
|
||||
}
|
||||
|
||||
@@ -6,9 +6,10 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
@@ -17,6 +18,7 @@ import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
||||
pkgerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
||||
@@ -27,6 +29,10 @@ import (
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
const gatewayCompatibilityMetricsLogInterval = 1024
|
||||
|
||||
var gatewayCompatibilityMetricsLogCounter atomic.Uint64
|
||||
|
||||
// GatewayHandler handles API gateway requests
|
||||
type GatewayHandler struct {
|
||||
gatewayService *service.GatewayService
|
||||
@@ -109,9 +115,10 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
zap.Int64("api_key_id", apiKey.ID),
|
||||
zap.Any("group_id", apiKey.GroupID),
|
||||
)
|
||||
defer h.maybeLogCompatibilityFallbackMetrics(reqLog)
|
||||
|
||||
// 读取请求体
|
||||
body, err := io.ReadAll(c.Request.Body)
|
||||
body, err := pkghttputil.ReadRequestBodyWithPrealloc(c.Request)
|
||||
if err != nil {
|
||||
if maxErr, ok := extractMaxBytesError(err); ok {
|
||||
h.errorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit))
|
||||
@@ -140,16 +147,16 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
// 设置 max_tokens=1 + haiku 探测请求标识到 context 中
|
||||
// 必须在 SetClaudeCodeClientContext 之前设置,因为 ClaudeCodeValidator 需要读取此标识进行绕过判断
|
||||
if isMaxTokensOneHaikuRequest(reqModel, parsedReq.MaxTokens, reqStream) {
|
||||
ctx := context.WithValue(c.Request.Context(), ctxkey.IsMaxTokensOneHaikuRequest, true)
|
||||
ctx := service.WithIsMaxTokensOneHaikuRequest(c.Request.Context(), true, h.metadataBridgeEnabled())
|
||||
c.Request = c.Request.WithContext(ctx)
|
||||
}
|
||||
|
||||
// 检查是否为 Claude Code 客户端,设置到 context 中
|
||||
SetClaudeCodeClientContext(c, body)
|
||||
// 检查是否为 Claude Code 客户端,设置到 context 中(复用已解析请求,避免二次反序列化)。
|
||||
SetClaudeCodeClientContext(c, body, parsedReq)
|
||||
isClaudeCodeClient := service.IsClaudeCodeClient(c.Request.Context())
|
||||
|
||||
// 在请求上下文中记录 thinking 状态,供 Antigravity 最终模型 key 推导/模型维度限流使用
|
||||
c.Request = c.Request.WithContext(context.WithValue(c.Request.Context(), ctxkey.ThinkingEnabled, parsedReq.ThinkingEnabled))
|
||||
c.Request = c.Request.WithContext(service.WithThinkingEnabled(c.Request.Context(), parsedReq.ThinkingEnabled, h.metadataBridgeEnabled()))
|
||||
|
||||
setOpsRequestContext(c, reqModel, reqStream, body)
|
||||
|
||||
@@ -247,8 +254,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
if apiKey.GroupID != nil {
|
||||
prefetchedGroupID = *apiKey.GroupID
|
||||
}
|
||||
ctx := context.WithValue(c.Request.Context(), ctxkey.PrefetchedStickyAccountID, sessionBoundAccountID)
|
||||
ctx = context.WithValue(ctx, ctxkey.PrefetchedStickyGroupID, prefetchedGroupID)
|
||||
ctx := service.WithPrefetchedStickySession(c.Request.Context(), sessionBoundAccountID, prefetchedGroupID, h.metadataBridgeEnabled())
|
||||
c.Request = c.Request.WithContext(ctx)
|
||||
}
|
||||
}
|
||||
@@ -261,7 +267,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
// 单账号分组提前设置 SingleAccountRetry 标记,让 Service 层首次 503 就不设模型限流标记。
|
||||
// 避免单账号分组收到 503 (MODEL_CAPACITY_EXHAUSTED) 时设 29s 限流,导致后续请求连续快速失败。
|
||||
if h.gatewayService.IsSingleAntigravityAccountGroup(c.Request.Context(), apiKey.GroupID) {
|
||||
ctx := context.WithValue(c.Request.Context(), ctxkey.SingleAccountRetry, true)
|
||||
ctx := service.WithSingleAccountRetry(c.Request.Context(), true, h.metadataBridgeEnabled())
|
||||
c.Request = c.Request.WithContext(ctx)
|
||||
}
|
||||
|
||||
@@ -275,7 +281,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
action := fs.HandleSelectionExhausted(c.Request.Context())
|
||||
switch action {
|
||||
case FailoverContinue:
|
||||
ctx := context.WithValue(c.Request.Context(), ctxkey.SingleAccountRetry, true)
|
||||
ctx := service.WithSingleAccountRetry(c.Request.Context(), true, h.metadataBridgeEnabled())
|
||||
c.Request = c.Request.WithContext(ctx)
|
||||
continue
|
||||
case FailoverCanceled:
|
||||
@@ -364,7 +370,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
var result *service.ForwardResult
|
||||
requestCtx := c.Request.Context()
|
||||
if fs.SwitchCount > 0 {
|
||||
requestCtx = context.WithValue(requestCtx, ctxkey.AccountSwitchCount, fs.SwitchCount)
|
||||
requestCtx = service.WithAccountSwitchCount(requestCtx, fs.SwitchCount, h.metadataBridgeEnabled())
|
||||
}
|
||||
if account.Platform == service.PlatformAntigravity {
|
||||
result, err = h.antigravityGatewayService.ForwardGemini(requestCtx, c, account, reqModel, "generateContent", reqStream, body, hasBoundSession)
|
||||
@@ -439,7 +445,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
// 单账号分组提前设置 SingleAccountRetry 标记,让 Service 层首次 503 就不设模型限流标记。
|
||||
// 避免单账号分组收到 503 (MODEL_CAPACITY_EXHAUSTED) 时设 29s 限流,导致后续请求连续快速失败。
|
||||
if h.gatewayService.IsSingleAntigravityAccountGroup(c.Request.Context(), currentAPIKey.GroupID) {
|
||||
ctx := context.WithValue(c.Request.Context(), ctxkey.SingleAccountRetry, true)
|
||||
ctx := service.WithSingleAccountRetry(c.Request.Context(), true, h.metadataBridgeEnabled())
|
||||
c.Request = c.Request.WithContext(ctx)
|
||||
}
|
||||
|
||||
@@ -458,7 +464,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
action := fs.HandleSelectionExhausted(c.Request.Context())
|
||||
switch action {
|
||||
case FailoverContinue:
|
||||
ctx := context.WithValue(c.Request.Context(), ctxkey.SingleAccountRetry, true)
|
||||
ctx := service.WithSingleAccountRetry(c.Request.Context(), true, h.metadataBridgeEnabled())
|
||||
c.Request = c.Request.WithContext(ctx)
|
||||
continue
|
||||
case FailoverCanceled:
|
||||
@@ -547,7 +553,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
var result *service.ForwardResult
|
||||
requestCtx := c.Request.Context()
|
||||
if fs.SwitchCount > 0 {
|
||||
requestCtx = context.WithValue(requestCtx, ctxkey.AccountSwitchCount, fs.SwitchCount)
|
||||
requestCtx = service.WithAccountSwitchCount(requestCtx, fs.SwitchCount, h.metadataBridgeEnabled())
|
||||
}
|
||||
if account.Platform == service.PlatformAntigravity && account.Type != service.AccountTypeAPIKey {
|
||||
result, err = h.antigravityGatewayService.Forward(requestCtx, c, account, body, hasBoundSession)
|
||||
@@ -956,20 +962,8 @@ func (h *GatewayHandler) handleStreamingAwareError(c *gin.Context, status int, e
|
||||
// Stream already started, send error as SSE event then close
|
||||
flusher, ok := c.Writer.(http.Flusher)
|
||||
if ok {
|
||||
// Send error event in SSE format with proper JSON marshaling
|
||||
errorData := map[string]any{
|
||||
"type": "error",
|
||||
"error": map[string]string{
|
||||
"type": errType,
|
||||
"message": message,
|
||||
},
|
||||
}
|
||||
jsonBytes, err := json.Marshal(errorData)
|
||||
if err != nil {
|
||||
_ = c.Error(err)
|
||||
return
|
||||
}
|
||||
errorEvent := fmt.Sprintf("data: %s\n\n", string(jsonBytes))
|
||||
// SSE 错误事件固定 schema,使用 Quote 直拼可避免额外 Marshal 分配。
|
||||
errorEvent := `data: {"type":"error","error":{"type":` + strconv.Quote(errType) + `,"message":` + strconv.Quote(message) + `}}` + "\n\n"
|
||||
if _, err := fmt.Fprint(c.Writer, errorEvent); err != nil {
|
||||
_ = c.Error(err)
|
||||
}
|
||||
@@ -1024,9 +1018,10 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
|
||||
zap.Int64("api_key_id", apiKey.ID),
|
||||
zap.Any("group_id", apiKey.GroupID),
|
||||
)
|
||||
defer h.maybeLogCompatibilityFallbackMetrics(reqLog)
|
||||
|
||||
// 读取请求体
|
||||
body, err := io.ReadAll(c.Request.Body)
|
||||
body, err := pkghttputil.ReadRequestBodyWithPrealloc(c.Request)
|
||||
if err != nil {
|
||||
if maxErr, ok := extractMaxBytesError(err); ok {
|
||||
h.errorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit))
|
||||
@@ -1041,9 +1036,6 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// 检查是否为 Claude Code 客户端,设置到 context 中
|
||||
SetClaudeCodeClientContext(c, body)
|
||||
|
||||
setOpsRequestContext(c, "", false, body)
|
||||
|
||||
parsedReq, err := service.ParseGatewayRequest(body, domain.PlatformAnthropic)
|
||||
@@ -1051,9 +1043,11 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
|
||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body")
|
||||
return
|
||||
}
|
||||
// count_tokens 走 messages 严格校验时,复用已解析请求,避免二次反序列化。
|
||||
SetClaudeCodeClientContext(c, body, parsedReq)
|
||||
reqLog = reqLog.With(zap.String("model", parsedReq.Model), zap.Bool("stream", parsedReq.Stream))
|
||||
// 在请求上下文中记录 thinking 状态,供 Antigravity 最终模型 key 推导/模型维度限流使用
|
||||
c.Request = c.Request.WithContext(context.WithValue(c.Request.Context(), ctxkey.ThinkingEnabled, parsedReq.ThinkingEnabled))
|
||||
c.Request = c.Request.WithContext(service.WithThinkingEnabled(c.Request.Context(), parsedReq.ThinkingEnabled, h.metadataBridgeEnabled()))
|
||||
|
||||
// 验证 model 必填
|
||||
if parsedReq.Model == "" {
|
||||
@@ -1217,24 +1211,8 @@ func sendMockInterceptStream(c *gin.Context, model string, interceptType Interce
|
||||
textDeltas = []string{"New", " Conversation"}
|
||||
}
|
||||
|
||||
// Build message_start event with proper JSON marshaling
|
||||
messageStart := map[string]any{
|
||||
"type": "message_start",
|
||||
"message": map[string]any{
|
||||
"id": msgID,
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"model": model,
|
||||
"content": []any{},
|
||||
"stop_reason": nil,
|
||||
"stop_sequence": nil,
|
||||
"usage": map[string]int{
|
||||
"input_tokens": 10,
|
||||
"output_tokens": 0,
|
||||
},
|
||||
},
|
||||
}
|
||||
messageStartJSON, _ := json.Marshal(messageStart)
|
||||
// Build message_start event with fixed schema.
|
||||
messageStartJSON := `{"type":"message_start","message":{"id":` + strconv.Quote(msgID) + `,"type":"message","role":"assistant","model":` + strconv.Quote(model) + `,"content":[],"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":10,"output_tokens":0}}}`
|
||||
|
||||
// Build events
|
||||
events := []string{
|
||||
@@ -1244,31 +1222,12 @@ func sendMockInterceptStream(c *gin.Context, model string, interceptType Interce
|
||||
|
||||
// Add text deltas
|
||||
for _, text := range textDeltas {
|
||||
delta := map[string]any{
|
||||
"type": "content_block_delta",
|
||||
"index": 0,
|
||||
"delta": map[string]string{
|
||||
"type": "text_delta",
|
||||
"text": text,
|
||||
},
|
||||
}
|
||||
deltaJSON, _ := json.Marshal(delta)
|
||||
deltaJSON := `{"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":` + strconv.Quote(text) + `}}`
|
||||
events = append(events, `event: content_block_delta`+"\n"+`data: `+string(deltaJSON))
|
||||
}
|
||||
|
||||
// Add final events
|
||||
messageDelta := map[string]any{
|
||||
"type": "message_delta",
|
||||
"delta": map[string]any{
|
||||
"stop_reason": "end_turn",
|
||||
"stop_sequence": nil,
|
||||
},
|
||||
"usage": map[string]int{
|
||||
"input_tokens": 10,
|
||||
"output_tokens": outputTokens,
|
||||
},
|
||||
}
|
||||
messageDeltaJSON, _ := json.Marshal(messageDelta)
|
||||
messageDeltaJSON := `{"type":"message_delta","delta":{"stop_reason":"end_turn","stop_sequence":null},"usage":{"input_tokens":10,"output_tokens":` + strconv.Itoa(outputTokens) + `}}`
|
||||
|
||||
events = append(events,
|
||||
`event: content_block_stop`+"\n"+`data: {"index":0,"type":"content_block_stop"}`,
|
||||
@@ -1366,6 +1325,30 @@ func billingErrorDetails(err error) (status int, code, message string) {
|
||||
return http.StatusForbidden, "billing_error", msg
|
||||
}
|
||||
|
||||
func (h *GatewayHandler) metadataBridgeEnabled() bool {
|
||||
if h == nil || h.cfg == nil {
|
||||
return true
|
||||
}
|
||||
return h.cfg.Gateway.OpenAIWS.MetadataBridgeEnabled
|
||||
}
|
||||
|
||||
func (h *GatewayHandler) maybeLogCompatibilityFallbackMetrics(reqLog *zap.Logger) {
|
||||
if reqLog == nil {
|
||||
return
|
||||
}
|
||||
if gatewayCompatibilityMetricsLogCounter.Add(1)%gatewayCompatibilityMetricsLogInterval != 0 {
|
||||
return
|
||||
}
|
||||
metrics := service.SnapshotOpenAICompatibilityFallbackMetrics()
|
||||
reqLog.Info("gateway.compatibility_fallback_metrics",
|
||||
zap.Int64("session_hash_legacy_read_fallback_total", metrics.SessionHashLegacyReadFallbackTotal),
|
||||
zap.Int64("session_hash_legacy_read_fallback_hit", metrics.SessionHashLegacyReadFallbackHit),
|
||||
zap.Int64("session_hash_legacy_dual_write_total", metrics.SessionHashLegacyDualWriteTotal),
|
||||
zap.Float64("session_hash_legacy_read_hit_rate", metrics.SessionHashLegacyReadHitRate),
|
||||
zap.Int64("metadata_legacy_fallback_total", metrics.MetadataLegacyFallbackTotal),
|
||||
)
|
||||
}
|
||||
|
||||
func (h *GatewayHandler) submitUsageRecordTask(task service.UsageRecordTask) {
|
||||
if task == nil {
|
||||
return
|
||||
@@ -1377,5 +1360,13 @@ func (h *GatewayHandler) submitUsageRecordTask(task service.UsageRecordTask) {
|
||||
// 回退路径:worker 池未注入时同步执行,避免退回到无界 goroutine 模式。
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
defer func() {
|
||||
if recovered := recover(); recovered != nil {
|
||||
logger.L().With(
|
||||
zap.String("component", "handler.gateway.messages"),
|
||||
zap.Any("panic", recovered),
|
||||
).Error("gateway.usage_record_task_panic_recovered")
|
||||
}
|
||||
}()
|
||||
task(ctx)
|
||||
}
|
||||
|
||||
@@ -119,6 +119,13 @@ func (f *fakeConcurrencyCache) GetAccountsLoadBatch(context.Context, []service.A
|
||||
func (f *fakeConcurrencyCache) GetUsersLoadBatch(context.Context, []service.UserWithConcurrency) (map[int64]*service.UserLoadInfo, error) {
|
||||
return map[int64]*service.UserLoadInfo{}, nil
|
||||
}
|
||||
func (f *fakeConcurrencyCache) GetAccountConcurrencyBatch(_ context.Context, accountIDs []int64) (map[int64]int, error) {
|
||||
result := make(map[int64]int, len(accountIDs))
|
||||
for _, id := range accountIDs {
|
||||
result[id] = 0
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
func (f *fakeConcurrencyCache) CleanupExpiredAccountSlots(context.Context, int64) error { return nil }
|
||||
|
||||
func newTestGatewayHandler(t *testing.T, group *service.Group, accounts []*service.Account) (*GatewayHandler, func()) {
|
||||
|
||||
@@ -18,12 +18,17 @@ import (
|
||||
// claudeCodeValidator is a singleton validator for Claude Code client detection
|
||||
var claudeCodeValidator = service.NewClaudeCodeValidator()
|
||||
|
||||
const claudeCodeParsedRequestContextKey = "claude_code_parsed_request"
|
||||
|
||||
// SetClaudeCodeClientContext 检查请求是否来自 Claude Code 客户端,并设置到 context 中
|
||||
// 返回更新后的 context
|
||||
func SetClaudeCodeClientContext(c *gin.Context, body []byte) {
|
||||
func SetClaudeCodeClientContext(c *gin.Context, body []byte, parsedReq *service.ParsedRequest) {
|
||||
if c == nil || c.Request == nil {
|
||||
return
|
||||
}
|
||||
if parsedReq != nil {
|
||||
c.Set(claudeCodeParsedRequestContextKey, parsedReq)
|
||||
}
|
||||
// Fast path:非 Claude CLI UA 直接判定 false,避免热路径二次 JSON 反序列化。
|
||||
if !claudeCodeValidator.ValidateUserAgent(c.GetHeader("User-Agent")) {
|
||||
ctx := service.SetClaudeCodeClient(c.Request.Context(), false)
|
||||
@@ -37,8 +42,11 @@ func SetClaudeCodeClientContext(c *gin.Context, body []byte) {
|
||||
isClaudeCode = true
|
||||
} else {
|
||||
// 仅在确认为 Claude CLI 且 messages 路径时再做 body 解析。
|
||||
var bodyMap map[string]any
|
||||
if len(body) > 0 {
|
||||
bodyMap := claudeCodeBodyMapFromParsedRequest(parsedReq)
|
||||
if bodyMap == nil {
|
||||
bodyMap = claudeCodeBodyMapFromContextCache(c)
|
||||
}
|
||||
if bodyMap == nil && len(body) > 0 {
|
||||
_ = json.Unmarshal(body, &bodyMap)
|
||||
}
|
||||
isClaudeCode = claudeCodeValidator.Validate(c.Request, bodyMap)
|
||||
@@ -49,6 +57,42 @@ func SetClaudeCodeClientContext(c *gin.Context, body []byte) {
|
||||
c.Request = c.Request.WithContext(ctx)
|
||||
}
|
||||
|
||||
func claudeCodeBodyMapFromParsedRequest(parsedReq *service.ParsedRequest) map[string]any {
|
||||
if parsedReq == nil {
|
||||
return nil
|
||||
}
|
||||
bodyMap := map[string]any{
|
||||
"model": parsedReq.Model,
|
||||
}
|
||||
if parsedReq.System != nil || parsedReq.HasSystem {
|
||||
bodyMap["system"] = parsedReq.System
|
||||
}
|
||||
if parsedReq.MetadataUserID != "" {
|
||||
bodyMap["metadata"] = map[string]any{"user_id": parsedReq.MetadataUserID}
|
||||
}
|
||||
return bodyMap
|
||||
}
|
||||
|
||||
func claudeCodeBodyMapFromContextCache(c *gin.Context) map[string]any {
|
||||
if c == nil {
|
||||
return nil
|
||||
}
|
||||
if cached, ok := c.Get(service.OpenAIParsedRequestBodyKey); ok {
|
||||
if bodyMap, ok := cached.(map[string]any); ok {
|
||||
return bodyMap
|
||||
}
|
||||
}
|
||||
if cached, ok := c.Get(claudeCodeParsedRequestContextKey); ok {
|
||||
switch v := cached.(type) {
|
||||
case *service.ParsedRequest:
|
||||
return claudeCodeBodyMapFromParsedRequest(v)
|
||||
case service.ParsedRequest:
|
||||
return claudeCodeBodyMapFromParsedRequest(&v)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// 并发槽位等待相关常量
|
||||
//
|
||||
// 性能优化说明:
|
||||
|
||||
@@ -33,6 +33,14 @@ func (m *concurrencyCacheMock) GetAccountConcurrency(ctx context.Context, accoun
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (m *concurrencyCacheMock) GetAccountConcurrencyBatch(ctx context.Context, accountIDs []int64) (map[int64]int, error) {
|
||||
result := make(map[int64]int, len(accountIDs))
|
||||
for _, accountID := range accountIDs {
|
||||
result[accountID] = 0
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (m *concurrencyCacheMock) IncrementAccountWaitCount(ctx context.Context, accountID int64, maxWait int) (bool, error) {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
@@ -49,6 +49,14 @@ func (s *helperConcurrencyCacheStub) GetAccountConcurrency(ctx context.Context,
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (s *helperConcurrencyCacheStub) GetAccountConcurrencyBatch(ctx context.Context, accountIDs []int64) (map[int64]int, error) {
|
||||
out := make(map[int64]int, len(accountIDs))
|
||||
for _, accountID := range accountIDs {
|
||||
out[accountID] = 0
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (s *helperConcurrencyCacheStub) IncrementAccountWaitCount(ctx context.Context, accountID int64, maxWait int) (bool, error) {
|
||||
return true, nil
|
||||
}
|
||||
@@ -133,7 +141,7 @@ func TestSetClaudeCodeClientContext_FastPathAndStrictPath(t *testing.T) {
|
||||
c, _ := newHelperTestContext(http.MethodPost, "/v1/messages")
|
||||
c.Request.Header.Set("User-Agent", "curl/8.6.0")
|
||||
|
||||
SetClaudeCodeClientContext(c, validClaudeCodeBodyJSON())
|
||||
SetClaudeCodeClientContext(c, validClaudeCodeBodyJSON(), nil)
|
||||
require.False(t, service.IsClaudeCodeClient(c.Request.Context()))
|
||||
})
|
||||
|
||||
@@ -141,7 +149,7 @@ func TestSetClaudeCodeClientContext_FastPathAndStrictPath(t *testing.T) {
|
||||
c, _ := newHelperTestContext(http.MethodGet, "/v1/models")
|
||||
c.Request.Header.Set("User-Agent", "claude-cli/1.0.1")
|
||||
|
||||
SetClaudeCodeClientContext(c, nil)
|
||||
SetClaudeCodeClientContext(c, nil, nil)
|
||||
require.True(t, service.IsClaudeCodeClient(c.Request.Context()))
|
||||
})
|
||||
|
||||
@@ -152,7 +160,7 @@ func TestSetClaudeCodeClientContext_FastPathAndStrictPath(t *testing.T) {
|
||||
c.Request.Header.Set("anthropic-beta", "message-batches-2024-09-24")
|
||||
c.Request.Header.Set("anthropic-version", "2023-06-01")
|
||||
|
||||
SetClaudeCodeClientContext(c, validClaudeCodeBodyJSON())
|
||||
SetClaudeCodeClientContext(c, validClaudeCodeBodyJSON(), nil)
|
||||
require.True(t, service.IsClaudeCodeClient(c.Request.Context()))
|
||||
})
|
||||
|
||||
@@ -160,11 +168,51 @@ func TestSetClaudeCodeClientContext_FastPathAndStrictPath(t *testing.T) {
|
||||
c, _ := newHelperTestContext(http.MethodPost, "/v1/messages")
|
||||
c.Request.Header.Set("User-Agent", "claude-cli/1.0.1")
|
||||
// 缺少严格校验所需 header + body 字段
|
||||
SetClaudeCodeClientContext(c, []byte(`{"model":"x"}`))
|
||||
SetClaudeCodeClientContext(c, []byte(`{"model":"x"}`), nil)
|
||||
require.False(t, service.IsClaudeCodeClient(c.Request.Context()))
|
||||
})
|
||||
}
|
||||
|
||||
func TestSetClaudeCodeClientContext_ReuseParsedRequestAndContextCache(t *testing.T) {
|
||||
t.Run("reuse parsed request without body unmarshal", func(t *testing.T) {
|
||||
c, _ := newHelperTestContext(http.MethodPost, "/v1/messages")
|
||||
c.Request.Header.Set("User-Agent", "claude-cli/1.0.1")
|
||||
c.Request.Header.Set("X-App", "claude-code")
|
||||
c.Request.Header.Set("anthropic-beta", "message-batches-2024-09-24")
|
||||
c.Request.Header.Set("anthropic-version", "2023-06-01")
|
||||
|
||||
parsedReq := &service.ParsedRequest{
|
||||
Model: "claude-3-5-sonnet-20241022",
|
||||
System: []any{
|
||||
map[string]any{"text": "You are Claude Code, Anthropic's official CLI for Claude."},
|
||||
},
|
||||
MetadataUserID: "user_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa_account__session_abc-123",
|
||||
}
|
||||
|
||||
// body 非法 JSON,如果函数复用 parsedReq 成功则仍应判定为 Claude Code。
|
||||
SetClaudeCodeClientContext(c, []byte(`{invalid`), parsedReq)
|
||||
require.True(t, service.IsClaudeCodeClient(c.Request.Context()))
|
||||
})
|
||||
|
||||
t.Run("reuse context cache without body unmarshal", func(t *testing.T) {
|
||||
c, _ := newHelperTestContext(http.MethodPost, "/v1/messages")
|
||||
c.Request.Header.Set("User-Agent", "claude-cli/1.0.1")
|
||||
c.Request.Header.Set("X-App", "claude-code")
|
||||
c.Request.Header.Set("anthropic-beta", "message-batches-2024-09-24")
|
||||
c.Request.Header.Set("anthropic-version", "2023-06-01")
|
||||
c.Set(service.OpenAIParsedRequestBodyKey, map[string]any{
|
||||
"model": "claude-3-5-sonnet-20241022",
|
||||
"system": []any{
|
||||
map[string]any{"text": "You are Claude Code, Anthropic's official CLI for Claude."},
|
||||
},
|
||||
"metadata": map[string]any{"user_id": "user_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa_account__session_abc-123"},
|
||||
})
|
||||
|
||||
SetClaudeCodeClientContext(c, []byte(`{invalid`), nil)
|
||||
require.True(t, service.IsClaudeCodeClient(c.Request.Context()))
|
||||
})
|
||||
}
|
||||
|
||||
func TestWaitForSlotWithPingTimeout_AccountAndUserAcquire(t *testing.T) {
|
||||
cache := &helperConcurrencyCacheStub{
|
||||
accountSeq: []bool{false, true},
|
||||
|
||||
@@ -7,16 +7,15 @@ import (
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/domain"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/gemini"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/googleapi"
|
||||
pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
@@ -168,7 +167,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
stream := action == "streamGenerateContent"
|
||||
reqLog = reqLog.With(zap.String("model", modelName), zap.String("action", action), zap.Bool("stream", stream))
|
||||
|
||||
body, err := io.ReadAll(c.Request.Body)
|
||||
body, err := pkghttputil.ReadRequestBodyWithPrealloc(c.Request)
|
||||
if err != nil {
|
||||
if maxErr, ok := extractMaxBytesError(err); ok {
|
||||
googleError(c, http.StatusRequestEntityTooLarge, buildBodyTooLargeMessage(maxErr.Limit))
|
||||
@@ -268,8 +267,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
if apiKey.GroupID != nil {
|
||||
prefetchedGroupID = *apiKey.GroupID
|
||||
}
|
||||
ctx := context.WithValue(c.Request.Context(), ctxkey.PrefetchedStickyAccountID, sessionBoundAccountID)
|
||||
ctx = context.WithValue(ctx, ctxkey.PrefetchedStickyGroupID, prefetchedGroupID)
|
||||
ctx := service.WithPrefetchedStickySession(c.Request.Context(), sessionBoundAccountID, prefetchedGroupID, h.metadataBridgeEnabled())
|
||||
c.Request = c.Request.WithContext(ctx)
|
||||
}
|
||||
}
|
||||
@@ -349,7 +347,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
// 单账号分组提前设置 SingleAccountRetry 标记,让 Service 层首次 503 就不设模型限流标记。
|
||||
// 避免单账号分组收到 503 (MODEL_CAPACITY_EXHAUSTED) 时设 29s 限流,导致后续请求连续快速失败。
|
||||
if h.gatewayService.IsSingleAntigravityAccountGroup(c.Request.Context(), apiKey.GroupID) {
|
||||
ctx := context.WithValue(c.Request.Context(), ctxkey.SingleAccountRetry, true)
|
||||
ctx := service.WithSingleAccountRetry(c.Request.Context(), true, h.metadataBridgeEnabled())
|
||||
c.Request = c.Request.WithContext(ctx)
|
||||
}
|
||||
|
||||
@@ -363,7 +361,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
action := fs.HandleSelectionExhausted(c.Request.Context())
|
||||
switch action {
|
||||
case FailoverContinue:
|
||||
ctx := context.WithValue(c.Request.Context(), ctxkey.SingleAccountRetry, true)
|
||||
ctx := service.WithSingleAccountRetry(c.Request.Context(), true, h.metadataBridgeEnabled())
|
||||
c.Request = c.Request.WithContext(ctx)
|
||||
continue
|
||||
case FailoverCanceled:
|
||||
@@ -456,7 +454,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
var result *service.ForwardResult
|
||||
requestCtx := c.Request.Context()
|
||||
if fs.SwitchCount > 0 {
|
||||
requestCtx = context.WithValue(requestCtx, ctxkey.AccountSwitchCount, fs.SwitchCount)
|
||||
requestCtx = service.WithAccountSwitchCount(requestCtx, fs.SwitchCount, h.metadataBridgeEnabled())
|
||||
}
|
||||
if account.Platform == service.PlatformAntigravity && account.Type != service.AccountTypeAPIKey {
|
||||
result, err = h.antigravityGatewayService.ForwardGemini(requestCtx, c, account, modelName, action, stream, body, hasBoundSession)
|
||||
|
||||
@@ -11,6 +11,7 @@ type AdminHandlers struct {
|
||||
Group *admin.GroupHandler
|
||||
Account *admin.AccountHandler
|
||||
Announcement *admin.AnnouncementHandler
|
||||
DataManagement *admin.DataManagementHandler
|
||||
OAuth *admin.OAuthHandler
|
||||
OpenAIOAuth *admin.OpenAIOAuthHandler
|
||||
GeminiOAuth *admin.GeminiOAuthHandler
|
||||
@@ -40,6 +41,7 @@ type Handlers struct {
|
||||
Gateway *GatewayHandler
|
||||
OpenAIGateway *OpenAIGatewayHandler
|
||||
SoraGateway *SoraGatewayHandler
|
||||
SoraClient *SoraClientHandler
|
||||
Setting *SettingHandler
|
||||
Totp *TotpHandler
|
||||
}
|
||||
|
||||
@@ -5,17 +5,20 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"runtime/debug"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
coderws "github.com/coder/websocket"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/tidwall/gjson"
|
||||
"go.uber.org/zap"
|
||||
@@ -64,6 +67,11 @@ func NewOpenAIGatewayHandler(
|
||||
// Responses handles OpenAI Responses API endpoint
|
||||
// POST /openai/v1/responses
|
||||
func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
// 局部兜底:确保该 handler 内部任何 panic 都不会击穿到进程级。
|
||||
streamStarted := false
|
||||
defer h.recoverResponsesPanic(c, &streamStarted)
|
||||
setOpenAIClientTransportHTTP(c)
|
||||
|
||||
requestStart := time.Now()
|
||||
|
||||
// Get apiKey and user from context (set by ApiKeyAuth middleware)
|
||||
@@ -85,9 +93,12 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
zap.Int64("api_key_id", apiKey.ID),
|
||||
zap.Any("group_id", apiKey.GroupID),
|
||||
)
|
||||
if !h.ensureResponsesDependencies(c, reqLog) {
|
||||
return
|
||||
}
|
||||
|
||||
// Read request body
|
||||
body, err := io.ReadAll(c.Request.Body)
|
||||
body, err := pkghttputil.ReadRequestBodyWithPrealloc(c.Request)
|
||||
if err != nil {
|
||||
if maxErr, ok := extractMaxBytesError(err); ok {
|
||||
h.errorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit))
|
||||
@@ -125,43 +136,30 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
}
|
||||
reqStream := streamResult.Bool()
|
||||
reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", reqStream))
|
||||
previousResponseID := strings.TrimSpace(gjson.GetBytes(body, "previous_response_id").String())
|
||||
if previousResponseID != "" {
|
||||
previousResponseIDKind := service.ClassifyOpenAIPreviousResponseIDKind(previousResponseID)
|
||||
reqLog = reqLog.With(
|
||||
zap.Bool("has_previous_response_id", true),
|
||||
zap.String("previous_response_id_kind", previousResponseIDKind),
|
||||
zap.Int("previous_response_id_len", len(previousResponseID)),
|
||||
)
|
||||
if previousResponseIDKind == service.OpenAIPreviousResponseIDKindMessageID {
|
||||
reqLog.Warn("openai.request_validation_failed",
|
||||
zap.String("reason", "previous_response_id_looks_like_message_id"),
|
||||
)
|
||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "previous_response_id must be a response.id (resp_*), not a message id")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
setOpsRequestContext(c, reqModel, reqStream, body)
|
||||
|
||||
// 提前校验 function_call_output 是否具备可关联上下文,避免上游 400。
|
||||
// 要求 previous_response_id,或 input 内存在带 call_id 的 tool_call/function_call,
|
||||
// 或带 id 且与 call_id 匹配的 item_reference。
|
||||
// 此路径需要遍历 input 数组做 call_id 关联检查,保留 Unmarshal
|
||||
if gjson.GetBytes(body, `input.#(type=="function_call_output")`).Exists() {
|
||||
var reqBody map[string]any
|
||||
if err := json.Unmarshal(body, &reqBody); err == nil {
|
||||
c.Set(service.OpenAIParsedRequestBodyKey, reqBody)
|
||||
if service.HasFunctionCallOutput(reqBody) {
|
||||
previousResponseID, _ := reqBody["previous_response_id"].(string)
|
||||
if strings.TrimSpace(previousResponseID) == "" && !service.HasToolCallContext(reqBody) {
|
||||
if service.HasFunctionCallOutputMissingCallID(reqBody) {
|
||||
reqLog.Warn("openai.request_validation_failed",
|
||||
zap.String("reason", "function_call_output_missing_call_id"),
|
||||
)
|
||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "function_call_output requires call_id or previous_response_id; if relying on history, ensure store=true and reuse previous_response_id")
|
||||
return
|
||||
}
|
||||
callIDs := service.FunctionCallOutputCallIDs(reqBody)
|
||||
if !service.HasItemReferenceForCallIDs(reqBody, callIDs) {
|
||||
reqLog.Warn("openai.request_validation_failed",
|
||||
zap.String("reason", "function_call_output_missing_item_reference"),
|
||||
)
|
||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "function_call_output requires item_reference ids matching each call_id, or previous_response_id/tool_call context; if relying on history, ensure store=true and reuse previous_response_id")
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if !h.validateFunctionCallOutputRequest(c, body, reqLog) {
|
||||
return
|
||||
}
|
||||
|
||||
// Track if we've started streaming (for error handling)
|
||||
streamStarted := false
|
||||
|
||||
// 绑定错误透传服务,允许 service 层在非 failover 错误场景复用规则。
|
||||
if h.errorPassthroughService != nil {
|
||||
service.BindErrorPassthroughService(c, h.errorPassthroughService)
|
||||
@@ -173,51 +171,11 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
service.SetOpsLatencyMs(c, service.OpsAuthLatencyMsKey, time.Since(requestStart).Milliseconds())
|
||||
routingStart := time.Now()
|
||||
|
||||
// 0. 先尝试直接抢占用户槽位(快速路径)
|
||||
userReleaseFunc, userAcquired, err := h.concurrencyHelper.TryAcquireUserSlot(c.Request.Context(), subject.UserID, subject.Concurrency)
|
||||
if err != nil {
|
||||
reqLog.Warn("openai.user_slot_acquire_failed", zap.Error(err))
|
||||
h.handleConcurrencyError(c, err, "user", streamStarted)
|
||||
userReleaseFunc, acquired := h.acquireResponsesUserSlot(c, subject.UserID, subject.Concurrency, reqStream, &streamStarted, reqLog)
|
||||
if !acquired {
|
||||
return
|
||||
}
|
||||
|
||||
waitCounted := false
|
||||
if !userAcquired {
|
||||
// 仅在抢槽失败时才进入等待队列,减少常态请求 Redis 写入。
|
||||
maxWait := service.CalculateMaxWait(subject.Concurrency)
|
||||
canWait, waitErr := h.concurrencyHelper.IncrementWaitCount(c.Request.Context(), subject.UserID, maxWait)
|
||||
if waitErr != nil {
|
||||
reqLog.Warn("openai.user_wait_counter_increment_failed", zap.Error(waitErr))
|
||||
// 按现有降级语义:等待计数异常时放行后续抢槽流程
|
||||
} else if !canWait {
|
||||
reqLog.Info("openai.user_wait_queue_full", zap.Int("max_wait", maxWait))
|
||||
h.errorResponse(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later")
|
||||
return
|
||||
}
|
||||
if waitErr == nil && canWait {
|
||||
waitCounted = true
|
||||
}
|
||||
defer func() {
|
||||
if waitCounted {
|
||||
h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID)
|
||||
}
|
||||
}()
|
||||
|
||||
userReleaseFunc, err = h.concurrencyHelper.AcquireUserSlotWithWait(c, subject.UserID, subject.Concurrency, reqStream, &streamStarted)
|
||||
if err != nil {
|
||||
reqLog.Warn("openai.user_slot_acquire_failed_after_wait", zap.Error(err))
|
||||
h.handleConcurrencyError(c, err, "user", streamStarted)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// 用户槽位已获取:退出等待队列计数。
|
||||
if waitCounted {
|
||||
h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID)
|
||||
waitCounted = false
|
||||
}
|
||||
// 确保请求取消时也会释放槽位,避免长连接被动中断造成泄漏
|
||||
userReleaseFunc = wrapReleaseOnDone(c.Request.Context(), userReleaseFunc)
|
||||
if userReleaseFunc != nil {
|
||||
defer userReleaseFunc()
|
||||
}
|
||||
@@ -241,7 +199,15 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
for {
|
||||
// Select account supporting the requested model
|
||||
reqLog.Debug("openai.account_selecting", zap.Int("excluded_account_count", len(failedAccountIDs)))
|
||||
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, failedAccountIDs)
|
||||
selection, scheduleDecision, err := h.gatewayService.SelectAccountWithScheduler(
|
||||
c.Request.Context(),
|
||||
apiKey.GroupID,
|
||||
previousResponseID,
|
||||
sessionHash,
|
||||
reqModel,
|
||||
failedAccountIDs,
|
||||
service.OpenAIUpstreamTransportAny,
|
||||
)
|
||||
if err != nil {
|
||||
reqLog.Warn("openai.account_select_failed",
|
||||
zap.Error(err),
|
||||
@@ -258,80 +224,30 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
}
|
||||
return
|
||||
}
|
||||
if selection == nil || selection.Account == nil {
|
||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted)
|
||||
return
|
||||
}
|
||||
if previousResponseID != "" && selection != nil && selection.Account != nil {
|
||||
reqLog.Debug("openai.account_selected_with_previous_response_id", zap.Int64("account_id", selection.Account.ID))
|
||||
}
|
||||
reqLog.Debug("openai.account_schedule_decision",
|
||||
zap.String("layer", scheduleDecision.Layer),
|
||||
zap.Bool("sticky_previous_hit", scheduleDecision.StickyPreviousHit),
|
||||
zap.Bool("sticky_session_hit", scheduleDecision.StickySessionHit),
|
||||
zap.Int("candidate_count", scheduleDecision.CandidateCount),
|
||||
zap.Int("top_k", scheduleDecision.TopK),
|
||||
zap.Int64("latency_ms", scheduleDecision.LatencyMs),
|
||||
zap.Float64("load_skew", scheduleDecision.LoadSkew),
|
||||
)
|
||||
account := selection.Account
|
||||
reqLog.Debug("openai.account_selected", zap.Int64("account_id", account.ID), zap.String("account_name", account.Name))
|
||||
setOpsSelectedAccount(c, account.ID, account.Platform)
|
||||
|
||||
// 3. Acquire account concurrency slot
|
||||
accountReleaseFunc := selection.ReleaseFunc
|
||||
if !selection.Acquired {
|
||||
if selection.WaitPlan == nil {
|
||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted)
|
||||
return
|
||||
}
|
||||
|
||||
// 先快速尝试一次账号槽位,命中则跳过等待计数写入。
|
||||
fastReleaseFunc, fastAcquired, err := h.concurrencyHelper.TryAcquireAccountSlot(
|
||||
c.Request.Context(),
|
||||
account.ID,
|
||||
selection.WaitPlan.MaxConcurrency,
|
||||
)
|
||||
if err != nil {
|
||||
reqLog.Warn("openai.account_slot_quick_acquire_failed", zap.Int64("account_id", account.ID), zap.Error(err))
|
||||
h.handleConcurrencyError(c, err, "account", streamStarted)
|
||||
return
|
||||
}
|
||||
if fastAcquired {
|
||||
accountReleaseFunc = fastReleaseFunc
|
||||
if err := h.gatewayService.BindStickySession(c.Request.Context(), apiKey.GroupID, sessionHash, account.ID); err != nil {
|
||||
reqLog.Warn("openai.bind_sticky_session_failed", zap.Int64("account_id", account.ID), zap.Error(err))
|
||||
}
|
||||
} else {
|
||||
accountWaitCounted := false
|
||||
canWait, err := h.concurrencyHelper.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting)
|
||||
if err != nil {
|
||||
reqLog.Warn("openai.account_wait_counter_increment_failed", zap.Int64("account_id", account.ID), zap.Error(err))
|
||||
} else if !canWait {
|
||||
reqLog.Info("openai.account_wait_queue_full",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int("max_waiting", selection.WaitPlan.MaxWaiting),
|
||||
)
|
||||
h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later", streamStarted)
|
||||
return
|
||||
}
|
||||
if err == nil && canWait {
|
||||
accountWaitCounted = true
|
||||
}
|
||||
releaseWait := func() {
|
||||
if accountWaitCounted {
|
||||
h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID)
|
||||
accountWaitCounted = false
|
||||
}
|
||||
}
|
||||
|
||||
accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout(
|
||||
c,
|
||||
account.ID,
|
||||
selection.WaitPlan.MaxConcurrency,
|
||||
selection.WaitPlan.Timeout,
|
||||
reqStream,
|
||||
&streamStarted,
|
||||
)
|
||||
if err != nil {
|
||||
reqLog.Warn("openai.account_slot_acquire_failed", zap.Int64("account_id", account.ID), zap.Error(err))
|
||||
releaseWait()
|
||||
h.handleConcurrencyError(c, err, "account", streamStarted)
|
||||
return
|
||||
}
|
||||
// Slot acquired: no longer waiting in queue.
|
||||
releaseWait()
|
||||
if err := h.gatewayService.BindStickySession(c.Request.Context(), apiKey.GroupID, sessionHash, account.ID); err != nil {
|
||||
reqLog.Warn("openai.bind_sticky_session_failed", zap.Int64("account_id", account.ID), zap.Error(err))
|
||||
}
|
||||
}
|
||||
accountReleaseFunc, acquired := h.acquireResponsesAccountSlot(c, apiKey.GroupID, sessionHash, selection, reqStream, &streamStarted, reqLog)
|
||||
if !acquired {
|
||||
return
|
||||
}
|
||||
// 账号槽位/等待计数需要在超时或断开时安全回收
|
||||
accountReleaseFunc = wrapReleaseOnDone(c.Request.Context(), accountReleaseFunc)
|
||||
|
||||
// Forward request
|
||||
service.SetOpsLatencyMs(c, service.OpsRoutingLatencyMsKey, time.Since(routingStart).Milliseconds())
|
||||
@@ -353,6 +269,8 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
if err != nil {
|
||||
var failoverErr *service.UpstreamFailoverError
|
||||
if errors.As(err, &failoverErr) {
|
||||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
|
||||
h.gatewayService.RecordOpenAIAccountSwitch()
|
||||
failedAccountIDs[account.ID] = struct{}{}
|
||||
lastFailoverErr = failoverErr
|
||||
if switchCount >= maxAccountSwitches {
|
||||
@@ -368,14 +286,25 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
)
|
||||
continue
|
||||
}
|
||||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
|
||||
wroteFallback := h.ensureForwardErrorResponse(c, streamStarted)
|
||||
reqLog.Error("openai.forward_failed",
|
||||
fields := []zap.Field{
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Bool("fallback_error_response_written", wroteFallback),
|
||||
zap.Error(err),
|
||||
)
|
||||
}
|
||||
if shouldLogOpenAIForwardFailureAsWarn(c, wroteFallback) {
|
||||
reqLog.Warn("openai.forward_failed", fields...)
|
||||
return
|
||||
}
|
||||
reqLog.Error("openai.forward_failed", fields...)
|
||||
return
|
||||
}
|
||||
if result != nil {
|
||||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, result.FirstTokenMs)
|
||||
} else {
|
||||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, nil)
|
||||
}
|
||||
|
||||
// 捕获请求信息(用于异步记录,避免在 goroutine 中访问 gin.Context)
|
||||
userAgent := c.GetHeader("User-Agent")
|
||||
@@ -411,6 +340,525 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
func (h *OpenAIGatewayHandler) validateFunctionCallOutputRequest(c *gin.Context, body []byte, reqLog *zap.Logger) bool {
|
||||
if !gjson.GetBytes(body, `input.#(type=="function_call_output")`).Exists() {
|
||||
return true
|
||||
}
|
||||
|
||||
var reqBody map[string]any
|
||||
if err := json.Unmarshal(body, &reqBody); err != nil {
|
||||
// 保持原有容错语义:解析失败时跳过预校验,沿用后续上游校验结果。
|
||||
return true
|
||||
}
|
||||
|
||||
c.Set(service.OpenAIParsedRequestBodyKey, reqBody)
|
||||
validation := service.ValidateFunctionCallOutputContext(reqBody)
|
||||
if !validation.HasFunctionCallOutput {
|
||||
return true
|
||||
}
|
||||
|
||||
previousResponseID, _ := reqBody["previous_response_id"].(string)
|
||||
if strings.TrimSpace(previousResponseID) != "" || validation.HasToolCallContext {
|
||||
return true
|
||||
}
|
||||
|
||||
if validation.HasFunctionCallOutputMissingCallID {
|
||||
reqLog.Warn("openai.request_validation_failed",
|
||||
zap.String("reason", "function_call_output_missing_call_id"),
|
||||
)
|
||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "function_call_output requires call_id or previous_response_id; if relying on history, ensure store=true and reuse previous_response_id")
|
||||
return false
|
||||
}
|
||||
if validation.HasItemReferenceForAllCallIDs {
|
||||
return true
|
||||
}
|
||||
|
||||
reqLog.Warn("openai.request_validation_failed",
|
||||
zap.String("reason", "function_call_output_missing_item_reference"),
|
||||
)
|
||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "function_call_output requires item_reference ids matching each call_id, or previous_response_id/tool_call context; if relying on history, ensure store=true and reuse previous_response_id")
|
||||
return false
|
||||
}
|
||||
|
||||
func (h *OpenAIGatewayHandler) acquireResponsesUserSlot(
|
||||
c *gin.Context,
|
||||
userID int64,
|
||||
userConcurrency int,
|
||||
reqStream bool,
|
||||
streamStarted *bool,
|
||||
reqLog *zap.Logger,
|
||||
) (func(), bool) {
|
||||
ctx := c.Request.Context()
|
||||
userReleaseFunc, userAcquired, err := h.concurrencyHelper.TryAcquireUserSlot(ctx, userID, userConcurrency)
|
||||
if err != nil {
|
||||
reqLog.Warn("openai.user_slot_acquire_failed", zap.Error(err))
|
||||
h.handleConcurrencyError(c, err, "user", *streamStarted)
|
||||
return nil, false
|
||||
}
|
||||
if userAcquired {
|
||||
return wrapReleaseOnDone(ctx, userReleaseFunc), true
|
||||
}
|
||||
|
||||
maxWait := service.CalculateMaxWait(userConcurrency)
|
||||
canWait, waitErr := h.concurrencyHelper.IncrementWaitCount(ctx, userID, maxWait)
|
||||
if waitErr != nil {
|
||||
reqLog.Warn("openai.user_wait_counter_increment_failed", zap.Error(waitErr))
|
||||
// 按现有降级语义:等待计数异常时放行后续抢槽流程
|
||||
} else if !canWait {
|
||||
reqLog.Info("openai.user_wait_queue_full", zap.Int("max_wait", maxWait))
|
||||
h.errorResponse(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later")
|
||||
return nil, false
|
||||
}
|
||||
|
||||
waitCounted := waitErr == nil && canWait
|
||||
defer func() {
|
||||
if waitCounted {
|
||||
h.concurrencyHelper.DecrementWaitCount(ctx, userID)
|
||||
}
|
||||
}()
|
||||
|
||||
userReleaseFunc, err = h.concurrencyHelper.AcquireUserSlotWithWait(c, userID, userConcurrency, reqStream, streamStarted)
|
||||
if err != nil {
|
||||
reqLog.Warn("openai.user_slot_acquire_failed_after_wait", zap.Error(err))
|
||||
h.handleConcurrencyError(c, err, "user", *streamStarted)
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// 槽位获取成功后,立刻退出等待计数。
|
||||
if waitCounted {
|
||||
h.concurrencyHelper.DecrementWaitCount(ctx, userID)
|
||||
waitCounted = false
|
||||
}
|
||||
return wrapReleaseOnDone(ctx, userReleaseFunc), true
|
||||
}
|
||||
|
||||
func (h *OpenAIGatewayHandler) acquireResponsesAccountSlot(
|
||||
c *gin.Context,
|
||||
groupID *int64,
|
||||
sessionHash string,
|
||||
selection *service.AccountSelectionResult,
|
||||
reqStream bool,
|
||||
streamStarted *bool,
|
||||
reqLog *zap.Logger,
|
||||
) (func(), bool) {
|
||||
if selection == nil || selection.Account == nil {
|
||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", *streamStarted)
|
||||
return nil, false
|
||||
}
|
||||
|
||||
ctx := c.Request.Context()
|
||||
account := selection.Account
|
||||
if selection.Acquired {
|
||||
return wrapReleaseOnDone(ctx, selection.ReleaseFunc), true
|
||||
}
|
||||
if selection.WaitPlan == nil {
|
||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", *streamStarted)
|
||||
return nil, false
|
||||
}
|
||||
|
||||
fastReleaseFunc, fastAcquired, err := h.concurrencyHelper.TryAcquireAccountSlot(
|
||||
ctx,
|
||||
account.ID,
|
||||
selection.WaitPlan.MaxConcurrency,
|
||||
)
|
||||
if err != nil {
|
||||
reqLog.Warn("openai.account_slot_quick_acquire_failed", zap.Int64("account_id", account.ID), zap.Error(err))
|
||||
h.handleConcurrencyError(c, err, "account", *streamStarted)
|
||||
return nil, false
|
||||
}
|
||||
if fastAcquired {
|
||||
if err := h.gatewayService.BindStickySession(ctx, groupID, sessionHash, account.ID); err != nil {
|
||||
reqLog.Warn("openai.bind_sticky_session_failed", zap.Int64("account_id", account.ID), zap.Error(err))
|
||||
}
|
||||
return wrapReleaseOnDone(ctx, fastReleaseFunc), true
|
||||
}
|
||||
|
||||
canWait, waitErr := h.concurrencyHelper.IncrementAccountWaitCount(ctx, account.ID, selection.WaitPlan.MaxWaiting)
|
||||
if waitErr != nil {
|
||||
reqLog.Warn("openai.account_wait_counter_increment_failed", zap.Int64("account_id", account.ID), zap.Error(waitErr))
|
||||
} else if !canWait {
|
||||
reqLog.Info("openai.account_wait_queue_full",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int("max_waiting", selection.WaitPlan.MaxWaiting),
|
||||
)
|
||||
h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later", *streamStarted)
|
||||
return nil, false
|
||||
}
|
||||
|
||||
accountWaitCounted := waitErr == nil && canWait
|
||||
releaseWait := func() {
|
||||
if accountWaitCounted {
|
||||
h.concurrencyHelper.DecrementAccountWaitCount(ctx, account.ID)
|
||||
accountWaitCounted = false
|
||||
}
|
||||
}
|
||||
defer releaseWait()
|
||||
|
||||
accountReleaseFunc, err := h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout(
|
||||
c,
|
||||
account.ID,
|
||||
selection.WaitPlan.MaxConcurrency,
|
||||
selection.WaitPlan.Timeout,
|
||||
reqStream,
|
||||
streamStarted,
|
||||
)
|
||||
if err != nil {
|
||||
reqLog.Warn("openai.account_slot_acquire_failed", zap.Int64("account_id", account.ID), zap.Error(err))
|
||||
h.handleConcurrencyError(c, err, "account", *streamStarted)
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// Slot acquired: no longer waiting in queue.
|
||||
releaseWait()
|
||||
if err := h.gatewayService.BindStickySession(ctx, groupID, sessionHash, account.ID); err != nil {
|
||||
reqLog.Warn("openai.bind_sticky_session_failed", zap.Int64("account_id", account.ID), zap.Error(err))
|
||||
}
|
||||
return wrapReleaseOnDone(ctx, accountReleaseFunc), true
|
||||
}
|
||||
|
||||
// ResponsesWebSocket handles OpenAI Responses API WebSocket ingress endpoint
|
||||
// GET /openai/v1/responses (Upgrade: websocket)
|
||||
func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) {
|
||||
if !isOpenAIWSUpgradeRequest(c.Request) {
|
||||
h.errorResponse(c, http.StatusUpgradeRequired, "invalid_request_error", "WebSocket upgrade required (Upgrade: websocket)")
|
||||
return
|
||||
}
|
||||
setOpenAIClientTransportWS(c)
|
||||
|
||||
apiKey, ok := middleware2.GetAPIKeyFromContext(c)
|
||||
if !ok {
|
||||
h.errorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key")
|
||||
return
|
||||
}
|
||||
subject, ok := middleware2.GetAuthSubjectFromContext(c)
|
||||
if !ok {
|
||||
h.errorResponse(c, http.StatusInternalServerError, "api_error", "User context not found")
|
||||
return
|
||||
}
|
||||
|
||||
reqLog := requestLogger(
|
||||
c,
|
||||
"handler.openai_gateway.responses_ws",
|
||||
zap.Int64("user_id", subject.UserID),
|
||||
zap.Int64("api_key_id", apiKey.ID),
|
||||
zap.Any("group_id", apiKey.GroupID),
|
||||
zap.Bool("openai_ws_mode", true),
|
||||
)
|
||||
if !h.ensureResponsesDependencies(c, reqLog) {
|
||||
return
|
||||
}
|
||||
reqLog.Info("openai.websocket_ingress_started")
|
||||
clientIP := ip.GetClientIP(c)
|
||||
userAgent := strings.TrimSpace(c.GetHeader("User-Agent"))
|
||||
|
||||
wsConn, err := coderws.Accept(c.Writer, c.Request, &coderws.AcceptOptions{
|
||||
CompressionMode: coderws.CompressionContextTakeover,
|
||||
})
|
||||
if err != nil {
|
||||
reqLog.Warn("openai.websocket_accept_failed",
|
||||
zap.Error(err),
|
||||
zap.String("client_ip", clientIP),
|
||||
zap.String("request_user_agent", userAgent),
|
||||
zap.String("upgrade_header", strings.TrimSpace(c.GetHeader("Upgrade"))),
|
||||
zap.String("connection_header", strings.TrimSpace(c.GetHeader("Connection"))),
|
||||
zap.String("sec_websocket_version", strings.TrimSpace(c.GetHeader("Sec-WebSocket-Version"))),
|
||||
zap.Bool("has_sec_websocket_key", strings.TrimSpace(c.GetHeader("Sec-WebSocket-Key")) != ""),
|
||||
)
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
_ = wsConn.CloseNow()
|
||||
}()
|
||||
wsConn.SetReadLimit(16 * 1024 * 1024)
|
||||
|
||||
ctx := c.Request.Context()
|
||||
readCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
|
||||
msgType, firstMessage, err := wsConn.Read(readCtx)
|
||||
cancel()
|
||||
if err != nil {
|
||||
closeStatus, closeReason := summarizeWSCloseErrorForLog(err)
|
||||
reqLog.Warn("openai.websocket_read_first_message_failed",
|
||||
zap.Error(err),
|
||||
zap.String("client_ip", clientIP),
|
||||
zap.String("close_status", closeStatus),
|
||||
zap.String("close_reason", closeReason),
|
||||
zap.Duration("read_timeout", 30*time.Second),
|
||||
)
|
||||
closeOpenAIClientWS(wsConn, coderws.StatusPolicyViolation, "missing first response.create message")
|
||||
return
|
||||
}
|
||||
if msgType != coderws.MessageText && msgType != coderws.MessageBinary {
|
||||
closeOpenAIClientWS(wsConn, coderws.StatusPolicyViolation, "unsupported websocket message type")
|
||||
return
|
||||
}
|
||||
if !gjson.ValidBytes(firstMessage) {
|
||||
closeOpenAIClientWS(wsConn, coderws.StatusPolicyViolation, "invalid JSON payload")
|
||||
return
|
||||
}
|
||||
|
||||
reqModel := strings.TrimSpace(gjson.GetBytes(firstMessage, "model").String())
|
||||
if reqModel == "" {
|
||||
closeOpenAIClientWS(wsConn, coderws.StatusPolicyViolation, "model is required in first response.create payload")
|
||||
return
|
||||
}
|
||||
previousResponseID := strings.TrimSpace(gjson.GetBytes(firstMessage, "previous_response_id").String())
|
||||
previousResponseIDKind := service.ClassifyOpenAIPreviousResponseIDKind(previousResponseID)
|
||||
if previousResponseID != "" && previousResponseIDKind == service.OpenAIPreviousResponseIDKindMessageID {
|
||||
closeOpenAIClientWS(wsConn, coderws.StatusPolicyViolation, "previous_response_id must be a response.id (resp_*), not a message id")
|
||||
return
|
||||
}
|
||||
reqLog = reqLog.With(
|
||||
zap.Bool("ws_ingress", true),
|
||||
zap.String("model", reqModel),
|
||||
zap.Bool("has_previous_response_id", previousResponseID != ""),
|
||||
zap.String("previous_response_id_kind", previousResponseIDKind),
|
||||
)
|
||||
setOpsRequestContext(c, reqModel, true, firstMessage)
|
||||
|
||||
var currentUserRelease func()
|
||||
var currentAccountRelease func()
|
||||
releaseTurnSlots := func() {
|
||||
if currentAccountRelease != nil {
|
||||
currentAccountRelease()
|
||||
currentAccountRelease = nil
|
||||
}
|
||||
if currentUserRelease != nil {
|
||||
currentUserRelease()
|
||||
currentUserRelease = nil
|
||||
}
|
||||
}
|
||||
// 必须尽早注册,确保任何 early return 都能释放已获取的并发槽位。
|
||||
defer releaseTurnSlots()
|
||||
|
||||
userReleaseFunc, userAcquired, err := h.concurrencyHelper.TryAcquireUserSlot(ctx, subject.UserID, subject.Concurrency)
|
||||
if err != nil {
|
||||
reqLog.Warn("openai.websocket_user_slot_acquire_failed", zap.Error(err))
|
||||
closeOpenAIClientWS(wsConn, coderws.StatusInternalError, "failed to acquire user concurrency slot")
|
||||
return
|
||||
}
|
||||
if !userAcquired {
|
||||
closeOpenAIClientWS(wsConn, coderws.StatusTryAgainLater, "too many concurrent requests, please retry later")
|
||||
return
|
||||
}
|
||||
currentUserRelease = wrapReleaseOnDone(ctx, userReleaseFunc)
|
||||
|
||||
subscription, _ := middleware2.GetSubscriptionFromContext(c)
|
||||
if err := h.billingCacheService.CheckBillingEligibility(ctx, apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
|
||||
reqLog.Info("openai.websocket_billing_eligibility_check_failed", zap.Error(err))
|
||||
closeOpenAIClientWS(wsConn, coderws.StatusPolicyViolation, "billing check failed")
|
||||
return
|
||||
}
|
||||
|
||||
sessionHash := h.gatewayService.GenerateSessionHashWithFallback(
|
||||
c,
|
||||
firstMessage,
|
||||
openAIWSIngressFallbackSessionSeed(subject.UserID, apiKey.ID, apiKey.GroupID),
|
||||
)
|
||||
selection, scheduleDecision, err := h.gatewayService.SelectAccountWithScheduler(
|
||||
ctx,
|
||||
apiKey.GroupID,
|
||||
previousResponseID,
|
||||
sessionHash,
|
||||
reqModel,
|
||||
nil,
|
||||
service.OpenAIUpstreamTransportResponsesWebsocketV2,
|
||||
)
|
||||
if err != nil {
|
||||
reqLog.Warn("openai.websocket_account_select_failed", zap.Error(err))
|
||||
closeOpenAIClientWS(wsConn, coderws.StatusTryAgainLater, "no available account")
|
||||
return
|
||||
}
|
||||
if selection == nil || selection.Account == nil {
|
||||
closeOpenAIClientWS(wsConn, coderws.StatusTryAgainLater, "no available account")
|
||||
return
|
||||
}
|
||||
|
||||
account := selection.Account
|
||||
accountMaxConcurrency := account.Concurrency
|
||||
if selection.WaitPlan != nil && selection.WaitPlan.MaxConcurrency > 0 {
|
||||
accountMaxConcurrency = selection.WaitPlan.MaxConcurrency
|
||||
}
|
||||
accountReleaseFunc := selection.ReleaseFunc
|
||||
if !selection.Acquired {
|
||||
if selection.WaitPlan == nil {
|
||||
closeOpenAIClientWS(wsConn, coderws.StatusTryAgainLater, "account is busy, please retry later")
|
||||
return
|
||||
}
|
||||
fastReleaseFunc, fastAcquired, err := h.concurrencyHelper.TryAcquireAccountSlot(
|
||||
ctx,
|
||||
account.ID,
|
||||
selection.WaitPlan.MaxConcurrency,
|
||||
)
|
||||
if err != nil {
|
||||
reqLog.Warn("openai.websocket_account_slot_acquire_failed", zap.Int64("account_id", account.ID), zap.Error(err))
|
||||
closeOpenAIClientWS(wsConn, coderws.StatusInternalError, "failed to acquire account concurrency slot")
|
||||
return
|
||||
}
|
||||
if !fastAcquired {
|
||||
closeOpenAIClientWS(wsConn, coderws.StatusTryAgainLater, "account is busy, please retry later")
|
||||
return
|
||||
}
|
||||
accountReleaseFunc = fastReleaseFunc
|
||||
}
|
||||
currentAccountRelease = wrapReleaseOnDone(ctx, accountReleaseFunc)
|
||||
if err := h.gatewayService.BindStickySession(ctx, apiKey.GroupID, sessionHash, account.ID); err != nil {
|
||||
reqLog.Warn("openai.websocket_bind_sticky_session_failed", zap.Int64("account_id", account.ID), zap.Error(err))
|
||||
}
|
||||
|
||||
token, _, err := h.gatewayService.GetAccessToken(ctx, account)
|
||||
if err != nil {
|
||||
reqLog.Warn("openai.websocket_get_access_token_failed", zap.Int64("account_id", account.ID), zap.Error(err))
|
||||
closeOpenAIClientWS(wsConn, coderws.StatusInternalError, "failed to get access token")
|
||||
return
|
||||
}
|
||||
|
||||
reqLog.Debug("openai.websocket_account_selected",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.String("account_name", account.Name),
|
||||
zap.String("schedule_layer", scheduleDecision.Layer),
|
||||
zap.Int("candidate_count", scheduleDecision.CandidateCount),
|
||||
)
|
||||
|
||||
hooks := &service.OpenAIWSIngressHooks{
|
||||
BeforeTurn: func(turn int) error {
|
||||
if turn == 1 {
|
||||
return nil
|
||||
}
|
||||
// 防御式清理:避免异常路径下旧槽位覆盖导致泄漏。
|
||||
releaseTurnSlots()
|
||||
// 非首轮 turn 需要重新抢占并发槽位,避免长连接空闲占槽。
|
||||
userReleaseFunc, userAcquired, err := h.concurrencyHelper.TryAcquireUserSlot(ctx, subject.UserID, subject.Concurrency)
|
||||
if err != nil {
|
||||
return service.NewOpenAIWSClientCloseError(coderws.StatusInternalError, "failed to acquire user concurrency slot", err)
|
||||
}
|
||||
if !userAcquired {
|
||||
return service.NewOpenAIWSClientCloseError(coderws.StatusTryAgainLater, "too many concurrent requests, please retry later", nil)
|
||||
}
|
||||
accountReleaseFunc, accountAcquired, err := h.concurrencyHelper.TryAcquireAccountSlot(ctx, account.ID, accountMaxConcurrency)
|
||||
if err != nil {
|
||||
if userReleaseFunc != nil {
|
||||
userReleaseFunc()
|
||||
}
|
||||
return service.NewOpenAIWSClientCloseError(coderws.StatusInternalError, "failed to acquire account concurrency slot", err)
|
||||
}
|
||||
if !accountAcquired {
|
||||
if userReleaseFunc != nil {
|
||||
userReleaseFunc()
|
||||
}
|
||||
return service.NewOpenAIWSClientCloseError(coderws.StatusTryAgainLater, "account is busy, please retry later", nil)
|
||||
}
|
||||
currentUserRelease = wrapReleaseOnDone(ctx, userReleaseFunc)
|
||||
currentAccountRelease = wrapReleaseOnDone(ctx, accountReleaseFunc)
|
||||
return nil
|
||||
},
|
||||
AfterTurn: func(turn int, result *service.OpenAIForwardResult, turnErr error) {
|
||||
releaseTurnSlots()
|
||||
if turnErr != nil || result == nil {
|
||||
return
|
||||
}
|
||||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, result.FirstTokenMs)
|
||||
h.submitUsageRecordTask(func(taskCtx context.Context) {
|
||||
if err := h.gatewayService.RecordUsage(taskCtx, &service.OpenAIRecordUsageInput{
|
||||
Result: result,
|
||||
APIKey: apiKey,
|
||||
User: apiKey.User,
|
||||
Account: account,
|
||||
Subscription: subscription,
|
||||
UserAgent: userAgent,
|
||||
IPAddress: clientIP,
|
||||
APIKeyService: h.apiKeyService,
|
||||
}); err != nil {
|
||||
reqLog.Error("openai.websocket_record_usage_failed",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.String("request_id", result.RequestID),
|
||||
zap.Error(err),
|
||||
)
|
||||
}
|
||||
})
|
||||
},
|
||||
}
|
||||
|
||||
if err := h.gatewayService.ProxyResponsesWebSocketFromClient(ctx, c, wsConn, account, token, firstMessage, hooks); err != nil {
|
||||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
|
||||
closeStatus, closeReason := summarizeWSCloseErrorForLog(err)
|
||||
reqLog.Warn("openai.websocket_proxy_failed",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Error(err),
|
||||
zap.String("close_status", closeStatus),
|
||||
zap.String("close_reason", closeReason),
|
||||
)
|
||||
var closeErr *service.OpenAIWSClientCloseError
|
||||
if errors.As(err, &closeErr) {
|
||||
closeOpenAIClientWS(wsConn, closeErr.StatusCode(), closeErr.Reason())
|
||||
return
|
||||
}
|
||||
closeOpenAIClientWS(wsConn, coderws.StatusInternalError, "upstream websocket proxy failed")
|
||||
return
|
||||
}
|
||||
reqLog.Info("openai.websocket_ingress_closed", zap.Int64("account_id", account.ID))
|
||||
}
|
||||
|
||||
func (h *OpenAIGatewayHandler) recoverResponsesPanic(c *gin.Context, streamStarted *bool) {
|
||||
recovered := recover()
|
||||
if recovered == nil {
|
||||
return
|
||||
}
|
||||
|
||||
started := false
|
||||
if streamStarted != nil {
|
||||
started = *streamStarted
|
||||
}
|
||||
wroteFallback := h.ensureForwardErrorResponse(c, started)
|
||||
requestLogger(c, "handler.openai_gateway.responses").Error(
|
||||
"openai.responses_panic_recovered",
|
||||
zap.Bool("fallback_error_response_written", wroteFallback),
|
||||
zap.Any("panic", recovered),
|
||||
zap.ByteString("stack", debug.Stack()),
|
||||
)
|
||||
}
|
||||
|
||||
func (h *OpenAIGatewayHandler) ensureResponsesDependencies(c *gin.Context, reqLog *zap.Logger) bool {
|
||||
missing := h.missingResponsesDependencies()
|
||||
if len(missing) == 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
if reqLog == nil {
|
||||
reqLog = requestLogger(c, "handler.openai_gateway.responses")
|
||||
}
|
||||
reqLog.Error("openai.handler_dependencies_missing", zap.Strings("missing_dependencies", missing))
|
||||
|
||||
if c != nil && c.Writer != nil && !c.Writer.Written() {
|
||||
c.JSON(http.StatusServiceUnavailable, gin.H{
|
||||
"error": gin.H{
|
||||
"type": "api_error",
|
||||
"message": "Service temporarily unavailable",
|
||||
},
|
||||
})
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (h *OpenAIGatewayHandler) missingResponsesDependencies() []string {
|
||||
missing := make([]string, 0, 5)
|
||||
if h == nil {
|
||||
return append(missing, "handler")
|
||||
}
|
||||
if h.gatewayService == nil {
|
||||
missing = append(missing, "gatewayService")
|
||||
}
|
||||
if h.billingCacheService == nil {
|
||||
missing = append(missing, "billingCacheService")
|
||||
}
|
||||
if h.apiKeyService == nil {
|
||||
missing = append(missing, "apiKeyService")
|
||||
}
|
||||
if h.concurrencyHelper == nil || h.concurrencyHelper.concurrencyService == nil {
|
||||
missing = append(missing, "concurrencyHelper")
|
||||
}
|
||||
return missing
|
||||
}
|
||||
|
||||
func getContextInt64(c *gin.Context, key string) (int64, bool) {
|
||||
if c == nil || key == "" {
|
||||
return 0, false
|
||||
@@ -444,6 +892,14 @@ func (h *OpenAIGatewayHandler) submitUsageRecordTask(task service.UsageRecordTas
|
||||
// 回退路径:worker 池未注入时同步执行,避免退回到无界 goroutine 模式。
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
defer func() {
|
||||
if recovered := recover(); recovered != nil {
|
||||
logger.L().With(
|
||||
zap.String("component", "handler.openai_gateway.responses"),
|
||||
zap.Any("panic", recovered),
|
||||
).Error("openai.usage_record_task_panic_recovered")
|
||||
}
|
||||
}()
|
||||
task(ctx)
|
||||
}
|
||||
|
||||
@@ -515,19 +971,8 @@ func (h *OpenAIGatewayHandler) handleStreamingAwareError(c *gin.Context, status
|
||||
// Stream already started, send error as SSE event then close
|
||||
flusher, ok := c.Writer.(http.Flusher)
|
||||
if ok {
|
||||
// Send error event in OpenAI SSE format with proper JSON marshaling
|
||||
errorData := map[string]any{
|
||||
"error": map[string]string{
|
||||
"type": errType,
|
||||
"message": message,
|
||||
},
|
||||
}
|
||||
jsonBytes, err := json.Marshal(errorData)
|
||||
if err != nil {
|
||||
_ = c.Error(err)
|
||||
return
|
||||
}
|
||||
errorEvent := fmt.Sprintf("event: error\ndata: %s\n\n", string(jsonBytes))
|
||||
// SSE 错误事件固定 schema,使用 Quote 直拼可避免额外 Marshal 分配。
|
||||
errorEvent := "event: error\ndata: " + `{"error":{"type":` + strconv.Quote(errType) + `,"message":` + strconv.Quote(message) + `}}` + "\n\n"
|
||||
if _, err := fmt.Fprint(c.Writer, errorEvent); err != nil {
|
||||
_ = c.Error(err)
|
||||
}
|
||||
@@ -549,6 +994,16 @@ func (h *OpenAIGatewayHandler) ensureForwardErrorResponse(c *gin.Context, stream
|
||||
return true
|
||||
}
|
||||
|
||||
func shouldLogOpenAIForwardFailureAsWarn(c *gin.Context, wroteFallback bool) bool {
|
||||
if wroteFallback {
|
||||
return false
|
||||
}
|
||||
if c == nil || c.Writer == nil {
|
||||
return false
|
||||
}
|
||||
return c.Writer.Written()
|
||||
}
|
||||
|
||||
// errorResponse returns OpenAI API format error response
|
||||
func (h *OpenAIGatewayHandler) errorResponse(c *gin.Context, status int, errType, message string) {
|
||||
c.JSON(status, gin.H{
|
||||
@@ -558,3 +1013,61 @@ func (h *OpenAIGatewayHandler) errorResponse(c *gin.Context, status int, errType
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func setOpenAIClientTransportHTTP(c *gin.Context) {
|
||||
service.SetOpenAIClientTransport(c, service.OpenAIClientTransportHTTP)
|
||||
}
|
||||
|
||||
func setOpenAIClientTransportWS(c *gin.Context) {
|
||||
service.SetOpenAIClientTransport(c, service.OpenAIClientTransportWS)
|
||||
}
|
||||
|
||||
func openAIWSIngressFallbackSessionSeed(userID, apiKeyID int64, groupID *int64) string {
|
||||
gid := int64(0)
|
||||
if groupID != nil {
|
||||
gid = *groupID
|
||||
}
|
||||
return fmt.Sprintf("openai_ws_ingress:%d:%d:%d", gid, userID, apiKeyID)
|
||||
}
|
||||
|
||||
func isOpenAIWSUpgradeRequest(r *http.Request) bool {
|
||||
if r == nil {
|
||||
return false
|
||||
}
|
||||
if !strings.EqualFold(strings.TrimSpace(r.Header.Get("Upgrade")), "websocket") {
|
||||
return false
|
||||
}
|
||||
return strings.Contains(strings.ToLower(strings.TrimSpace(r.Header.Get("Connection"))), "upgrade")
|
||||
}
|
||||
|
||||
func closeOpenAIClientWS(conn *coderws.Conn, status coderws.StatusCode, reason string) {
|
||||
if conn == nil {
|
||||
return
|
||||
}
|
||||
reason = strings.TrimSpace(reason)
|
||||
if len(reason) > 120 {
|
||||
reason = reason[:120]
|
||||
}
|
||||
_ = conn.Close(status, reason)
|
||||
_ = conn.CloseNow()
|
||||
}
|
||||
|
||||
func summarizeWSCloseErrorForLog(err error) (string, string) {
|
||||
if err == nil {
|
||||
return "-", "-"
|
||||
}
|
||||
statusCode := coderws.CloseStatus(err)
|
||||
if statusCode == -1 {
|
||||
return "-", "-"
|
||||
}
|
||||
closeStatus := fmt.Sprintf("%d(%s)", int(statusCode), statusCode.String())
|
||||
closeReason := "-"
|
||||
var closeErr coderws.CloseError
|
||||
if errors.As(err, &closeErr) {
|
||||
reason := strings.TrimSpace(closeErr.Reason)
|
||||
if reason != "" {
|
||||
closeReason = reason
|
||||
}
|
||||
}
|
||||
return closeStatus, closeReason
|
||||
}
|
||||
|
||||
@@ -1,12 +1,19 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil"
|
||||
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
coderws "github.com/coder/websocket"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
@@ -105,6 +112,27 @@ func TestOpenAIHandleStreamingAwareError_NonStreaming(t *testing.T) {
|
||||
assert.Equal(t, "test error", errorObj["message"])
|
||||
}
|
||||
|
||||
func TestReadRequestBodyWithPrealloc(t *testing.T) {
|
||||
payload := `{"model":"gpt-5","input":"hello"}`
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/responses", strings.NewReader(payload))
|
||||
req.ContentLength = int64(len(payload))
|
||||
|
||||
body, err := pkghttputil.ReadRequestBodyWithPrealloc(req)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, payload, string(body))
|
||||
}
|
||||
|
||||
func TestReadRequestBodyWithPrealloc_MaxBytesError(t *testing.T) {
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/responses", strings.NewReader(strings.Repeat("x", 8)))
|
||||
req.Body = http.MaxBytesReader(rec, req.Body, 4)
|
||||
|
||||
_, err := pkghttputil.ReadRequestBodyWithPrealloc(req)
|
||||
require.Error(t, err)
|
||||
var maxErr *http.MaxBytesError
|
||||
require.ErrorAs(t, err, &maxErr)
|
||||
}
|
||||
|
||||
func TestOpenAIEnsureForwardErrorResponse_WritesFallbackWhenNotWritten(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
@@ -141,6 +169,387 @@ func TestOpenAIEnsureForwardErrorResponse_DoesNotOverrideWrittenResponse(t *test
|
||||
assert.Equal(t, "already written", w.Body.String())
|
||||
}
|
||||
|
||||
func TestShouldLogOpenAIForwardFailureAsWarn(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
t.Run("fallback_written_should_not_downgrade", func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
require.False(t, shouldLogOpenAIForwardFailureAsWarn(c, true))
|
||||
})
|
||||
|
||||
t.Run("context_nil_should_not_downgrade", func(t *testing.T) {
|
||||
require.False(t, shouldLogOpenAIForwardFailureAsWarn(nil, false))
|
||||
})
|
||||
|
||||
t.Run("response_not_written_should_not_downgrade", func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
require.False(t, shouldLogOpenAIForwardFailureAsWarn(c, false))
|
||||
})
|
||||
|
||||
t.Run("response_already_written_should_downgrade", func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
c.String(http.StatusForbidden, "already written")
|
||||
require.True(t, shouldLogOpenAIForwardFailureAsWarn(c, false))
|
||||
})
|
||||
}
|
||||
|
||||
func TestOpenAIRecoverResponsesPanic_WritesFallbackResponse(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
|
||||
|
||||
h := &OpenAIGatewayHandler{}
|
||||
streamStarted := false
|
||||
require.NotPanics(t, func() {
|
||||
func() {
|
||||
defer h.recoverResponsesPanic(c, &streamStarted)
|
||||
panic("test panic")
|
||||
}()
|
||||
})
|
||||
|
||||
require.Equal(t, http.StatusBadGateway, w.Code)
|
||||
|
||||
var parsed map[string]any
|
||||
err := json.Unmarshal(w.Body.Bytes(), &parsed)
|
||||
require.NoError(t, err)
|
||||
|
||||
errorObj, ok := parsed["error"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "upstream_error", errorObj["type"])
|
||||
assert.Equal(t, "Upstream request failed", errorObj["message"])
|
||||
}
|
||||
|
||||
func TestOpenAIRecoverResponsesPanic_NoPanicNoWrite(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
|
||||
|
||||
h := &OpenAIGatewayHandler{}
|
||||
streamStarted := false
|
||||
require.NotPanics(t, func() {
|
||||
func() {
|
||||
defer h.recoverResponsesPanic(c, &streamStarted)
|
||||
}()
|
||||
})
|
||||
|
||||
require.False(t, c.Writer.Written())
|
||||
assert.Equal(t, "", w.Body.String())
|
||||
}
|
||||
|
||||
func TestOpenAIRecoverResponsesPanic_DoesNotOverrideWrittenResponse(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
|
||||
c.String(http.StatusTeapot, "already written")
|
||||
|
||||
h := &OpenAIGatewayHandler{}
|
||||
streamStarted := false
|
||||
require.NotPanics(t, func() {
|
||||
func() {
|
||||
defer h.recoverResponsesPanic(c, &streamStarted)
|
||||
panic("test panic")
|
||||
}()
|
||||
})
|
||||
|
||||
require.Equal(t, http.StatusTeapot, w.Code)
|
||||
assert.Equal(t, "already written", w.Body.String())
|
||||
}
|
||||
|
||||
func TestOpenAIMissingResponsesDependencies(t *testing.T) {
|
||||
t.Run("nil_handler", func(t *testing.T) {
|
||||
var h *OpenAIGatewayHandler
|
||||
require.Equal(t, []string{"handler"}, h.missingResponsesDependencies())
|
||||
})
|
||||
|
||||
t.Run("all_dependencies_missing", func(t *testing.T) {
|
||||
h := &OpenAIGatewayHandler{}
|
||||
require.Equal(t,
|
||||
[]string{"gatewayService", "billingCacheService", "apiKeyService", "concurrencyHelper"},
|
||||
h.missingResponsesDependencies(),
|
||||
)
|
||||
})
|
||||
|
||||
t.Run("all_dependencies_present", func(t *testing.T) {
|
||||
h := &OpenAIGatewayHandler{
|
||||
gatewayService: &service.OpenAIGatewayService{},
|
||||
billingCacheService: &service.BillingCacheService{},
|
||||
apiKeyService: &service.APIKeyService{},
|
||||
concurrencyHelper: &ConcurrencyHelper{
|
||||
concurrencyService: &service.ConcurrencyService{},
|
||||
},
|
||||
}
|
||||
require.Empty(t, h.missingResponsesDependencies())
|
||||
})
|
||||
}
|
||||
|
||||
func TestOpenAIEnsureResponsesDependencies(t *testing.T) {
|
||||
t.Run("missing_dependencies_returns_503", func(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
|
||||
|
||||
h := &OpenAIGatewayHandler{}
|
||||
ok := h.ensureResponsesDependencies(c, nil)
|
||||
|
||||
require.False(t, ok)
|
||||
require.Equal(t, http.StatusServiceUnavailable, w.Code)
|
||||
var parsed map[string]any
|
||||
err := json.Unmarshal(w.Body.Bytes(), &parsed)
|
||||
require.NoError(t, err)
|
||||
errorObj, exists := parsed["error"].(map[string]any)
|
||||
require.True(t, exists)
|
||||
assert.Equal(t, "api_error", errorObj["type"])
|
||||
assert.Equal(t, "Service temporarily unavailable", errorObj["message"])
|
||||
})
|
||||
|
||||
t.Run("already_written_response_not_overridden", func(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
|
||||
c.String(http.StatusTeapot, "already written")
|
||||
|
||||
h := &OpenAIGatewayHandler{}
|
||||
ok := h.ensureResponsesDependencies(c, nil)
|
||||
|
||||
require.False(t, ok)
|
||||
require.Equal(t, http.StatusTeapot, w.Code)
|
||||
assert.Equal(t, "already written", w.Body.String())
|
||||
})
|
||||
|
||||
t.Run("dependencies_ready_returns_true_and_no_write", func(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
|
||||
|
||||
h := &OpenAIGatewayHandler{
|
||||
gatewayService: &service.OpenAIGatewayService{},
|
||||
billingCacheService: &service.BillingCacheService{},
|
||||
apiKeyService: &service.APIKeyService{},
|
||||
concurrencyHelper: &ConcurrencyHelper{
|
||||
concurrencyService: &service.ConcurrencyService{},
|
||||
},
|
||||
}
|
||||
ok := h.ensureResponsesDependencies(c, nil)
|
||||
|
||||
require.True(t, ok)
|
||||
require.False(t, c.Writer.Written())
|
||||
assert.Equal(t, "", w.Body.String())
|
||||
})
|
||||
}
|
||||
|
||||
func TestOpenAIResponses_MissingDependencies_ReturnsServiceUnavailable(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", strings.NewReader(`{"model":"gpt-5","stream":false}`))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
groupID := int64(2)
|
||||
c.Set(string(middleware.ContextKeyAPIKey), &service.APIKey{
|
||||
ID: 10,
|
||||
GroupID: &groupID,
|
||||
})
|
||||
c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{
|
||||
UserID: 1,
|
||||
Concurrency: 1,
|
||||
})
|
||||
|
||||
// 故意使用未初始化依赖,验证快速失败而不是崩溃。
|
||||
h := &OpenAIGatewayHandler{}
|
||||
require.NotPanics(t, func() {
|
||||
h.Responses(c)
|
||||
})
|
||||
|
||||
require.Equal(t, http.StatusServiceUnavailable, w.Code)
|
||||
|
||||
var parsed map[string]any
|
||||
err := json.Unmarshal(w.Body.Bytes(), &parsed)
|
||||
require.NoError(t, err)
|
||||
|
||||
errorObj, ok := parsed["error"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "api_error", errorObj["type"])
|
||||
assert.Equal(t, "Service temporarily unavailable", errorObj["message"])
|
||||
}
|
||||
|
||||
func TestOpenAIResponses_SetsClientTransportHTTP(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", strings.NewReader(`{"model":"gpt-5"}`))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
h := &OpenAIGatewayHandler{}
|
||||
h.Responses(c)
|
||||
|
||||
require.Equal(t, http.StatusUnauthorized, w.Code)
|
||||
require.Equal(t, service.OpenAIClientTransportHTTP, service.GetOpenAIClientTransport(c))
|
||||
}
|
||||
|
||||
func TestOpenAIResponses_RejectsMessageIDAsPreviousResponseID(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", strings.NewReader(
|
||||
`{"model":"gpt-5.1","stream":false,"previous_response_id":"msg_123456","input":[{"type":"input_text","text":"hello"}]}`,
|
||||
))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
groupID := int64(2)
|
||||
c.Set(string(middleware.ContextKeyAPIKey), &service.APIKey{
|
||||
ID: 101,
|
||||
GroupID: &groupID,
|
||||
User: &service.User{ID: 1},
|
||||
})
|
||||
c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{
|
||||
UserID: 1,
|
||||
Concurrency: 1,
|
||||
})
|
||||
|
||||
h := newOpenAIHandlerForPreviousResponseIDValidation(t, nil)
|
||||
h.Responses(c)
|
||||
|
||||
require.Equal(t, http.StatusBadRequest, w.Code)
|
||||
require.Contains(t, w.Body.String(), "previous_response_id must be a response.id")
|
||||
}
|
||||
|
||||
func TestOpenAIResponsesWebSocket_SetsClientTransportWSWhenUpgradeValid(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/openai/v1/responses", nil)
|
||||
c.Request.Header.Set("Upgrade", "websocket")
|
||||
c.Request.Header.Set("Connection", "Upgrade")
|
||||
|
||||
h := &OpenAIGatewayHandler{}
|
||||
h.ResponsesWebSocket(c)
|
||||
|
||||
require.Equal(t, http.StatusUnauthorized, w.Code)
|
||||
require.Equal(t, service.OpenAIClientTransportWS, service.GetOpenAIClientTransport(c))
|
||||
}
|
||||
|
||||
func TestOpenAIResponsesWebSocket_InvalidUpgradeDoesNotSetTransport(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/openai/v1/responses", nil)
|
||||
|
||||
h := &OpenAIGatewayHandler{}
|
||||
h.ResponsesWebSocket(c)
|
||||
|
||||
require.Equal(t, http.StatusUpgradeRequired, w.Code)
|
||||
require.Equal(t, service.OpenAIClientTransportUnknown, service.GetOpenAIClientTransport(c))
|
||||
}
|
||||
|
||||
func TestOpenAIResponsesWebSocket_RejectsMessageIDAsPreviousResponseID(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
h := newOpenAIHandlerForPreviousResponseIDValidation(t, nil)
|
||||
wsServer := newOpenAIWSHandlerTestServer(t, h, middleware.AuthSubject{UserID: 1, Concurrency: 1})
|
||||
defer wsServer.Close()
|
||||
|
||||
dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http")+"/openai/v1/responses", nil)
|
||||
cancelDial()
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
_ = clientConn.CloseNow()
|
||||
}()
|
||||
|
||||
writeCtx, cancelWrite := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
err = clientConn.Write(writeCtx, coderws.MessageText, []byte(
|
||||
`{"type":"response.create","model":"gpt-5.1","stream":false,"previous_response_id":"msg_abc123"}`,
|
||||
))
|
||||
cancelWrite()
|
||||
require.NoError(t, err)
|
||||
|
||||
readCtx, cancelRead := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
_, _, err = clientConn.Read(readCtx)
|
||||
cancelRead()
|
||||
require.Error(t, err)
|
||||
var closeErr coderws.CloseError
|
||||
require.ErrorAs(t, err, &closeErr)
|
||||
require.Equal(t, coderws.StatusPolicyViolation, closeErr.Code)
|
||||
require.Contains(t, strings.ToLower(closeErr.Reason), "previous_response_id")
|
||||
}
|
||||
|
||||
func TestOpenAIResponsesWebSocket_PreviousResponseIDKindLoggedBeforeAcquireFailure(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
cache := &concurrencyCacheMock{
|
||||
acquireUserSlotFn: func(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error) {
|
||||
return false, errors.New("user slot unavailable")
|
||||
},
|
||||
}
|
||||
h := newOpenAIHandlerForPreviousResponseIDValidation(t, cache)
|
||||
wsServer := newOpenAIWSHandlerTestServer(t, h, middleware.AuthSubject{UserID: 1, Concurrency: 1})
|
||||
defer wsServer.Close()
|
||||
|
||||
dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http")+"/openai/v1/responses", nil)
|
||||
cancelDial()
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
_ = clientConn.CloseNow()
|
||||
}()
|
||||
|
||||
writeCtx, cancelWrite := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
err = clientConn.Write(writeCtx, coderws.MessageText, []byte(
|
||||
`{"type":"response.create","model":"gpt-5.1","stream":false,"previous_response_id":"resp_prev_123"}`,
|
||||
))
|
||||
cancelWrite()
|
||||
require.NoError(t, err)
|
||||
|
||||
readCtx, cancelRead := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
_, _, err = clientConn.Read(readCtx)
|
||||
cancelRead()
|
||||
require.Error(t, err)
|
||||
var closeErr coderws.CloseError
|
||||
require.ErrorAs(t, err, &closeErr)
|
||||
require.Equal(t, coderws.StatusInternalError, closeErr.Code)
|
||||
require.Contains(t, strings.ToLower(closeErr.Reason), "failed to acquire user concurrency slot")
|
||||
}
|
||||
|
||||
func TestSetOpenAIClientTransportHTTP(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
|
||||
setOpenAIClientTransportHTTP(c)
|
||||
require.Equal(t, service.OpenAIClientTransportHTTP, service.GetOpenAIClientTransport(c))
|
||||
}
|
||||
|
||||
func TestSetOpenAIClientTransportWS(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
|
||||
setOpenAIClientTransportWS(c)
|
||||
require.Equal(t, service.OpenAIClientTransportWS, service.GetOpenAIClientTransport(c))
|
||||
}
|
||||
|
||||
// TestOpenAIHandler_GjsonExtraction 验证 gjson 从请求体中提取 model/stream 的正确性
|
||||
func TestOpenAIHandler_GjsonExtraction(t *testing.T) {
|
||||
tests := []struct {
|
||||
@@ -228,3 +637,41 @@ func TestOpenAIHandler_InstructionsInjection(t *testing.T) {
|
||||
require.NoError(t, setErr)
|
||||
require.True(t, gjson.ValidBytes(result))
|
||||
}
|
||||
|
||||
func newOpenAIHandlerForPreviousResponseIDValidation(t *testing.T, cache *concurrencyCacheMock) *OpenAIGatewayHandler {
|
||||
t.Helper()
|
||||
if cache == nil {
|
||||
cache = &concurrencyCacheMock{
|
||||
acquireUserSlotFn: func(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error) {
|
||||
return true, nil
|
||||
},
|
||||
acquireAccountSlotFn: func(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) {
|
||||
return true, nil
|
||||
},
|
||||
}
|
||||
}
|
||||
return &OpenAIGatewayHandler{
|
||||
gatewayService: &service.OpenAIGatewayService{},
|
||||
billingCacheService: &service.BillingCacheService{},
|
||||
apiKeyService: &service.APIKeyService{},
|
||||
concurrencyHelper: NewConcurrencyHelper(service.NewConcurrencyService(cache), SSEPingFormatNone, time.Second),
|
||||
}
|
||||
}
|
||||
|
||||
func newOpenAIWSHandlerTestServer(t *testing.T, h *OpenAIGatewayHandler, subject middleware.AuthSubject) *httptest.Server {
|
||||
t.Helper()
|
||||
groupID := int64(2)
|
||||
apiKey := &service.APIKey{
|
||||
ID: 101,
|
||||
GroupID: &groupID,
|
||||
User: &service.User{ID: subject.UserID},
|
||||
}
|
||||
router := gin.New()
|
||||
router.Use(func(c *gin.Context) {
|
||||
c.Set(string(middleware.ContextKeyAPIKey), apiKey)
|
||||
c.Set(string(middleware.ContextKeyUser), subject)
|
||||
c.Next()
|
||||
})
|
||||
router.GET("/openai/v1/responses", h.ResponsesWebSocket)
|
||||
return httptest.NewServer(router)
|
||||
}
|
||||
|
||||
@@ -311,6 +311,35 @@ type opsCaptureWriter struct {
|
||||
buf bytes.Buffer
|
||||
}
|
||||
|
||||
const opsCaptureWriterLimit = 64 * 1024
|
||||
|
||||
var opsCaptureWriterPool = sync.Pool{
|
||||
New: func() any {
|
||||
return &opsCaptureWriter{limit: opsCaptureWriterLimit}
|
||||
},
|
||||
}
|
||||
|
||||
func acquireOpsCaptureWriter(rw gin.ResponseWriter) *opsCaptureWriter {
|
||||
w, ok := opsCaptureWriterPool.Get().(*opsCaptureWriter)
|
||||
if !ok || w == nil {
|
||||
w = &opsCaptureWriter{}
|
||||
}
|
||||
w.ResponseWriter = rw
|
||||
w.limit = opsCaptureWriterLimit
|
||||
w.buf.Reset()
|
||||
return w
|
||||
}
|
||||
|
||||
func releaseOpsCaptureWriter(w *opsCaptureWriter) {
|
||||
if w == nil {
|
||||
return
|
||||
}
|
||||
w.ResponseWriter = nil
|
||||
w.limit = opsCaptureWriterLimit
|
||||
w.buf.Reset()
|
||||
opsCaptureWriterPool.Put(w)
|
||||
}
|
||||
|
||||
func (w *opsCaptureWriter) Write(b []byte) (int, error) {
|
||||
if w.Status() >= 400 && w.limit > 0 && w.buf.Len() < w.limit {
|
||||
remaining := w.limit - w.buf.Len()
|
||||
@@ -342,7 +371,16 @@ func (w *opsCaptureWriter) WriteString(s string) (int, error) {
|
||||
// - Streaming errors after the response has started (SSE) may still need explicit logging.
|
||||
func OpsErrorLoggerMiddleware(ops *service.OpsService) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
w := &opsCaptureWriter{ResponseWriter: c.Writer, limit: 64 * 1024}
|
||||
originalWriter := c.Writer
|
||||
w := acquireOpsCaptureWriter(originalWriter)
|
||||
defer func() {
|
||||
// Restore the original writer before returning so outer middlewares
|
||||
// don't observe a pooled wrapper that has been released.
|
||||
if c.Writer == w {
|
||||
c.Writer = originalWriter
|
||||
}
|
||||
releaseOpsCaptureWriter(w)
|
||||
}()
|
||||
c.Writer = w
|
||||
c.Next()
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
@@ -173,3 +174,43 @@ func TestEnqueueOpsErrorLog_EarlyReturnBranches(t *testing.T) {
|
||||
enqueueOpsErrorLog(ops, entry)
|
||||
require.Equal(t, int64(0), OpsErrorLogEnqueuedTotal())
|
||||
}
|
||||
|
||||
func TestOpsCaptureWriterPool_ResetOnRelease(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
|
||||
writer := acquireOpsCaptureWriter(c.Writer)
|
||||
require.NotNil(t, writer)
|
||||
_, err := writer.buf.WriteString("temp-error-body")
|
||||
require.NoError(t, err)
|
||||
|
||||
releaseOpsCaptureWriter(writer)
|
||||
|
||||
reused := acquireOpsCaptureWriter(c.Writer)
|
||||
defer releaseOpsCaptureWriter(reused)
|
||||
|
||||
require.Zero(t, reused.buf.Len(), "writer should be reset before reuse")
|
||||
}
|
||||
|
||||
func TestOpsErrorLoggerMiddleware_DoesNotBreakOuterMiddlewares(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
r := gin.New()
|
||||
r.Use(middleware2.Recovery())
|
||||
r.Use(middleware2.RequestLogger())
|
||||
r.Use(middleware2.Logger())
|
||||
r.GET("/v1/messages", OpsErrorLoggerMiddleware(nil), func(c *gin.Context) {
|
||||
c.Status(http.StatusNoContent)
|
||||
})
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/v1/messages", nil)
|
||||
|
||||
require.NotPanics(t, func() {
|
||||
r.ServeHTTP(rec, req)
|
||||
})
|
||||
require.Equal(t, http.StatusNoContent, rec.Code)
|
||||
}
|
||||
|
||||
@@ -51,6 +51,7 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) {
|
||||
PurchaseSubscriptionEnabled: settings.PurchaseSubscriptionEnabled,
|
||||
PurchaseSubscriptionURL: settings.PurchaseSubscriptionURL,
|
||||
LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled,
|
||||
SoraClientEnabled: settings.SoraClientEnabled,
|
||||
Version: h.version,
|
||||
})
|
||||
}
|
||||
|
||||
979
backend/internal/handler/sora_client_handler.go
Normal file
979
backend/internal/handler/sora_client_handler.go
Normal file
@@ -0,0 +1,979 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
const (
|
||||
// 上游模型缓存 TTL
|
||||
modelCacheTTL = 1 * time.Hour // 上游获取成功
|
||||
modelCacheFailedTTL = 2 * time.Minute // 上游获取失败(降级到本地)
|
||||
)
|
||||
|
||||
// SoraClientHandler 处理 Sora 客户端 API 请求。
|
||||
type SoraClientHandler struct {
|
||||
genService *service.SoraGenerationService
|
||||
quotaService *service.SoraQuotaService
|
||||
s3Storage *service.SoraS3Storage
|
||||
soraGatewayService *service.SoraGatewayService
|
||||
gatewayService *service.GatewayService
|
||||
mediaStorage *service.SoraMediaStorage
|
||||
apiKeyService *service.APIKeyService
|
||||
|
||||
// 上游模型缓存
|
||||
modelCacheMu sync.RWMutex
|
||||
cachedFamilies []service.SoraModelFamily
|
||||
modelCacheTime time.Time
|
||||
modelCacheUpstream bool // 是否来自上游(决定 TTL)
|
||||
}
|
||||
|
||||
// NewSoraClientHandler 创建 Sora 客户端 Handler。
|
||||
func NewSoraClientHandler(
|
||||
genService *service.SoraGenerationService,
|
||||
quotaService *service.SoraQuotaService,
|
||||
s3Storage *service.SoraS3Storage,
|
||||
soraGatewayService *service.SoraGatewayService,
|
||||
gatewayService *service.GatewayService,
|
||||
mediaStorage *service.SoraMediaStorage,
|
||||
apiKeyService *service.APIKeyService,
|
||||
) *SoraClientHandler {
|
||||
return &SoraClientHandler{
|
||||
genService: genService,
|
||||
quotaService: quotaService,
|
||||
s3Storage: s3Storage,
|
||||
soraGatewayService: soraGatewayService,
|
||||
gatewayService: gatewayService,
|
||||
mediaStorage: mediaStorage,
|
||||
apiKeyService: apiKeyService,
|
||||
}
|
||||
}
|
||||
|
||||
// GenerateRequest 生成请求。
|
||||
type GenerateRequest struct {
|
||||
Model string `json:"model" binding:"required"`
|
||||
Prompt string `json:"prompt" binding:"required"`
|
||||
MediaType string `json:"media_type"` // video / image,默认 video
|
||||
VideoCount int `json:"video_count,omitempty"` // 视频数量(1-3)
|
||||
ImageInput string `json:"image_input,omitempty"` // 参考图(base64 或 URL)
|
||||
APIKeyID *int64 `json:"api_key_id,omitempty"` // 前端传递的 API Key ID
|
||||
}
|
||||
|
||||
// Generate 异步生成 — 创建 pending 记录后立即返回。
|
||||
// POST /api/v1/sora/generate
|
||||
func (h *SoraClientHandler) Generate(c *gin.Context) {
|
||||
userID := getUserIDFromContext(c)
|
||||
if userID == 0 {
|
||||
response.Error(c, http.StatusUnauthorized, "未登录")
|
||||
return
|
||||
}
|
||||
|
||||
var req GenerateRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.Error(c, http.StatusBadRequest, "参数错误: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if req.MediaType == "" {
|
||||
req.MediaType = "video"
|
||||
}
|
||||
req.VideoCount = normalizeVideoCount(req.MediaType, req.VideoCount)
|
||||
|
||||
// 并发数检查(最多 3 个)
|
||||
activeCount, err := h.genService.CountActiveByUser(c.Request.Context(), userID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
if activeCount >= 3 {
|
||||
response.Error(c, http.StatusTooManyRequests, "同时进行中的任务不能超过 3 个")
|
||||
return
|
||||
}
|
||||
|
||||
// 配额检查(粗略检查,实际文件大小在上传后才知道)
|
||||
if h.quotaService != nil {
|
||||
if err := h.quotaService.CheckQuota(c.Request.Context(), userID, 0); err != nil {
|
||||
var quotaErr *service.QuotaExceededError
|
||||
if errors.As(err, "aErr) {
|
||||
response.Error(c, http.StatusTooManyRequests, "存储配额已满,请删除不需要的作品释放空间")
|
||||
return
|
||||
}
|
||||
response.Error(c, http.StatusForbidden, err.Error())
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// 获取 API Key ID 和 Group ID
|
||||
var apiKeyID *int64
|
||||
var groupID *int64
|
||||
|
||||
if req.APIKeyID != nil && h.apiKeyService != nil {
|
||||
// 前端传递了 api_key_id,需要校验
|
||||
apiKey, err := h.apiKeyService.GetByID(c.Request.Context(), *req.APIKeyID)
|
||||
if err != nil {
|
||||
response.Error(c, http.StatusBadRequest, "API Key 不存在")
|
||||
return
|
||||
}
|
||||
if apiKey.UserID != userID {
|
||||
response.Error(c, http.StatusForbidden, "API Key 不属于当前用户")
|
||||
return
|
||||
}
|
||||
if apiKey.Status != service.StatusAPIKeyActive {
|
||||
response.Error(c, http.StatusForbidden, "API Key 不可用")
|
||||
return
|
||||
}
|
||||
apiKeyID = &apiKey.ID
|
||||
groupID = apiKey.GroupID
|
||||
} else if id, ok := c.Get("api_key_id"); ok {
|
||||
// 兼容 API Key 认证路径(/sora/v1/ 网关路由)
|
||||
if v, ok := id.(int64); ok {
|
||||
apiKeyID = &v
|
||||
}
|
||||
}
|
||||
|
||||
gen, err := h.genService.CreatePending(c.Request.Context(), userID, apiKeyID, req.Model, req.Prompt, req.MediaType)
|
||||
if err != nil {
|
||||
if errors.Is(err, service.ErrSoraGenerationConcurrencyLimit) {
|
||||
response.Error(c, http.StatusTooManyRequests, "同时进行中的任务不能超过 3 个")
|
||||
return
|
||||
}
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// 启动后台异步生成 goroutine
|
||||
go h.processGeneration(gen.ID, userID, groupID, req.Model, req.Prompt, req.MediaType, req.ImageInput, req.VideoCount)
|
||||
|
||||
response.Success(c, gin.H{
|
||||
"generation_id": gen.ID,
|
||||
"status": gen.Status,
|
||||
})
|
||||
}
|
||||
|
||||
// processGeneration 后台异步执行 Sora 生成任务。
|
||||
// 流程:选择账号 → Forward → 提取媒体 URL → 三层降级存储(S3 → 本地 → 上游)→ 更新记录。
|
||||
func (h *SoraClientHandler) processGeneration(genID int64, userID int64, groupID *int64, model, prompt, mediaType, imageInput string, videoCount int) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
// 标记为生成中
|
||||
if err := h.genService.MarkGenerating(ctx, genID, ""); err != nil {
|
||||
if errors.Is(err, service.ErrSoraGenerationStateConflict) {
|
||||
logger.LegacyPrintf("handler.sora_client", "[SoraClient] 任务状态已变化,跳过生成 id=%d", genID)
|
||||
return
|
||||
}
|
||||
logger.LegacyPrintf("handler.sora_client", "[SoraClient] 标记生成中失败 id=%d err=%v", genID, err)
|
||||
return
|
||||
}
|
||||
|
||||
logger.LegacyPrintf(
|
||||
"handler.sora_client",
|
||||
"[SoraClient] 开始生成 id=%d user=%d group=%d model=%s media_type=%s video_count=%d has_image=%v prompt_len=%d",
|
||||
genID,
|
||||
userID,
|
||||
groupIDForLog(groupID),
|
||||
model,
|
||||
mediaType,
|
||||
videoCount,
|
||||
strings.TrimSpace(imageInput) != "",
|
||||
len(strings.TrimSpace(prompt)),
|
||||
)
|
||||
|
||||
// 有 groupID 时由分组决定平台,无 groupID 时用 ForcePlatform 兜底
|
||||
if groupID == nil {
|
||||
ctx = context.WithValue(ctx, ctxkey.ForcePlatform, service.PlatformSora)
|
||||
}
|
||||
|
||||
if h.gatewayService == nil {
|
||||
_ = h.genService.MarkFailed(ctx, genID, "内部错误: gatewayService 未初始化")
|
||||
return
|
||||
}
|
||||
|
||||
// 选择 Sora 账号
|
||||
account, err := h.gatewayService.SelectAccountForModel(ctx, groupID, "", model)
|
||||
if err != nil {
|
||||
logger.LegacyPrintf(
|
||||
"handler.sora_client",
|
||||
"[SoraClient] 选择账号失败 id=%d user=%d group=%d model=%s err=%v",
|
||||
genID,
|
||||
userID,
|
||||
groupIDForLog(groupID),
|
||||
model,
|
||||
err,
|
||||
)
|
||||
_ = h.genService.MarkFailed(ctx, genID, "选择账号失败: "+err.Error())
|
||||
return
|
||||
}
|
||||
logger.LegacyPrintf(
|
||||
"handler.sora_client",
|
||||
"[SoraClient] 选中账号 id=%d user=%d group=%d model=%s account_id=%d account_name=%s platform=%s type=%s",
|
||||
genID,
|
||||
userID,
|
||||
groupIDForLog(groupID),
|
||||
model,
|
||||
account.ID,
|
||||
account.Name,
|
||||
account.Platform,
|
||||
account.Type,
|
||||
)
|
||||
|
||||
// 构建 chat completions 请求体(非流式)
|
||||
body := buildAsyncRequestBody(model, prompt, imageInput, normalizeVideoCount(mediaType, videoCount))
|
||||
|
||||
if h.soraGatewayService == nil {
|
||||
_ = h.genService.MarkFailed(ctx, genID, "内部错误: soraGatewayService 未初始化")
|
||||
return
|
||||
}
|
||||
|
||||
// 创建 mock gin 上下文用于 Forward(捕获响应以提取媒体 URL)
|
||||
recorder := httptest.NewRecorder()
|
||||
mockGinCtx, _ := gin.CreateTestContext(recorder)
|
||||
mockGinCtx.Request, _ = http.NewRequest("POST", "/", nil)
|
||||
|
||||
// 调用 Forward(非流式)
|
||||
result, err := h.soraGatewayService.Forward(ctx, mockGinCtx, account, body, false)
|
||||
if err != nil {
|
||||
logger.LegacyPrintf(
|
||||
"handler.sora_client",
|
||||
"[SoraClient] Forward失败 id=%d account_id=%d model=%s status=%d body=%s err=%v",
|
||||
genID,
|
||||
account.ID,
|
||||
model,
|
||||
recorder.Code,
|
||||
trimForLog(recorder.Body.String(), 400),
|
||||
err,
|
||||
)
|
||||
// 检查是否已取消
|
||||
gen, _ := h.genService.GetByID(ctx, genID, userID)
|
||||
if gen != nil && gen.Status == service.SoraGenStatusCancelled {
|
||||
return
|
||||
}
|
||||
_ = h.genService.MarkFailed(ctx, genID, "生成失败: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 提取媒体 URL(优先从 ForwardResult,其次从响应体解析)
|
||||
mediaURL, mediaURLs := extractMediaURLsFromResult(result, recorder)
|
||||
if mediaURL == "" {
|
||||
logger.LegacyPrintf(
|
||||
"handler.sora_client",
|
||||
"[SoraClient] 未提取到媒体URL id=%d account_id=%d model=%s status=%d body=%s",
|
||||
genID,
|
||||
account.ID,
|
||||
model,
|
||||
recorder.Code,
|
||||
trimForLog(recorder.Body.String(), 400),
|
||||
)
|
||||
_ = h.genService.MarkFailed(ctx, genID, "未获取到媒体 URL")
|
||||
return
|
||||
}
|
||||
|
||||
// 检查任务是否已被取消
|
||||
gen, _ := h.genService.GetByID(ctx, genID, userID)
|
||||
if gen != nil && gen.Status == service.SoraGenStatusCancelled {
|
||||
logger.LegacyPrintf("handler.sora_client", "[SoraClient] 任务已取消,跳过存储 id=%d", genID)
|
||||
return
|
||||
}
|
||||
|
||||
// 三层降级存储:S3 → 本地 → 上游临时 URL
|
||||
storedURL, storedURLs, storageType, s3Keys, fileSize := h.storeMediaWithDegradation(ctx, userID, mediaType, mediaURL, mediaURLs)
|
||||
|
||||
usageAdded := false
|
||||
if (storageType == service.SoraStorageTypeS3 || storageType == service.SoraStorageTypeLocal) && fileSize > 0 && h.quotaService != nil {
|
||||
if err := h.quotaService.AddUsage(ctx, userID, fileSize); err != nil {
|
||||
h.cleanupStoredMedia(ctx, storageType, s3Keys, storedURLs)
|
||||
var quotaErr *service.QuotaExceededError
|
||||
if errors.As(err, "aErr) {
|
||||
_ = h.genService.MarkFailed(ctx, genID, "存储配额已满,请删除不需要的作品释放空间")
|
||||
return
|
||||
}
|
||||
_ = h.genService.MarkFailed(ctx, genID, "存储配额更新失败: "+err.Error())
|
||||
return
|
||||
}
|
||||
usageAdded = true
|
||||
}
|
||||
|
||||
// 存储完成后再做一次取消检查,防止取消被 completed 覆盖。
|
||||
gen, _ = h.genService.GetByID(ctx, genID, userID)
|
||||
if gen != nil && gen.Status == service.SoraGenStatusCancelled {
|
||||
logger.LegacyPrintf("handler.sora_client", "[SoraClient] 存储后检测到任务已取消,回滚存储 id=%d", genID)
|
||||
h.cleanupStoredMedia(ctx, storageType, s3Keys, storedURLs)
|
||||
if usageAdded && h.quotaService != nil {
|
||||
_ = h.quotaService.ReleaseUsage(ctx, userID, fileSize)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// 标记完成
|
||||
if err := h.genService.MarkCompleted(ctx, genID, storedURL, storedURLs, storageType, s3Keys, fileSize); err != nil {
|
||||
if errors.Is(err, service.ErrSoraGenerationStateConflict) {
|
||||
h.cleanupStoredMedia(ctx, storageType, s3Keys, storedURLs)
|
||||
if usageAdded && h.quotaService != nil {
|
||||
_ = h.quotaService.ReleaseUsage(ctx, userID, fileSize)
|
||||
}
|
||||
return
|
||||
}
|
||||
logger.LegacyPrintf("handler.sora_client", "[SoraClient] 标记完成失败 id=%d err=%v", genID, err)
|
||||
return
|
||||
}
|
||||
|
||||
logger.LegacyPrintf("handler.sora_client", "[SoraClient] 生成完成 id=%d storage=%s size=%d", genID, storageType, fileSize)
|
||||
}
|
||||
|
||||
// storeMediaWithDegradation 实现三层降级存储链:S3 → 本地 → 上游。
|
||||
func (h *SoraClientHandler) storeMediaWithDegradation(
|
||||
ctx context.Context, userID int64, mediaType string,
|
||||
mediaURL string, mediaURLs []string,
|
||||
) (storedURL string, storedURLs []string, storageType string, s3Keys []string, fileSize int64) {
|
||||
urls := mediaURLs
|
||||
if len(urls) == 0 {
|
||||
urls = []string{mediaURL}
|
||||
}
|
||||
|
||||
// 第一层:尝试 S3
|
||||
if h.s3Storage != nil && h.s3Storage.Enabled(ctx) {
|
||||
keys := make([]string, 0, len(urls))
|
||||
var totalSize int64
|
||||
allOK := true
|
||||
for _, u := range urls {
|
||||
key, size, err := h.s3Storage.UploadFromURL(ctx, userID, u)
|
||||
if err != nil {
|
||||
logger.LegacyPrintf("handler.sora_client", "[SoraClient] S3 上传失败 err=%v", err)
|
||||
allOK = false
|
||||
// 清理已上传的文件
|
||||
if len(keys) > 0 {
|
||||
_ = h.s3Storage.DeleteObjects(ctx, keys)
|
||||
}
|
||||
break
|
||||
}
|
||||
keys = append(keys, key)
|
||||
totalSize += size
|
||||
}
|
||||
if allOK && len(keys) > 0 {
|
||||
accessURLs := make([]string, 0, len(keys))
|
||||
for _, key := range keys {
|
||||
accessURL, err := h.s3Storage.GetAccessURL(ctx, key)
|
||||
if err != nil {
|
||||
logger.LegacyPrintf("handler.sora_client", "[SoraClient] 生成 S3 访问 URL 失败 err=%v", err)
|
||||
_ = h.s3Storage.DeleteObjects(ctx, keys)
|
||||
allOK = false
|
||||
break
|
||||
}
|
||||
accessURLs = append(accessURLs, accessURL)
|
||||
}
|
||||
if allOK && len(accessURLs) > 0 {
|
||||
return accessURLs[0], accessURLs, service.SoraStorageTypeS3, keys, totalSize
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 第二层:尝试本地存储
|
||||
if h.mediaStorage != nil && h.mediaStorage.Enabled() {
|
||||
storedPaths, err := h.mediaStorage.StoreFromURLs(ctx, mediaType, urls)
|
||||
if err == nil && len(storedPaths) > 0 {
|
||||
firstPath := storedPaths[0]
|
||||
totalSize, sizeErr := h.mediaStorage.TotalSizeByRelativePaths(storedPaths)
|
||||
if sizeErr != nil {
|
||||
logger.LegacyPrintf("handler.sora_client", "[SoraClient] 统计本地文件大小失败 err=%v", sizeErr)
|
||||
}
|
||||
return firstPath, storedPaths, service.SoraStorageTypeLocal, nil, totalSize
|
||||
}
|
||||
logger.LegacyPrintf("handler.sora_client", "[SoraClient] 本地存储失败 err=%v", err)
|
||||
}
|
||||
|
||||
// 第三层:保留上游临时 URL
|
||||
return urls[0], urls, service.SoraStorageTypeUpstream, nil, 0
|
||||
}
|
||||
|
||||
// buildAsyncRequestBody 构建 Sora 异步生成的 chat completions 请求体。
|
||||
func buildAsyncRequestBody(model, prompt, imageInput string, videoCount int) []byte {
|
||||
body := map[string]any{
|
||||
"model": model,
|
||||
"messages": []map[string]string{
|
||||
{"role": "user", "content": prompt},
|
||||
},
|
||||
"stream": false,
|
||||
}
|
||||
if imageInput != "" {
|
||||
body["image_input"] = imageInput
|
||||
}
|
||||
if videoCount > 1 {
|
||||
body["video_count"] = videoCount
|
||||
}
|
||||
b, _ := json.Marshal(body)
|
||||
return b
|
||||
}
|
||||
|
||||
func normalizeVideoCount(mediaType string, videoCount int) int {
|
||||
if mediaType != "video" {
|
||||
return 1
|
||||
}
|
||||
if videoCount <= 0 {
|
||||
return 1
|
||||
}
|
||||
if videoCount > 3 {
|
||||
return 3
|
||||
}
|
||||
return videoCount
|
||||
}
|
||||
|
||||
// extractMediaURLsFromResult 从 Forward 结果和响应体中提取媒体 URL。
|
||||
// OAuth 路径:ForwardResult.MediaURL 已填充。
|
||||
// APIKey 路径:需从响应体解析 media_url / media_urls 字段。
|
||||
func extractMediaURLsFromResult(result *service.ForwardResult, recorder *httptest.ResponseRecorder) (string, []string) {
|
||||
// 优先从 ForwardResult 获取(OAuth 路径)
|
||||
if result != nil && result.MediaURL != "" {
|
||||
// 尝试从响应体获取完整 URL 列表
|
||||
if urls := parseMediaURLsFromBody(recorder.Body.Bytes()); len(urls) > 0 {
|
||||
return urls[0], urls
|
||||
}
|
||||
return result.MediaURL, []string{result.MediaURL}
|
||||
}
|
||||
|
||||
// 从响应体解析(APIKey 路径)
|
||||
if urls := parseMediaURLsFromBody(recorder.Body.Bytes()); len(urls) > 0 {
|
||||
return urls[0], urls
|
||||
}
|
||||
|
||||
return "", nil
|
||||
}
|
||||
|
||||
// parseMediaURLsFromBody 从 JSON 响应体中解析 media_url / media_urls 字段。
|
||||
func parseMediaURLsFromBody(body []byte) []string {
|
||||
if len(body) == 0 {
|
||||
return nil
|
||||
}
|
||||
var resp map[string]any
|
||||
if err := json.Unmarshal(body, &resp); err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 优先 media_urls(多图数组)
|
||||
if rawURLs, ok := resp["media_urls"]; ok {
|
||||
if arr, ok := rawURLs.([]any); ok && len(arr) > 0 {
|
||||
urls := make([]string, 0, len(arr))
|
||||
for _, item := range arr {
|
||||
if s, ok := item.(string); ok && s != "" {
|
||||
urls = append(urls, s)
|
||||
}
|
||||
}
|
||||
if len(urls) > 0 {
|
||||
return urls
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 回退到 media_url(单个 URL)
|
||||
if url, ok := resp["media_url"].(string); ok && url != "" {
|
||||
return []string{url}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ListGenerations 查询生成记录列表。
|
||||
// GET /api/v1/sora/generations
|
||||
func (h *SoraClientHandler) ListGenerations(c *gin.Context) {
|
||||
userID := getUserIDFromContext(c)
|
||||
if userID == 0 {
|
||||
response.Error(c, http.StatusUnauthorized, "未登录")
|
||||
return
|
||||
}
|
||||
|
||||
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
|
||||
pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20"))
|
||||
|
||||
params := service.SoraGenerationListParams{
|
||||
UserID: userID,
|
||||
Status: c.Query("status"),
|
||||
StorageType: c.Query("storage_type"),
|
||||
MediaType: c.Query("media_type"),
|
||||
Page: page,
|
||||
PageSize: pageSize,
|
||||
}
|
||||
|
||||
gens, total, err := h.genService.List(c.Request.Context(), params)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// 为 S3 记录动态生成预签名 URL
|
||||
for _, gen := range gens {
|
||||
_ = h.genService.ResolveMediaURLs(c.Request.Context(), gen)
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{
|
||||
"data": gens,
|
||||
"total": total,
|
||||
"page": page,
|
||||
})
|
||||
}
|
||||
|
||||
// GetGeneration 查询生成记录详情。
|
||||
// GET /api/v1/sora/generations/:id
|
||||
func (h *SoraClientHandler) GetGeneration(c *gin.Context) {
|
||||
userID := getUserIDFromContext(c)
|
||||
if userID == 0 {
|
||||
response.Error(c, http.StatusUnauthorized, "未登录")
|
||||
return
|
||||
}
|
||||
|
||||
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.Error(c, http.StatusBadRequest, "无效的 ID")
|
||||
return
|
||||
}
|
||||
|
||||
gen, err := h.genService.GetByID(c.Request.Context(), id, userID)
|
||||
if err != nil {
|
||||
response.Error(c, http.StatusNotFound, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
_ = h.genService.ResolveMediaURLs(c.Request.Context(), gen)
|
||||
response.Success(c, gen)
|
||||
}
|
||||
|
||||
// DeleteGeneration 删除生成记录。
|
||||
// DELETE /api/v1/sora/generations/:id
|
||||
func (h *SoraClientHandler) DeleteGeneration(c *gin.Context) {
|
||||
userID := getUserIDFromContext(c)
|
||||
if userID == 0 {
|
||||
response.Error(c, http.StatusUnauthorized, "未登录")
|
||||
return
|
||||
}
|
||||
|
||||
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.Error(c, http.StatusBadRequest, "无效的 ID")
|
||||
return
|
||||
}
|
||||
|
||||
gen, err := h.genService.GetByID(c.Request.Context(), id, userID)
|
||||
if err != nil {
|
||||
response.Error(c, http.StatusNotFound, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 先尝试清理本地文件,再删除记录(清理失败不阻塞删除)。
|
||||
if gen.StorageType == service.SoraStorageTypeLocal && h.mediaStorage != nil {
|
||||
paths := gen.MediaURLs
|
||||
if len(paths) == 0 && gen.MediaURL != "" {
|
||||
paths = []string{gen.MediaURL}
|
||||
}
|
||||
if err := h.mediaStorage.DeleteByRelativePaths(paths); err != nil {
|
||||
logger.LegacyPrintf("handler.sora_client", "[SoraClient] 删除本地文件失败 id=%d err=%v", id, err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := h.genService.Delete(c.Request.Context(), id, userID); err != nil {
|
||||
response.Error(c, http.StatusNotFound, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{"message": "已删除"})
|
||||
}
|
||||
|
||||
// GetQuota 查询用户存储配额。
|
||||
// GET /api/v1/sora/quota
|
||||
func (h *SoraClientHandler) GetQuota(c *gin.Context) {
|
||||
userID := getUserIDFromContext(c)
|
||||
if userID == 0 {
|
||||
response.Error(c, http.StatusUnauthorized, "未登录")
|
||||
return
|
||||
}
|
||||
|
||||
if h.quotaService == nil {
|
||||
response.Success(c, service.QuotaInfo{QuotaSource: "unlimited", Source: "unlimited"})
|
||||
return
|
||||
}
|
||||
|
||||
quota, err := h.quotaService.GetQuota(c.Request.Context(), userID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, quota)
|
||||
}
|
||||
|
||||
// CancelGeneration 取消生成任务。
|
||||
// POST /api/v1/sora/generations/:id/cancel
|
||||
func (h *SoraClientHandler) CancelGeneration(c *gin.Context) {
|
||||
userID := getUserIDFromContext(c)
|
||||
if userID == 0 {
|
||||
response.Error(c, http.StatusUnauthorized, "未登录")
|
||||
return
|
||||
}
|
||||
|
||||
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.Error(c, http.StatusBadRequest, "无效的 ID")
|
||||
return
|
||||
}
|
||||
|
||||
// 权限校验
|
||||
gen, err := h.genService.GetByID(c.Request.Context(), id, userID)
|
||||
if err != nil {
|
||||
response.Error(c, http.StatusNotFound, err.Error())
|
||||
return
|
||||
}
|
||||
_ = gen
|
||||
|
||||
if err := h.genService.MarkCancelled(c.Request.Context(), id); err != nil {
|
||||
if errors.Is(err, service.ErrSoraGenerationNotActive) {
|
||||
response.Error(c, http.StatusConflict, "任务已结束,无法取消")
|
||||
return
|
||||
}
|
||||
response.Error(c, http.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{"message": "已取消"})
|
||||
}
|
||||
|
||||
// SaveToStorage 手动保存 upstream 记录到 S3。
|
||||
// POST /api/v1/sora/generations/:id/save
|
||||
func (h *SoraClientHandler) SaveToStorage(c *gin.Context) {
|
||||
userID := getUserIDFromContext(c)
|
||||
if userID == 0 {
|
||||
response.Error(c, http.StatusUnauthorized, "未登录")
|
||||
return
|
||||
}
|
||||
|
||||
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.Error(c, http.StatusBadRequest, "无效的 ID")
|
||||
return
|
||||
}
|
||||
|
||||
gen, err := h.genService.GetByID(c.Request.Context(), id, userID)
|
||||
if err != nil {
|
||||
response.Error(c, http.StatusNotFound, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if gen.StorageType != service.SoraStorageTypeUpstream {
|
||||
response.Error(c, http.StatusBadRequest, "仅 upstream 类型的记录可手动保存")
|
||||
return
|
||||
}
|
||||
if gen.MediaURL == "" {
|
||||
response.Error(c, http.StatusBadRequest, "媒体 URL 为空,可能已过期")
|
||||
return
|
||||
}
|
||||
|
||||
if h.s3Storage == nil || !h.s3Storage.Enabled(c.Request.Context()) {
|
||||
response.Error(c, http.StatusServiceUnavailable, "云存储未配置,请联系管理员")
|
||||
return
|
||||
}
|
||||
|
||||
sourceURLs := gen.MediaURLs
|
||||
if len(sourceURLs) == 0 && gen.MediaURL != "" {
|
||||
sourceURLs = []string{gen.MediaURL}
|
||||
}
|
||||
if len(sourceURLs) == 0 {
|
||||
response.Error(c, http.StatusBadRequest, "媒体 URL 为空,可能已过期")
|
||||
return
|
||||
}
|
||||
|
||||
uploadedKeys := make([]string, 0, len(sourceURLs))
|
||||
accessURLs := make([]string, 0, len(sourceURLs))
|
||||
var totalSize int64
|
||||
|
||||
for _, sourceURL := range sourceURLs {
|
||||
objectKey, fileSize, uploadErr := h.s3Storage.UploadFromURL(c.Request.Context(), userID, sourceURL)
|
||||
if uploadErr != nil {
|
||||
if len(uploadedKeys) > 0 {
|
||||
_ = h.s3Storage.DeleteObjects(c.Request.Context(), uploadedKeys)
|
||||
}
|
||||
var upstreamErr *service.UpstreamDownloadError
|
||||
if errors.As(uploadErr, &upstreamErr) && (upstreamErr.StatusCode == http.StatusForbidden || upstreamErr.StatusCode == http.StatusNotFound) {
|
||||
response.Error(c, http.StatusGone, "媒体链接已过期,无法保存")
|
||||
return
|
||||
}
|
||||
response.Error(c, http.StatusInternalServerError, "上传到 S3 失败: "+uploadErr.Error())
|
||||
return
|
||||
}
|
||||
accessURL, err := h.s3Storage.GetAccessURL(c.Request.Context(), objectKey)
|
||||
if err != nil {
|
||||
uploadedKeys = append(uploadedKeys, objectKey)
|
||||
_ = h.s3Storage.DeleteObjects(c.Request.Context(), uploadedKeys)
|
||||
response.Error(c, http.StatusInternalServerError, "生成 S3 访问链接失败: "+err.Error())
|
||||
return
|
||||
}
|
||||
uploadedKeys = append(uploadedKeys, objectKey)
|
||||
accessURLs = append(accessURLs, accessURL)
|
||||
totalSize += fileSize
|
||||
}
|
||||
|
||||
usageAdded := false
|
||||
if totalSize > 0 && h.quotaService != nil {
|
||||
if err := h.quotaService.AddUsage(c.Request.Context(), userID, totalSize); err != nil {
|
||||
_ = h.s3Storage.DeleteObjects(c.Request.Context(), uploadedKeys)
|
||||
var quotaErr *service.QuotaExceededError
|
||||
if errors.As(err, "aErr) {
|
||||
response.Error(c, http.StatusTooManyRequests, "存储配额已满,请删除不需要的作品释放空间")
|
||||
return
|
||||
}
|
||||
response.Error(c, http.StatusInternalServerError, "配额更新失败: "+err.Error())
|
||||
return
|
||||
}
|
||||
usageAdded = true
|
||||
}
|
||||
|
||||
if err := h.genService.UpdateStorageForCompleted(
|
||||
c.Request.Context(),
|
||||
id,
|
||||
accessURLs[0],
|
||||
accessURLs,
|
||||
service.SoraStorageTypeS3,
|
||||
uploadedKeys,
|
||||
totalSize,
|
||||
); err != nil {
|
||||
_ = h.s3Storage.DeleteObjects(c.Request.Context(), uploadedKeys)
|
||||
if usageAdded && h.quotaService != nil {
|
||||
_ = h.quotaService.ReleaseUsage(c.Request.Context(), userID, totalSize)
|
||||
}
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{
|
||||
"message": "已保存到 S3",
|
||||
"object_key": uploadedKeys[0],
|
||||
"object_keys": uploadedKeys,
|
||||
})
|
||||
}
|
||||
|
||||
// GetStorageStatus 返回存储状态。
|
||||
// GET /api/v1/sora/storage-status
|
||||
func (h *SoraClientHandler) GetStorageStatus(c *gin.Context) {
|
||||
s3Enabled := h.s3Storage != nil && h.s3Storage.Enabled(c.Request.Context())
|
||||
s3Healthy := false
|
||||
if s3Enabled {
|
||||
s3Healthy = h.s3Storage.IsHealthy(c.Request.Context())
|
||||
}
|
||||
localEnabled := h.mediaStorage != nil && h.mediaStorage.Enabled()
|
||||
response.Success(c, gin.H{
|
||||
"s3_enabled": s3Enabled,
|
||||
"s3_healthy": s3Healthy,
|
||||
"local_enabled": localEnabled,
|
||||
})
|
||||
}
|
||||
|
||||
func (h *SoraClientHandler) cleanupStoredMedia(ctx context.Context, storageType string, s3Keys []string, localPaths []string) {
|
||||
switch storageType {
|
||||
case service.SoraStorageTypeS3:
|
||||
if h.s3Storage != nil && len(s3Keys) > 0 {
|
||||
if err := h.s3Storage.DeleteObjects(ctx, s3Keys); err != nil {
|
||||
logger.LegacyPrintf("handler.sora_client", "[SoraClient] 清理 S3 文件失败 keys=%v err=%v", s3Keys, err)
|
||||
}
|
||||
}
|
||||
case service.SoraStorageTypeLocal:
|
||||
if h.mediaStorage != nil && len(localPaths) > 0 {
|
||||
if err := h.mediaStorage.DeleteByRelativePaths(localPaths); err != nil {
|
||||
logger.LegacyPrintf("handler.sora_client", "[SoraClient] 清理本地文件失败 paths=%v err=%v", localPaths, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// getUserIDFromContext 从 gin 上下文中提取用户 ID。
|
||||
func getUserIDFromContext(c *gin.Context) int64 {
|
||||
if subject, ok := middleware2.GetAuthSubjectFromContext(c); ok && subject.UserID > 0 {
|
||||
return subject.UserID
|
||||
}
|
||||
|
||||
if id, ok := c.Get("user_id"); ok {
|
||||
switch v := id.(type) {
|
||||
case int64:
|
||||
return v
|
||||
case float64:
|
||||
return int64(v)
|
||||
case string:
|
||||
n, _ := strconv.ParseInt(v, 10, 64)
|
||||
return n
|
||||
}
|
||||
}
|
||||
// 尝试从 JWT claims 获取
|
||||
if id, ok := c.Get("userID"); ok {
|
||||
if v, ok := id.(int64); ok {
|
||||
return v
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func groupIDForLog(groupID *int64) int64 {
|
||||
if groupID == nil {
|
||||
return 0
|
||||
}
|
||||
return *groupID
|
||||
}
|
||||
|
||||
func trimForLog(raw string, maxLen int) string {
|
||||
trimmed := strings.TrimSpace(raw)
|
||||
if maxLen <= 0 || len(trimmed) <= maxLen {
|
||||
return trimmed
|
||||
}
|
||||
return trimmed[:maxLen] + "...(truncated)"
|
||||
}
|
||||
|
||||
// GetModels 获取可用 Sora 模型家族列表。
|
||||
// 优先从上游 Sora API 同步模型列表,失败时降级到本地配置。
|
||||
// GET /api/v1/sora/models
|
||||
func (h *SoraClientHandler) GetModels(c *gin.Context) {
|
||||
families := h.getModelFamilies(c.Request.Context())
|
||||
response.Success(c, families)
|
||||
}
|
||||
|
||||
// getModelFamilies 获取模型家族列表(带缓存)。
|
||||
func (h *SoraClientHandler) getModelFamilies(ctx context.Context) []service.SoraModelFamily {
|
||||
// 读锁检查缓存
|
||||
h.modelCacheMu.RLock()
|
||||
ttl := modelCacheTTL
|
||||
if !h.modelCacheUpstream {
|
||||
ttl = modelCacheFailedTTL
|
||||
}
|
||||
if h.cachedFamilies != nil && time.Since(h.modelCacheTime) < ttl {
|
||||
families := h.cachedFamilies
|
||||
h.modelCacheMu.RUnlock()
|
||||
return families
|
||||
}
|
||||
h.modelCacheMu.RUnlock()
|
||||
|
||||
// 写锁更新缓存
|
||||
h.modelCacheMu.Lock()
|
||||
defer h.modelCacheMu.Unlock()
|
||||
|
||||
// double-check
|
||||
ttl = modelCacheTTL
|
||||
if !h.modelCacheUpstream {
|
||||
ttl = modelCacheFailedTTL
|
||||
}
|
||||
if h.cachedFamilies != nil && time.Since(h.modelCacheTime) < ttl {
|
||||
return h.cachedFamilies
|
||||
}
|
||||
|
||||
// 尝试从上游获取
|
||||
families, err := h.fetchUpstreamModels(ctx)
|
||||
if err != nil {
|
||||
logger.LegacyPrintf("handler.sora_client", "[SoraClient] 上游模型获取失败,使用本地配置: %v", err)
|
||||
families = service.BuildSoraModelFamilies()
|
||||
h.cachedFamilies = families
|
||||
h.modelCacheTime = time.Now()
|
||||
h.modelCacheUpstream = false
|
||||
return families
|
||||
}
|
||||
|
||||
logger.LegacyPrintf("handler.sora_client", "[SoraClient] 从上游同步到 %d 个模型家族", len(families))
|
||||
h.cachedFamilies = families
|
||||
h.modelCacheTime = time.Now()
|
||||
h.modelCacheUpstream = true
|
||||
return families
|
||||
}
|
||||
|
||||
// fetchUpstreamModels 从上游 Sora API 获取模型列表。
|
||||
func (h *SoraClientHandler) fetchUpstreamModels(ctx context.Context) ([]service.SoraModelFamily, error) {
|
||||
if h.gatewayService == nil {
|
||||
return nil, fmt.Errorf("gatewayService 未初始化")
|
||||
}
|
||||
|
||||
// 设置 ForcePlatform 用于 Sora 账号选择
|
||||
ctx = context.WithValue(ctx, ctxkey.ForcePlatform, service.PlatformSora)
|
||||
|
||||
// 选择一个 Sora 账号
|
||||
account, err := h.gatewayService.SelectAccountForModel(ctx, nil, "", "sora2-landscape-10s")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("选择 Sora 账号失败: %w", err)
|
||||
}
|
||||
|
||||
// 仅支持 API Key 类型账号
|
||||
if account.Type != service.AccountTypeAPIKey {
|
||||
return nil, fmt.Errorf("当前账号类型 %s 不支持模型同步", account.Type)
|
||||
}
|
||||
|
||||
apiKey := account.GetCredential("api_key")
|
||||
if apiKey == "" {
|
||||
return nil, fmt.Errorf("账号缺少 api_key")
|
||||
}
|
||||
|
||||
baseURL := account.GetBaseURL()
|
||||
if baseURL == "" {
|
||||
return nil, fmt.Errorf("账号缺少 base_url")
|
||||
}
|
||||
|
||||
// 构建上游模型列表请求
|
||||
modelsURL := strings.TrimRight(baseURL, "/") + "/sora/v1/models"
|
||||
|
||||
reqCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
req, err := http.NewRequestWithContext(reqCtx, http.MethodGet, modelsURL, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("创建请求失败: %w", err)
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+apiKey)
|
||||
|
||||
client := &http.Client{Timeout: 10 * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("请求上游失败: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
_ = resp.Body.Close()
|
||||
}()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("上游返回状态码 %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(io.LimitReader(resp.Body, 1*1024*1024))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("读取响应失败: %w", err)
|
||||
}
|
||||
|
||||
// 解析 OpenAI 格式的模型列表
|
||||
var modelsResp struct {
|
||||
Data []struct {
|
||||
ID string `json:"id"`
|
||||
} `json:"data"`
|
||||
}
|
||||
if err := json.Unmarshal(body, &modelsResp); err != nil {
|
||||
return nil, fmt.Errorf("解析响应失败: %w", err)
|
||||
}
|
||||
|
||||
if len(modelsResp.Data) == 0 {
|
||||
return nil, fmt.Errorf("上游返回空模型列表")
|
||||
}
|
||||
|
||||
// 提取模型 ID
|
||||
modelIDs := make([]string, 0, len(modelsResp.Data))
|
||||
for _, m := range modelsResp.Data {
|
||||
modelIDs = append(modelIDs, m.ID)
|
||||
}
|
||||
|
||||
// 转换为模型家族
|
||||
families := service.BuildSoraModelFamiliesFromIDs(modelIDs)
|
||||
if len(families) == 0 {
|
||||
return nil, fmt.Errorf("未能从上游模型列表中识别出有效的模型家族")
|
||||
}
|
||||
|
||||
return families, nil
|
||||
}
|
||||
3135
backend/internal/handler/sora_client_handler_test.go
Normal file
3135
backend/internal/handler/sora_client_handler_test.go
Normal file
File diff suppressed because it is too large
Load Diff
@@ -7,7 +7,6 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"path"
|
||||
@@ -17,6 +16,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
@@ -107,7 +107,7 @@ func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) {
|
||||
zap.Any("group_id", apiKey.GroupID),
|
||||
)
|
||||
|
||||
body, err := io.ReadAll(c.Request.Body)
|
||||
body, err := pkghttputil.ReadRequestBodyWithPrealloc(c.Request)
|
||||
if err != nil {
|
||||
if maxErr, ok := extractMaxBytesError(err); ok {
|
||||
h.errorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit))
|
||||
@@ -461,6 +461,14 @@ func (h *SoraGatewayHandler) submitUsageRecordTask(task service.UsageRecordTask)
|
||||
// 回退路径:worker 池未注入时同步执行,避免退回到无界 goroutine 模式。
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
defer func() {
|
||||
if recovered := recover(); recovered != nil {
|
||||
logger.L().With(
|
||||
zap.String("component", "handler.sora_gateway.chat_completions"),
|
||||
zap.Any("panic", recovered),
|
||||
).Error("sora.usage_record_task_panic_recovered")
|
||||
}
|
||||
}()
|
||||
task(ctx)
|
||||
}
|
||||
|
||||
|
||||
@@ -314,10 +314,10 @@ func (s *stubUsageLogRepo) GetAccountTodayStats(ctx context.Context, accountID i
|
||||
func (s *stubUsageLogRepo) GetDashboardStats(ctx context.Context) (*usagestats.DashboardStats, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (s *stubUsageLogRepo) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, stream *bool, billingType *int8) ([]usagestats.TrendDataPoint, error) {
|
||||
func (s *stubUsageLogRepo) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) ([]usagestats.TrendDataPoint, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (s *stubUsageLogRepo) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, stream *bool, billingType *int8) ([]usagestats.ModelStat, error) {
|
||||
func (s *stubUsageLogRepo) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.ModelStat, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (s *stubUsageLogRepo) GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error) {
|
||||
|
||||
@@ -2,6 +2,7 @@ package handler
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
|
||||
@@ -65,8 +66,17 @@ func (h *UsageHandler) List(c *gin.Context) {
|
||||
// Parse additional filters
|
||||
model := c.Query("model")
|
||||
|
||||
var requestType *int16
|
||||
var stream *bool
|
||||
if streamStr := c.Query("stream"); streamStr != "" {
|
||||
if requestTypeStr := strings.TrimSpace(c.Query("request_type")); requestTypeStr != "" {
|
||||
parsed, err := service.ParseUsageRequestType(requestTypeStr)
|
||||
if err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
value := int16(parsed)
|
||||
requestType = &value
|
||||
} else if streamStr := c.Query("stream"); streamStr != "" {
|
||||
val, err := strconv.ParseBool(streamStr)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid stream value, use true or false")
|
||||
@@ -114,6 +124,7 @@ func (h *UsageHandler) List(c *gin.Context) {
|
||||
UserID: subject.UserID, // Always filter by current user for security
|
||||
APIKeyID: apiKeyID,
|
||||
Model: model,
|
||||
RequestType: requestType,
|
||||
Stream: stream,
|
||||
BillingType: billingType,
|
||||
StartTime: startTime,
|
||||
|
||||
80
backend/internal/handler/usage_handler_request_type_test.go
Normal file
80
backend/internal/handler/usage_handler_request_type_test.go
Normal file
@@ -0,0 +1,80 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
|
||||
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type userUsageRepoCapture struct {
|
||||
service.UsageLogRepository
|
||||
listFilters usagestats.UsageLogFilters
|
||||
}
|
||||
|
||||
func (s *userUsageRepoCapture) ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters usagestats.UsageLogFilters) ([]service.UsageLog, *pagination.PaginationResult, error) {
|
||||
s.listFilters = filters
|
||||
return []service.UsageLog{}, &pagination.PaginationResult{
|
||||
Total: 0,
|
||||
Page: params.Page,
|
||||
PageSize: params.PageSize,
|
||||
Pages: 0,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func newUserUsageRequestTypeTestRouter(repo *userUsageRepoCapture) *gin.Engine {
|
||||
gin.SetMode(gin.TestMode)
|
||||
usageSvc := service.NewUsageService(repo, nil, nil, nil)
|
||||
handler := NewUsageHandler(usageSvc, nil)
|
||||
router := gin.New()
|
||||
router.Use(func(c *gin.Context) {
|
||||
c.Set(string(middleware2.ContextKeyUser), middleware2.AuthSubject{UserID: 42})
|
||||
c.Next()
|
||||
})
|
||||
router.GET("/usage", handler.List)
|
||||
return router
|
||||
}
|
||||
|
||||
func TestUserUsageListRequestTypePriority(t *testing.T) {
|
||||
repo := &userUsageRepoCapture{}
|
||||
router := newUserUsageRequestTypeTestRouter(repo)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/usage?request_type=ws_v2&stream=bad", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
require.Equal(t, int64(42), repo.listFilters.UserID)
|
||||
require.NotNil(t, repo.listFilters.RequestType)
|
||||
require.Equal(t, int16(service.RequestTypeWSV2), *repo.listFilters.RequestType)
|
||||
require.Nil(t, repo.listFilters.Stream)
|
||||
}
|
||||
|
||||
func TestUserUsageListInvalidRequestType(t *testing.T) {
|
||||
repo := &userUsageRepoCapture{}
|
||||
router := newUserUsageRequestTypeTestRouter(repo)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/usage?request_type=invalid", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusBadRequest, rec.Code)
|
||||
}
|
||||
|
||||
func TestUserUsageListInvalidStream(t *testing.T) {
|
||||
repo := &userUsageRepoCapture{}
|
||||
router := newUserUsageRequestTypeTestRouter(repo)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/usage?stream=invalid", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusBadRequest, rec.Code)
|
||||
}
|
||||
@@ -61,6 +61,22 @@ func TestGatewayHandlerSubmitUsageRecordTask_NilTask(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestGatewayHandlerSubmitUsageRecordTask_WithoutPool_TaskPanicRecovered(t *testing.T) {
|
||||
h := &GatewayHandler{}
|
||||
var called atomic.Bool
|
||||
|
||||
require.NotPanics(t, func() {
|
||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||
panic("usage task panic")
|
||||
})
|
||||
})
|
||||
|
||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||
called.Store(true)
|
||||
})
|
||||
require.True(t, called.Load(), "panic 后后续任务应仍可执行")
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayHandlerSubmitUsageRecordTask_WithPool(t *testing.T) {
|
||||
pool := newUsageRecordTestPool(t)
|
||||
h := &OpenAIGatewayHandler{usageRecordWorkerPool: pool}
|
||||
@@ -98,6 +114,22 @@ func TestOpenAIGatewayHandlerSubmitUsageRecordTask_NilTask(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayHandlerSubmitUsageRecordTask_WithoutPool_TaskPanicRecovered(t *testing.T) {
|
||||
h := &OpenAIGatewayHandler{}
|
||||
var called atomic.Bool
|
||||
|
||||
require.NotPanics(t, func() {
|
||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||
panic("usage task panic")
|
||||
})
|
||||
})
|
||||
|
||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||
called.Store(true)
|
||||
})
|
||||
require.True(t, called.Load(), "panic 后后续任务应仍可执行")
|
||||
}
|
||||
|
||||
func TestSoraGatewayHandlerSubmitUsageRecordTask_WithPool(t *testing.T) {
|
||||
pool := newUsageRecordTestPool(t)
|
||||
h := &SoraGatewayHandler{usageRecordWorkerPool: pool}
|
||||
@@ -134,3 +166,19 @@ func TestSoraGatewayHandlerSubmitUsageRecordTask_NilTask(t *testing.T) {
|
||||
h.submitUsageRecordTask(nil)
|
||||
})
|
||||
}
|
||||
|
||||
func TestSoraGatewayHandlerSubmitUsageRecordTask_WithoutPool_TaskPanicRecovered(t *testing.T) {
|
||||
h := &SoraGatewayHandler{}
|
||||
var called atomic.Bool
|
||||
|
||||
require.NotPanics(t, func() {
|
||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||
panic("usage task panic")
|
||||
})
|
||||
})
|
||||
|
||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||
called.Store(true)
|
||||
})
|
||||
require.True(t, called.Load(), "panic 后后续任务应仍可执行")
|
||||
}
|
||||
|
||||
@@ -14,6 +14,7 @@ func ProvideAdminHandlers(
|
||||
groupHandler *admin.GroupHandler,
|
||||
accountHandler *admin.AccountHandler,
|
||||
announcementHandler *admin.AnnouncementHandler,
|
||||
dataManagementHandler *admin.DataManagementHandler,
|
||||
oauthHandler *admin.OAuthHandler,
|
||||
openaiOAuthHandler *admin.OpenAIOAuthHandler,
|
||||
geminiOAuthHandler *admin.GeminiOAuthHandler,
|
||||
@@ -35,6 +36,7 @@ func ProvideAdminHandlers(
|
||||
Group: groupHandler,
|
||||
Account: accountHandler,
|
||||
Announcement: announcementHandler,
|
||||
DataManagement: dataManagementHandler,
|
||||
OAuth: oauthHandler,
|
||||
OpenAIOAuth: openaiOAuthHandler,
|
||||
GeminiOAuth: geminiOAuthHandler,
|
||||
@@ -75,6 +77,7 @@ func ProvideHandlers(
|
||||
gatewayHandler *GatewayHandler,
|
||||
openaiGatewayHandler *OpenAIGatewayHandler,
|
||||
soraGatewayHandler *SoraGatewayHandler,
|
||||
soraClientHandler *SoraClientHandler,
|
||||
settingHandler *SettingHandler,
|
||||
totpHandler *TotpHandler,
|
||||
_ *service.IdempotencyCoordinator,
|
||||
@@ -92,6 +95,7 @@ func ProvideHandlers(
|
||||
Gateway: gatewayHandler,
|
||||
OpenAIGateway: openaiGatewayHandler,
|
||||
SoraGateway: soraGatewayHandler,
|
||||
SoraClient: soraClientHandler,
|
||||
Setting: settingHandler,
|
||||
Totp: totpHandler,
|
||||
}
|
||||
@@ -119,6 +123,7 @@ var ProviderSet = wire.NewSet(
|
||||
admin.NewGroupHandler,
|
||||
admin.NewAccountHandler,
|
||||
admin.NewAnnouncementHandler,
|
||||
admin.NewDataManagementHandler,
|
||||
admin.NewOAuthHandler,
|
||||
admin.NewOpenAIOAuthHandler,
|
||||
admin.NewGeminiOAuthHandler,
|
||||
|
||||
Reference in New Issue
Block a user