Merge branch 'dev-release'

This commit is contained in:
yangjianbo
2026-02-07 19:58:00 +08:00
62 changed files with 2674 additions and 564 deletions

View File

@@ -3,6 +3,7 @@ package admin
import (
"errors"
"fmt"
"strconv"
"strings"
"sync"
@@ -789,57 +790,40 @@ func (h *AccountHandler) BatchUpdateCredentials(c *gin.Context) {
}
ctx := c.Request.Context()
success := 0
failed := 0
results := []gin.H{}
// 阶段一:预验证所有账号存在,收集 credentials
type accountUpdate struct {
ID int64
Credentials map[string]any
}
updates := make([]accountUpdate, 0, len(req.AccountIDs))
for _, accountID := range req.AccountIDs {
// Get account
account, err := h.adminService.GetAccount(ctx, accountID)
if err != nil {
failed++
results = append(results, gin.H{
"account_id": accountID,
"success": false,
"error": "Account not found",
})
continue
response.Error(c, 404, fmt.Sprintf("Account %d not found", accountID))
return
}
// Update credentials field
if account.Credentials == nil {
account.Credentials = make(map[string]any)
}
account.Credentials[req.Field] = req.Value
updates = append(updates, accountUpdate{ID: accountID, Credentials: account.Credentials})
}
// Update account
// 阶段二:依次更新,任何失败立即返回(避免部分成功部分失败)
for _, u := range updates {
updateInput := &service.UpdateAccountInput{
Credentials: account.Credentials,
Credentials: u.Credentials,
}
_, err = h.adminService.UpdateAccount(ctx, accountID, updateInput)
if err != nil {
failed++
results = append(results, gin.H{
"account_id": accountID,
"success": false,
"error": err.Error(),
})
continue
if _, err := h.adminService.UpdateAccount(ctx, u.ID, updateInput); err != nil {
response.Error(c, 500, fmt.Sprintf("Failed to update account %d: %v", u.ID, err))
return
}
success++
results = append(results, gin.H{
"account_id": accountID,
"success": true,
})
}
response.Success(c, gin.H{
"success": success,
"failed": failed,
"results": results,
"success": len(updates),
"failed": 0,
})
}

View File

@@ -0,0 +1,200 @@
//go:build unit
package admin
import (
"bytes"
"context"
"encoding/json"
"errors"
"net/http"
"net/http/httptest"
"sync/atomic"
"testing"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
"github.com/Wei-Shaw/sub2api/internal/service"
)
// failingAdminService 嵌入 stubAdminService可配置 UpdateAccount 在指定 ID 时失败。
type failingAdminService struct {
*stubAdminService
failOnAccountID int64
updateCallCount atomic.Int64
}
func (f *failingAdminService) UpdateAccount(ctx context.Context, id int64, input *service.UpdateAccountInput) (*service.Account, error) {
f.updateCallCount.Add(1)
if id == f.failOnAccountID {
return nil, errors.New("database error")
}
return f.stubAdminService.UpdateAccount(ctx, id, input)
}
func setupAccountHandlerWithService(adminSvc service.AdminService) (*gin.Engine, *AccountHandler) {
gin.SetMode(gin.TestMode)
router := gin.New()
handler := NewAccountHandler(adminSvc, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
router.POST("/api/v1/admin/accounts/batch-update-credentials", handler.BatchUpdateCredentials)
return router, handler
}
func TestBatchUpdateCredentials_AllSuccess(t *testing.T) {
svc := &failingAdminService{stubAdminService: newStubAdminService()}
router, _ := setupAccountHandlerWithService(svc)
body, _ := json.Marshal(BatchUpdateCredentialsRequest{
AccountIDs: []int64{1, 2, 3},
Field: "account_uuid",
Value: "test-uuid",
})
w := httptest.NewRecorder()
req, _ := http.NewRequest("POST", "/api/v1/admin/accounts/batch-update-credentials", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(w, req)
require.Equal(t, http.StatusOK, w.Code, "全部成功时应返回 200")
require.Equal(t, int64(3), svc.updateCallCount.Load(), "应调用 3 次 UpdateAccount")
}
func TestBatchUpdateCredentials_FailFast(t *testing.T) {
// 让第 2 个账号ID=2更新时失败
svc := &failingAdminService{
stubAdminService: newStubAdminService(),
failOnAccountID: 2,
}
router, _ := setupAccountHandlerWithService(svc)
body, _ := json.Marshal(BatchUpdateCredentialsRequest{
AccountIDs: []int64{1, 2, 3},
Field: "org_uuid",
Value: "test-org",
})
w := httptest.NewRecorder()
req, _ := http.NewRequest("POST", "/api/v1/admin/accounts/batch-update-credentials", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(w, req)
require.Equal(t, http.StatusInternalServerError, w.Code, "ID=2 失败时应返回 500")
// 验证 fail-fastID=1 更新成功ID=2 失败ID=3 不应被调用
require.Equal(t, int64(2), svc.updateCallCount.Load(),
"fail-fast: 应只调用 2 次 UpdateAccountID=1 成功、ID=2 失败后停止)")
}
func TestBatchUpdateCredentials_FirstAccountNotFound(t *testing.T) {
// GetAccount 在 stubAdminService 中总是成功的,需要创建一个 GetAccount 会失败的 stub
svc := &getAccountFailingService{
stubAdminService: newStubAdminService(),
failOnAccountID: 1,
}
router, _ := setupAccountHandlerWithService(svc)
body, _ := json.Marshal(BatchUpdateCredentialsRequest{
AccountIDs: []int64{1, 2, 3},
Field: "account_uuid",
Value: "test",
})
w := httptest.NewRecorder()
req, _ := http.NewRequest("POST", "/api/v1/admin/accounts/batch-update-credentials", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(w, req)
require.Equal(t, http.StatusNotFound, w.Code, "第一阶段验证失败应返回 404")
}
// getAccountFailingService 模拟 GetAccount 在特定 ID 时返回 not found。
type getAccountFailingService struct {
*stubAdminService
failOnAccountID int64
}
func (f *getAccountFailingService) GetAccount(ctx context.Context, id int64) (*service.Account, error) {
if id == f.failOnAccountID {
return nil, errors.New("not found")
}
return f.stubAdminService.GetAccount(ctx, id)
}
func TestBatchUpdateCredentials_InterceptWarmupRequests_NonBool(t *testing.T) {
svc := &failingAdminService{stubAdminService: newStubAdminService()}
router, _ := setupAccountHandlerWithService(svc)
// intercept_warmup_requests 传入非 bool 类型string应返回 400
body, _ := json.Marshal(map[string]any{
"account_ids": []int64{1},
"field": "intercept_warmup_requests",
"value": "not-a-bool",
})
w := httptest.NewRecorder()
req, _ := http.NewRequest("POST", "/api/v1/admin/accounts/batch-update-credentials", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(w, req)
require.Equal(t, http.StatusBadRequest, w.Code,
"intercept_warmup_requests 传入非 bool 值应返回 400")
}
func TestBatchUpdateCredentials_InterceptWarmupRequests_ValidBool(t *testing.T) {
svc := &failingAdminService{stubAdminService: newStubAdminService()}
router, _ := setupAccountHandlerWithService(svc)
body, _ := json.Marshal(map[string]any{
"account_ids": []int64{1},
"field": "intercept_warmup_requests",
"value": true,
})
w := httptest.NewRecorder()
req, _ := http.NewRequest("POST", "/api/v1/admin/accounts/batch-update-credentials", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(w, req)
require.Equal(t, http.StatusOK, w.Code,
"intercept_warmup_requests 传入合法 bool 值应返回 200")
}
func TestBatchUpdateCredentials_AccountUUID_NonString(t *testing.T) {
svc := &failingAdminService{stubAdminService: newStubAdminService()}
router, _ := setupAccountHandlerWithService(svc)
// account_uuid 传入非 string 类型number应返回 400
body, _ := json.Marshal(map[string]any{
"account_ids": []int64{1},
"field": "account_uuid",
"value": 12345,
})
w := httptest.NewRecorder()
req, _ := http.NewRequest("POST", "/api/v1/admin/accounts/batch-update-credentials", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(w, req)
require.Equal(t, http.StatusBadRequest, w.Code,
"account_uuid 传入非 string 值应返回 400")
}
func TestBatchUpdateCredentials_AccountUUID_NullValue(t *testing.T) {
svc := &failingAdminService{stubAdminService: newStubAdminService()}
router, _ := setupAccountHandlerWithService(svc)
// account_uuid 传入 null设置为空应正常通过
body, _ := json.Marshal(map[string]any{
"account_ids": []int64{1},
"field": "account_uuid",
"value": nil,
})
w := httptest.NewRecorder()
req, _ := http.NewRequest("POST", "/api/v1/admin/accounts/batch-update-credentials", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(w, req)
require.Equal(t, http.StatusOK, w.Code,
"account_uuid 传入 null 应返回 200")
}

View File

@@ -379,7 +379,7 @@ func (h *DashboardHandler) GetBatchUsersUsage(c *gin.Context) {
return
}
stats, err := h.dashboardService.GetBatchUserUsageStats(c.Request.Context(), req.UserIDs)
stats, err := h.dashboardService.GetBatchUserUsageStats(c.Request.Context(), req.UserIDs, time.Time{}, time.Time{})
if err != nil {
response.Error(c, 500, "Failed to get user usage stats")
return
@@ -407,7 +407,7 @@ func (h *DashboardHandler) GetBatchAPIKeysUsage(c *gin.Context) {
return
}
stats, err := h.dashboardService.GetBatchAPIKeyUsageStats(c.Request.Context(), req.APIKeyIDs)
stats, err := h.dashboardService.GetBatchAPIKeyUsageStats(c.Request.Context(), req.APIKeyIDs, time.Time{}, time.Time{})
if err != nil {
response.Error(c, 500, "Failed to get API key usage stats")
return

View File

@@ -0,0 +1,97 @@
//go:build unit
package admin
import (
"testing"
"github.com/stretchr/testify/require"
)
// truncateSearchByRune 模拟 user_handler.go 中的 search 截断逻辑
func truncateSearchByRune(search string, maxRunes int) string {
if runes := []rune(search); len(runes) > maxRunes {
return string(runes[:maxRunes])
}
return search
}
func TestTruncateSearchByRune(t *testing.T) {
tests := []struct {
name string
input string
maxRunes int
wantLen int // 期望的 rune 长度
}{
{
name: "纯中文超长",
input: string(make([]rune, 150)),
maxRunes: 100,
wantLen: 100,
},
{
name: "纯 ASCII 超长",
input: string(make([]byte, 150)),
maxRunes: 100,
wantLen: 100,
},
{
name: "空字符串",
input: "",
maxRunes: 100,
wantLen: 0,
},
{
name: "恰好 100 个字符",
input: string(make([]rune, 100)),
maxRunes: 100,
wantLen: 100,
},
{
name: "不足 100 字符不截断",
input: "hello世界",
maxRunes: 100,
wantLen: 7,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
result := truncateSearchByRune(tc.input, tc.maxRunes)
require.Equal(t, tc.wantLen, len([]rune(result)))
})
}
}
func TestTruncateSearchByRune_PreservesMultibyte(t *testing.T) {
// 101 个中文字符,截断到 100 个后应该仍然是有效 UTF-8
input := ""
for i := 0; i < 101; i++ {
input += "中"
}
result := truncateSearchByRune(input, 100)
require.Equal(t, 100, len([]rune(result)))
// 验证截断结果是有效的 UTF-8每个中文字符 3 字节)
require.Equal(t, 300, len(result))
}
func TestTruncateSearchByRune_MixedASCIIAndMultibyte(t *testing.T) {
// 50 个 ASCII + 51 个中文 = 101 个 rune
input := ""
for i := 0; i < 50; i++ {
input += "a"
}
for i := 0; i < 51; i++ {
input += "中"
}
result := truncateSearchByRune(input, 100)
runes := []rune(result)
require.Equal(t, 100, len(runes))
// 前 50 个应该是 'a',后 50 个应该是 '中'
require.Equal(t, 'a', runes[0])
require.Equal(t, 'a', runes[49])
require.Equal(t, '中', runes[50])
require.Equal(t, '中', runes[99])
}

View File

@@ -70,8 +70,8 @@ func (h *UserHandler) List(c *gin.Context) {
search := c.Query("search")
// 标准化和验证 search 参数
search = strings.TrimSpace(search)
if len(search) > 100 {
search = search[:100]
if runes := []rune(search); len(runes) > 100 {
search = string(runes[:100])
}
filters := service.UserListFilters{

View File

@@ -2,6 +2,7 @@ package handler
import (
"log/slog"
"strings"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
@@ -448,17 +449,12 @@ func (h *AuthHandler) ForgotPassword(c *gin.Context) {
return
}
// Build frontend base URL from request
scheme := "https"
if c.Request.TLS == nil {
// Check X-Forwarded-Proto header (common in reverse proxy setups)
if proto := c.GetHeader("X-Forwarded-Proto"); proto != "" {
scheme = proto
} else {
scheme = "http"
}
frontendBaseURL := strings.TrimSpace(h.cfg.Server.FrontendURL)
if frontendBaseURL == "" {
slog.Error("server.frontend_url not configured; cannot build password reset link")
response.InternalError(c, "Password reset is not configured")
return
}
frontendBaseURL := scheme + "://" + c.Request.Host
// Request password reset (async)
// Note: This returns success even if email doesn't exist (to prevent enumeration)

View File

@@ -236,7 +236,8 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, reqModel, failedAccountIDs, "") // Gemini 不使用会话限制
if err != nil {
if len(failedAccountIDs) == 0 {
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
log.Printf("[Gateway] SelectAccount failed: %v", err)
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "Service temporarily unavailable", streamStarted)
return
}
if lastFailoverErr != nil {
@@ -284,12 +285,12 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
if err == nil && canWait {
accountWaitCounted = true
}
// Ensure the wait counter is decremented if we exit before acquiring the slot.
defer func() {
releaseWait := func() {
if accountWaitCounted {
h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID)
accountWaitCounted = false
}
}()
}
accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout(
c,
@@ -301,14 +302,12 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
)
if err != nil {
log.Printf("Account concurrency acquire failed: %v", err)
releaseWait()
h.handleConcurrencyError(c, err, "account", streamStarted)
return
}
// Slot acquired: no longer waiting in queue.
if accountWaitCounted {
h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID)
accountWaitCounted = false
}
releaseWait()
if err := h.gatewayService.BindStickySession(c.Request.Context(), apiKey.GroupID, sessionKey, account.ID); err != nil {
log.Printf("Bind sticky session failed: %v", err)
}
@@ -398,7 +397,8 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), currentAPIKey.GroupID, sessionKey, reqModel, failedAccountIDs, parsedReq.MetadataUserID)
if err != nil {
if len(failedAccountIDs) == 0 {
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
log.Printf("[Gateway] SelectAccount failed: %v", err)
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "Service temporarily unavailable", streamStarted)
return
}
if lastFailoverErr != nil {
@@ -446,11 +446,12 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
if err == nil && canWait {
accountWaitCounted = true
}
defer func() {
releaseWait := func() {
if accountWaitCounted {
h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID)
accountWaitCounted = false
}
}()
}
accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout(
c,
@@ -462,13 +463,12 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
)
if err != nil {
log.Printf("Account concurrency acquire failed: %v", err)
releaseWait()
h.handleConcurrencyError(c, err, "account", streamStarted)
return
}
if accountWaitCounted {
h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID)
accountWaitCounted = false
}
// Slot acquired: no longer waiting in queue.
releaseWait()
if err := h.gatewayService.BindStickySession(c.Request.Context(), currentAPIKey.GroupID, sessionKey, account.ID); err != nil {
log.Printf("Bind sticky session failed: %v", err)
}
@@ -967,7 +967,8 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
// 选择支持该模型的账号
account, err := h.gatewayService.SelectAccountForModel(c.Request.Context(), apiKey.GroupID, sessionHash, parsedReq.Model)
if err != nil {
h.errorResponse(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error())
log.Printf("[Gateway] SelectAccountForModel failed: %v", err)
h.errorResponse(c, http.StatusServiceUnavailable, "api_error", "Service temporarily unavailable")
return
}
setOpsSelectedAccount(c, account.ID)
@@ -1238,7 +1239,8 @@ func billingErrorDetails(err error) (status int, code, message string) {
}
msg := pkgerrors.Message(err)
if msg == "" {
msg = err.Error()
log.Printf("[Gateway] billing error details: %v", err)
msg = "Billing error"
}
return http.StatusForbidden, "billing_error", msg
}

View File

@@ -28,6 +28,7 @@ type OpenAIGatewayHandler struct {
errorPassthroughService *service.ErrorPassthroughService
concurrencyHelper *ConcurrencyHelper
maxAccountSwitches int
cfg *config.Config
}
// NewOpenAIGatewayHandler creates a new OpenAIGatewayHandler
@@ -54,6 +55,7 @@ func NewOpenAIGatewayHandler(
errorPassthroughService: errorPassthroughService,
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatComment, pingInterval),
maxAccountSwitches: maxAccountSwitches,
cfg: cfg,
}
}
@@ -109,7 +111,8 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
}
userAgent := c.GetHeader("User-Agent")
if !openai.IsCodexCLIRequest(userAgent) {
isCodexCLI := openai.IsCodexCLIRequest(userAgent) || (h.cfg != nil && h.cfg.Gateway.ForceCodexCLI)
if !isCodexCLI {
existingInstructions, _ := reqBody["instructions"].(string)
if strings.TrimSpace(existingInstructions) == "" {
if instructions := strings.TrimSpace(service.GetOpenCodeInstructions()); instructions != "" {
@@ -218,7 +221,8 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
if err != nil {
log.Printf("[OpenAI Handler] SelectAccount failed: %v", err)
if len(failedAccountIDs) == 0 {
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
log.Printf("[OpenAI Gateway] SelectAccount failed: %v", err)
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "Service temporarily unavailable", streamStarted)
return
}
if lastFailoverErr != nil {
@@ -251,11 +255,12 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
if err == nil && canWait {
accountWaitCounted = true
}
defer func() {
releaseWait := func() {
if accountWaitCounted {
h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID)
accountWaitCounted = false
}
}()
}
accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout(
c,
@@ -267,13 +272,12 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
)
if err != nil {
log.Printf("Account concurrency acquire failed: %v", err)
releaseWait()
h.handleConcurrencyError(c, err, "account", streamStarted)
return
}
if accountWaitCounted {
h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID)
accountWaitCounted = false
}
// Slot acquired: no longer waiting in queue.
releaseWait()
if err := h.gatewayService.BindStickySession(c.Request.Context(), apiKey.GroupID, sessionHash, account.ID); err != nil {
log.Printf("Bind sticky session failed: %v", err)
}

View File

@@ -392,7 +392,7 @@ func (h *UsageHandler) DashboardAPIKeysUsage(c *gin.Context) {
return
}
stats, err := h.usageService.GetBatchAPIKeyUsageStats(c.Request.Context(), validAPIKeyIDs)
stats, err := h.usageService.GetBatchAPIKeyUsageStats(c.Request.Context(), validAPIKeyIDs, time.Time{}, time.Time{})
if err != nil {
response.ErrorFrom(c, err)
return