merge: 合并 main 分支到 test,解决 config 和 modelWhitelist 冲突

- config.go: 保留 Sora 配置,合入 SubscriptionCache 配置
- useModelWhitelist.ts: 同时保留 soraModels 和 antigravityModels

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
yangjianbo
2026-02-07 20:18:07 +08:00
156 changed files with 14550 additions and 2206 deletions

View File

@@ -0,0 +1,544 @@
package admin
import (
"context"
"errors"
"fmt"
"strconv"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
const (
dataType = "sub2api-data"
legacyDataType = "sub2api-bundle"
dataVersion = 1
dataPageCap = 1000
)
type DataPayload struct {
Type string `json:"type,omitempty"`
Version int `json:"version,omitempty"`
ExportedAt string `json:"exported_at"`
Proxies []DataProxy `json:"proxies"`
Accounts []DataAccount `json:"accounts"`
}
type DataProxy struct {
ProxyKey string `json:"proxy_key"`
Name string `json:"name"`
Protocol string `json:"protocol"`
Host string `json:"host"`
Port int `json:"port"`
Username string `json:"username,omitempty"`
Password string `json:"password,omitempty"`
Status string `json:"status"`
}
type DataAccount struct {
Name string `json:"name"`
Notes *string `json:"notes,omitempty"`
Platform string `json:"platform"`
Type string `json:"type"`
Credentials map[string]any `json:"credentials"`
Extra map[string]any `json:"extra,omitempty"`
ProxyKey *string `json:"proxy_key,omitempty"`
Concurrency int `json:"concurrency"`
Priority int `json:"priority"`
RateMultiplier *float64 `json:"rate_multiplier,omitempty"`
ExpiresAt *int64 `json:"expires_at,omitempty"`
AutoPauseOnExpired *bool `json:"auto_pause_on_expired,omitempty"`
}
type DataImportRequest struct {
Data DataPayload `json:"data"`
SkipDefaultGroupBind *bool `json:"skip_default_group_bind"`
}
type DataImportResult struct {
ProxyCreated int `json:"proxy_created"`
ProxyReused int `json:"proxy_reused"`
ProxyFailed int `json:"proxy_failed"`
AccountCreated int `json:"account_created"`
AccountFailed int `json:"account_failed"`
Errors []DataImportError `json:"errors,omitempty"`
}
type DataImportError struct {
Kind string `json:"kind"`
Name string `json:"name,omitempty"`
ProxyKey string `json:"proxy_key,omitempty"`
Message string `json:"message"`
}
func buildProxyKey(protocol, host string, port int, username, password string) string {
return fmt.Sprintf("%s|%s|%d|%s|%s", strings.TrimSpace(protocol), strings.TrimSpace(host), port, strings.TrimSpace(username), strings.TrimSpace(password))
}
func (h *AccountHandler) ExportData(c *gin.Context) {
ctx := c.Request.Context()
selectedIDs, err := parseAccountIDs(c)
if err != nil {
response.BadRequest(c, err.Error())
return
}
accounts, err := h.resolveExportAccounts(ctx, selectedIDs, c)
if err != nil {
response.ErrorFrom(c, err)
return
}
includeProxies, err := parseIncludeProxies(c)
if err != nil {
response.BadRequest(c, err.Error())
return
}
var proxies []service.Proxy
if includeProxies {
proxies, err = h.resolveExportProxies(ctx, accounts)
if err != nil {
response.ErrorFrom(c, err)
return
}
} else {
proxies = []service.Proxy{}
}
proxyKeyByID := make(map[int64]string, len(proxies))
dataProxies := make([]DataProxy, 0, len(proxies))
for i := range proxies {
p := proxies[i]
key := buildProxyKey(p.Protocol, p.Host, p.Port, p.Username, p.Password)
proxyKeyByID[p.ID] = key
dataProxies = append(dataProxies, DataProxy{
ProxyKey: key,
Name: p.Name,
Protocol: p.Protocol,
Host: p.Host,
Port: p.Port,
Username: p.Username,
Password: p.Password,
Status: p.Status,
})
}
dataAccounts := make([]DataAccount, 0, len(accounts))
for i := range accounts {
acc := accounts[i]
var proxyKey *string
if acc.ProxyID != nil {
if key, ok := proxyKeyByID[*acc.ProxyID]; ok {
proxyKey = &key
}
}
var expiresAt *int64
if acc.ExpiresAt != nil {
v := acc.ExpiresAt.Unix()
expiresAt = &v
}
dataAccounts = append(dataAccounts, DataAccount{
Name: acc.Name,
Notes: acc.Notes,
Platform: acc.Platform,
Type: acc.Type,
Credentials: acc.Credentials,
Extra: acc.Extra,
ProxyKey: proxyKey,
Concurrency: acc.Concurrency,
Priority: acc.Priority,
RateMultiplier: acc.RateMultiplier,
ExpiresAt: expiresAt,
AutoPauseOnExpired: &acc.AutoPauseOnExpired,
})
}
payload := DataPayload{
ExportedAt: time.Now().UTC().Format(time.RFC3339),
Proxies: dataProxies,
Accounts: dataAccounts,
}
response.Success(c, payload)
}
func (h *AccountHandler) ImportData(c *gin.Context) {
var req DataImportRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
dataPayload := req.Data
if err := validateDataHeader(dataPayload); err != nil {
response.BadRequest(c, err.Error())
return
}
skipDefaultGroupBind := true
if req.SkipDefaultGroupBind != nil {
skipDefaultGroupBind = *req.SkipDefaultGroupBind
}
result := DataImportResult{}
existingProxies, err := h.listAllProxies(c.Request.Context())
if err != nil {
response.ErrorFrom(c, err)
return
}
proxyKeyToID := make(map[string]int64, len(existingProxies))
for i := range existingProxies {
p := existingProxies[i]
key := buildProxyKey(p.Protocol, p.Host, p.Port, p.Username, p.Password)
proxyKeyToID[key] = p.ID
}
for i := range dataPayload.Proxies {
item := dataPayload.Proxies[i]
key := item.ProxyKey
if key == "" {
key = buildProxyKey(item.Protocol, item.Host, item.Port, item.Username, item.Password)
}
if err := validateDataProxy(item); err != nil {
result.ProxyFailed++
result.Errors = append(result.Errors, DataImportError{
Kind: "proxy",
Name: item.Name,
ProxyKey: key,
Message: err.Error(),
})
continue
}
normalizedStatus := normalizeProxyStatus(item.Status)
if existingID, ok := proxyKeyToID[key]; ok {
proxyKeyToID[key] = existingID
result.ProxyReused++
if normalizedStatus != "" {
if proxy, err := h.adminService.GetProxy(c.Request.Context(), existingID); err == nil && proxy != nil && proxy.Status != normalizedStatus {
_, _ = h.adminService.UpdateProxy(c.Request.Context(), existingID, &service.UpdateProxyInput{
Status: normalizedStatus,
})
}
}
continue
}
created, err := h.adminService.CreateProxy(c.Request.Context(), &service.CreateProxyInput{
Name: defaultProxyName(item.Name),
Protocol: item.Protocol,
Host: item.Host,
Port: item.Port,
Username: item.Username,
Password: item.Password,
})
if err != nil {
result.ProxyFailed++
result.Errors = append(result.Errors, DataImportError{
Kind: "proxy",
Name: item.Name,
ProxyKey: key,
Message: err.Error(),
})
continue
}
proxyKeyToID[key] = created.ID
result.ProxyCreated++
if normalizedStatus != "" && normalizedStatus != created.Status {
_, _ = h.adminService.UpdateProxy(c.Request.Context(), created.ID, &service.UpdateProxyInput{
Status: normalizedStatus,
})
}
}
for i := range dataPayload.Accounts {
item := dataPayload.Accounts[i]
if err := validateDataAccount(item); err != nil {
result.AccountFailed++
result.Errors = append(result.Errors, DataImportError{
Kind: "account",
Name: item.Name,
Message: err.Error(),
})
continue
}
var proxyID *int64
if item.ProxyKey != nil && *item.ProxyKey != "" {
if id, ok := proxyKeyToID[*item.ProxyKey]; ok {
proxyID = &id
} else {
result.AccountFailed++
result.Errors = append(result.Errors, DataImportError{
Kind: "account",
Name: item.Name,
ProxyKey: *item.ProxyKey,
Message: "proxy_key not found",
})
continue
}
}
accountInput := &service.CreateAccountInput{
Name: item.Name,
Notes: item.Notes,
Platform: item.Platform,
Type: item.Type,
Credentials: item.Credentials,
Extra: item.Extra,
ProxyID: proxyID,
Concurrency: item.Concurrency,
Priority: item.Priority,
RateMultiplier: item.RateMultiplier,
GroupIDs: nil,
ExpiresAt: item.ExpiresAt,
AutoPauseOnExpired: item.AutoPauseOnExpired,
SkipDefaultGroupBind: skipDefaultGroupBind,
}
if _, err := h.adminService.CreateAccount(c.Request.Context(), accountInput); err != nil {
result.AccountFailed++
result.Errors = append(result.Errors, DataImportError{
Kind: "account",
Name: item.Name,
Message: err.Error(),
})
continue
}
result.AccountCreated++
}
response.Success(c, result)
}
func (h *AccountHandler) listAllProxies(ctx context.Context) ([]service.Proxy, error) {
page := 1
pageSize := dataPageCap
var out []service.Proxy
for {
items, total, err := h.adminService.ListProxies(ctx, page, pageSize, "", "", "")
if err != nil {
return nil, err
}
out = append(out, items...)
if len(out) >= int(total) || len(items) == 0 {
break
}
page++
}
return out, nil
}
func (h *AccountHandler) listAccountsFiltered(ctx context.Context, platform, accountType, status, search string) ([]service.Account, error) {
page := 1
pageSize := dataPageCap
var out []service.Account
for {
items, total, err := h.adminService.ListAccounts(ctx, page, pageSize, platform, accountType, status, search)
if err != nil {
return nil, err
}
out = append(out, items...)
if len(out) >= int(total) || len(items) == 0 {
break
}
page++
}
return out, nil
}
func (h *AccountHandler) resolveExportAccounts(ctx context.Context, ids []int64, c *gin.Context) ([]service.Account, error) {
if len(ids) > 0 {
accounts, err := h.adminService.GetAccountsByIDs(ctx, ids)
if err != nil {
return nil, err
}
out := make([]service.Account, 0, len(accounts))
for _, acc := range accounts {
if acc == nil {
continue
}
out = append(out, *acc)
}
return out, nil
}
platform := c.Query("platform")
accountType := c.Query("type")
status := c.Query("status")
search := strings.TrimSpace(c.Query("search"))
if len(search) > 100 {
search = search[:100]
}
return h.listAccountsFiltered(ctx, platform, accountType, status, search)
}
func (h *AccountHandler) resolveExportProxies(ctx context.Context, accounts []service.Account) ([]service.Proxy, error) {
if len(accounts) == 0 {
return []service.Proxy{}, nil
}
seen := make(map[int64]struct{})
ids := make([]int64, 0)
for i := range accounts {
if accounts[i].ProxyID == nil {
continue
}
id := *accounts[i].ProxyID
if id <= 0 {
continue
}
if _, ok := seen[id]; ok {
continue
}
seen[id] = struct{}{}
ids = append(ids, id)
}
if len(ids) == 0 {
return []service.Proxy{}, nil
}
return h.adminService.GetProxiesByIDs(ctx, ids)
}
func parseAccountIDs(c *gin.Context) ([]int64, error) {
values := c.QueryArray("ids")
if len(values) == 0 {
raw := strings.TrimSpace(c.Query("ids"))
if raw != "" {
values = []string{raw}
}
}
if len(values) == 0 {
return nil, nil
}
ids := make([]int64, 0, len(values))
for _, item := range values {
for _, part := range strings.Split(item, ",") {
part = strings.TrimSpace(part)
if part == "" {
continue
}
id, err := strconv.ParseInt(part, 10, 64)
if err != nil || id <= 0 {
return nil, fmt.Errorf("invalid account id: %s", part)
}
ids = append(ids, id)
}
}
return ids, nil
}
func parseIncludeProxies(c *gin.Context) (bool, error) {
raw := strings.TrimSpace(strings.ToLower(c.Query("include_proxies")))
if raw == "" {
return true, nil
}
switch raw {
case "1", "true", "yes", "on":
return true, nil
case "0", "false", "no", "off":
return false, nil
default:
return true, fmt.Errorf("invalid include_proxies value: %s", raw)
}
}
func validateDataHeader(payload DataPayload) error {
if payload.Type != "" && payload.Type != dataType && payload.Type != legacyDataType {
return fmt.Errorf("unsupported data type: %s", payload.Type)
}
if payload.Version != 0 && payload.Version != dataVersion {
return fmt.Errorf("unsupported data version: %d", payload.Version)
}
if payload.Proxies == nil {
return errors.New("proxies is required")
}
if payload.Accounts == nil {
return errors.New("accounts is required")
}
return nil
}
func validateDataProxy(item DataProxy) error {
if strings.TrimSpace(item.Protocol) == "" {
return errors.New("proxy protocol is required")
}
if strings.TrimSpace(item.Host) == "" {
return errors.New("proxy host is required")
}
if item.Port <= 0 || item.Port > 65535 {
return errors.New("proxy port is invalid")
}
switch item.Protocol {
case "http", "https", "socks5", "socks5h":
default:
return fmt.Errorf("proxy protocol is invalid: %s", item.Protocol)
}
if item.Status != "" {
normalizedStatus := normalizeProxyStatus(item.Status)
if normalizedStatus != service.StatusActive && normalizedStatus != "inactive" {
return fmt.Errorf("proxy status is invalid: %s", item.Status)
}
}
return nil
}
func validateDataAccount(item DataAccount) error {
if strings.TrimSpace(item.Name) == "" {
return errors.New("account name is required")
}
if strings.TrimSpace(item.Platform) == "" {
return errors.New("account platform is required")
}
if strings.TrimSpace(item.Type) == "" {
return errors.New("account type is required")
}
if len(item.Credentials) == 0 {
return errors.New("account credentials is required")
}
switch item.Type {
case service.AccountTypeOAuth, service.AccountTypeSetupToken, service.AccountTypeAPIKey, service.AccountTypeUpstream:
default:
return fmt.Errorf("account type is invalid: %s", item.Type)
}
if item.RateMultiplier != nil && *item.RateMultiplier < 0 {
return errors.New("rate_multiplier must be >= 0")
}
if item.Concurrency < 0 {
return errors.New("concurrency must be >= 0")
}
if item.Priority < 0 {
return errors.New("priority must be >= 0")
}
return nil
}
func defaultProxyName(name string) string {
if strings.TrimSpace(name) == "" {
return "imported-proxy"
}
return name
}
func normalizeProxyStatus(status string) string {
normalized := strings.TrimSpace(strings.ToLower(status))
switch normalized {
case "":
return ""
case service.StatusActive:
return service.StatusActive
case "inactive", service.StatusDisabled:
return "inactive"
default:
return normalized
}
}

View File

@@ -0,0 +1,231 @@
package admin
import (
"bytes"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
type dataResponse struct {
Code int `json:"code"`
Data dataPayload `json:"data"`
}
type dataPayload struct {
Type string `json:"type"`
Version int `json:"version"`
Proxies []dataProxy `json:"proxies"`
Accounts []dataAccount `json:"accounts"`
}
type dataProxy struct {
ProxyKey string `json:"proxy_key"`
Name string `json:"name"`
Protocol string `json:"protocol"`
Host string `json:"host"`
Port int `json:"port"`
Username string `json:"username"`
Password string `json:"password"`
Status string `json:"status"`
}
type dataAccount struct {
Name string `json:"name"`
Platform string `json:"platform"`
Type string `json:"type"`
Credentials map[string]any `json:"credentials"`
Extra map[string]any `json:"extra"`
ProxyKey *string `json:"proxy_key"`
Concurrency int `json:"concurrency"`
Priority int `json:"priority"`
}
func setupAccountDataRouter() (*gin.Engine, *stubAdminService) {
gin.SetMode(gin.TestMode)
router := gin.New()
adminSvc := newStubAdminService()
h := NewAccountHandler(
adminSvc,
nil,
nil,
nil,
nil,
nil,
nil,
nil,
nil,
nil,
nil,
nil,
)
router.GET("/api/v1/admin/accounts/data", h.ExportData)
router.POST("/api/v1/admin/accounts/data", h.ImportData)
return router, adminSvc
}
func TestExportDataIncludesSecrets(t *testing.T) {
router, adminSvc := setupAccountDataRouter()
proxyID := int64(11)
adminSvc.proxies = []service.Proxy{
{
ID: proxyID,
Name: "proxy",
Protocol: "http",
Host: "127.0.0.1",
Port: 8080,
Username: "user",
Password: "pass",
Status: service.StatusActive,
},
{
ID: 12,
Name: "orphan",
Protocol: "https",
Host: "10.0.0.1",
Port: 443,
Username: "o",
Password: "p",
Status: service.StatusActive,
},
}
adminSvc.accounts = []service.Account{
{
ID: 21,
Name: "account",
Platform: service.PlatformOpenAI,
Type: service.AccountTypeOAuth,
Credentials: map[string]any{"token": "secret"},
Extra: map[string]any{"note": "x"},
ProxyID: &proxyID,
Concurrency: 3,
Priority: 50,
Status: service.StatusDisabled,
},
}
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/accounts/data", nil)
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
var resp dataResponse
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
require.Equal(t, 0, resp.Code)
require.Empty(t, resp.Data.Type)
require.Equal(t, 0, resp.Data.Version)
require.Len(t, resp.Data.Proxies, 1)
require.Equal(t, "pass", resp.Data.Proxies[0].Password)
require.Len(t, resp.Data.Accounts, 1)
require.Equal(t, "secret", resp.Data.Accounts[0].Credentials["token"])
}
func TestExportDataWithoutProxies(t *testing.T) {
router, adminSvc := setupAccountDataRouter()
proxyID := int64(11)
adminSvc.proxies = []service.Proxy{
{
ID: proxyID,
Name: "proxy",
Protocol: "http",
Host: "127.0.0.1",
Port: 8080,
Username: "user",
Password: "pass",
Status: service.StatusActive,
},
}
adminSvc.accounts = []service.Account{
{
ID: 21,
Name: "account",
Platform: service.PlatformOpenAI,
Type: service.AccountTypeOAuth,
Credentials: map[string]any{"token": "secret"},
ProxyID: &proxyID,
Concurrency: 3,
Priority: 50,
Status: service.StatusDisabled,
},
}
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/accounts/data?include_proxies=false", nil)
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
var resp dataResponse
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
require.Equal(t, 0, resp.Code)
require.Len(t, resp.Data.Proxies, 0)
require.Len(t, resp.Data.Accounts, 1)
require.Nil(t, resp.Data.Accounts[0].ProxyKey)
}
func TestImportDataReusesProxyAndSkipsDefaultGroup(t *testing.T) {
router, adminSvc := setupAccountDataRouter()
adminSvc.proxies = []service.Proxy{
{
ID: 1,
Name: "proxy",
Protocol: "socks5",
Host: "1.2.3.4",
Port: 1080,
Username: "u",
Password: "p",
Status: service.StatusActive,
},
}
dataPayload := map[string]any{
"data": map[string]any{
"type": dataType,
"version": dataVersion,
"proxies": []map[string]any{
{
"proxy_key": "socks5|1.2.3.4|1080|u|p",
"name": "proxy",
"protocol": "socks5",
"host": "1.2.3.4",
"port": 1080,
"username": "u",
"password": "p",
"status": "active",
},
},
"accounts": []map[string]any{
{
"name": "acc",
"platform": service.PlatformOpenAI,
"type": service.AccountTypeOAuth,
"credentials": map[string]any{"token": "x"},
"proxy_key": "socks5|1.2.3.4|1080|u|p",
"concurrency": 3,
"priority": 50,
},
},
},
"skip_default_group_bind": true,
}
body, _ := json.Marshal(dataPayload)
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts/data", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
require.Len(t, adminSvc.createdProxies, 0)
require.Len(t, adminSvc.createdAccounts, 1)
require.True(t, adminSvc.createdAccounts[0].SkipDefaultGroupBind)
}

View File

@@ -3,11 +3,13 @@ package admin
import (
"errors"
"fmt"
"strconv"
"strings"
"sync"
"time"
"github.com/Wei-Shaw/sub2api/internal/domain"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
"github.com/Wei-Shaw/sub2api/internal/pkg/geminicli"
@@ -696,11 +698,61 @@ func (h *AccountHandler) BatchCreate(c *gin.Context) {
return
}
// Return mock data for now
ctx := c.Request.Context()
success := 0
failed := 0
results := make([]gin.H, 0, len(req.Accounts))
for _, item := range req.Accounts {
if item.RateMultiplier != nil && *item.RateMultiplier < 0 {
failed++
results = append(results, gin.H{
"name": item.Name,
"success": false,
"error": "rate_multiplier must be >= 0",
})
continue
}
skipCheck := item.ConfirmMixedChannelRisk != nil && *item.ConfirmMixedChannelRisk
account, err := h.adminService.CreateAccount(ctx, &service.CreateAccountInput{
Name: item.Name,
Notes: item.Notes,
Platform: item.Platform,
Type: item.Type,
Credentials: item.Credentials,
Extra: item.Extra,
ProxyID: item.ProxyID,
Concurrency: item.Concurrency,
Priority: item.Priority,
RateMultiplier: item.RateMultiplier,
GroupIDs: item.GroupIDs,
ExpiresAt: item.ExpiresAt,
AutoPauseOnExpired: item.AutoPauseOnExpired,
SkipMixedChannelCheck: skipCheck,
})
if err != nil {
failed++
results = append(results, gin.H{
"name": item.Name,
"success": false,
"error": err.Error(),
})
continue
}
success++
results = append(results, gin.H{
"name": item.Name,
"id": account.ID,
"success": true,
})
}
response.Success(c, gin.H{
"success": len(req.Accounts),
"failed": 0,
"results": []gin.H{},
"success": success,
"failed": failed,
"results": results,
})
}
@@ -738,57 +790,40 @@ func (h *AccountHandler) BatchUpdateCredentials(c *gin.Context) {
}
ctx := c.Request.Context()
success := 0
failed := 0
results := []gin.H{}
// 阶段一:预验证所有账号存在,收集 credentials
type accountUpdate struct {
ID int64
Credentials map[string]any
}
updates := make([]accountUpdate, 0, len(req.AccountIDs))
for _, accountID := range req.AccountIDs {
// Get account
account, err := h.adminService.GetAccount(ctx, accountID)
if err != nil {
failed++
results = append(results, gin.H{
"account_id": accountID,
"success": false,
"error": "Account not found",
})
continue
response.Error(c, 404, fmt.Sprintf("Account %d not found", accountID))
return
}
// Update credentials field
if account.Credentials == nil {
account.Credentials = make(map[string]any)
}
account.Credentials[req.Field] = req.Value
updates = append(updates, accountUpdate{ID: accountID, Credentials: account.Credentials})
}
// Update account
// 阶段二:依次更新,任何失败立即返回(避免部分成功部分失败)
for _, u := range updates {
updateInput := &service.UpdateAccountInput{
Credentials: account.Credentials,
Credentials: u.Credentials,
}
_, err = h.adminService.UpdateAccount(ctx, accountID, updateInput)
if err != nil {
failed++
results = append(results, gin.H{
"account_id": accountID,
"success": false,
"error": err.Error(),
})
continue
if _, err := h.adminService.UpdateAccount(ctx, u.ID, updateInput); err != nil {
response.Error(c, 500, fmt.Sprintf("Failed to update account %d: %v", u.ID, err))
return
}
success++
results = append(results, gin.H{
"account_id": accountID,
"success": true,
})
}
response.Success(c, gin.H{
"success": success,
"failed": failed,
"results": results,
"success": len(updates),
"failed": 0,
})
}
@@ -1440,3 +1475,9 @@ func (h *AccountHandler) BatchRefreshTier(c *gin.Context) {
response.Success(c, results)
}
// GetAntigravityDefaultModelMapping 获取 Antigravity 平台的默认模型映射
// GET /api/v1/admin/accounts/antigravity/default-model-mapping
func (h *AccountHandler) GetAntigravityDefaultModelMapping(c *gin.Context) {
response.Success(c, domain.DefaultAntigravityModelMapping)
}

View File

@@ -2,19 +2,27 @@ package admin
import (
"context"
"strings"
"sync"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
)
type stubAdminService struct {
users []service.User
apiKeys []service.APIKey
groups []service.Group
accounts []service.Account
proxies []service.Proxy
proxyCounts []service.ProxyWithAccountCount
redeems []service.RedeemCode
users []service.User
apiKeys []service.APIKey
groups []service.Group
accounts []service.Account
proxies []service.Proxy
proxyCounts []service.ProxyWithAccountCount
redeems []service.RedeemCode
createdAccounts []*service.CreateAccountInput
createdProxies []*service.CreateProxyInput
updatedProxyIDs []int64
updatedProxies []*service.UpdateProxyInput
testedProxyIDs []int64
mu sync.Mutex
}
func newStubAdminService() *stubAdminService {
@@ -177,6 +185,9 @@ func (s *stubAdminService) GetAccountsByIDs(ctx context.Context, ids []int64) ([
}
func (s *stubAdminService) CreateAccount(ctx context.Context, input *service.CreateAccountInput) (*service.Account, error) {
s.mu.Lock()
s.createdAccounts = append(s.createdAccounts, input)
s.mu.Unlock()
account := service.Account{ID: 300, Name: input.Name, Status: service.StatusActive}
return &account, nil
}
@@ -214,7 +225,25 @@ func (s *stubAdminService) BulkUpdateAccounts(ctx context.Context, input *servic
}
func (s *stubAdminService) ListProxies(ctx context.Context, page, pageSize int, protocol, status, search string) ([]service.Proxy, int64, error) {
return s.proxies, int64(len(s.proxies)), nil
search = strings.TrimSpace(strings.ToLower(search))
filtered := make([]service.Proxy, 0, len(s.proxies))
for _, proxy := range s.proxies {
if protocol != "" && proxy.Protocol != protocol {
continue
}
if status != "" && proxy.Status != status {
continue
}
if search != "" {
name := strings.ToLower(proxy.Name)
host := strings.ToLower(proxy.Host)
if !strings.Contains(name, search) && !strings.Contains(host, search) {
continue
}
}
filtered = append(filtered, proxy)
}
return filtered, int64(len(filtered)), nil
}
func (s *stubAdminService) ListProxiesWithAccountCount(ctx context.Context, page, pageSize int, protocol, status, search string) ([]service.ProxyWithAccountCount, int64, error) {
@@ -230,16 +259,47 @@ func (s *stubAdminService) GetAllProxiesWithAccountCount(ctx context.Context) ([
}
func (s *stubAdminService) GetProxy(ctx context.Context, id int64) (*service.Proxy, error) {
for i := range s.proxies {
proxy := s.proxies[i]
if proxy.ID == id {
return &proxy, nil
}
}
proxy := service.Proxy{ID: id, Name: "proxy", Status: service.StatusActive}
return &proxy, nil
}
func (s *stubAdminService) GetProxiesByIDs(ctx context.Context, ids []int64) ([]service.Proxy, error) {
if len(ids) == 0 {
return []service.Proxy{}, nil
}
out := make([]service.Proxy, 0, len(ids))
seen := make(map[int64]struct{}, len(ids))
for _, id := range ids {
seen[id] = struct{}{}
}
for i := range s.proxies {
proxy := s.proxies[i]
if _, ok := seen[proxy.ID]; ok {
out = append(out, proxy)
}
}
return out, nil
}
func (s *stubAdminService) CreateProxy(ctx context.Context, input *service.CreateProxyInput) (*service.Proxy, error) {
s.mu.Lock()
s.createdProxies = append(s.createdProxies, input)
s.mu.Unlock()
proxy := service.Proxy{ID: 400, Name: input.Name, Status: service.StatusActive}
return &proxy, nil
}
func (s *stubAdminService) UpdateProxy(ctx context.Context, id int64, input *service.UpdateProxyInput) (*service.Proxy, error) {
s.mu.Lock()
s.updatedProxyIDs = append(s.updatedProxyIDs, id)
s.updatedProxies = append(s.updatedProxies, input)
s.mu.Unlock()
proxy := service.Proxy{ID: id, Name: input.Name, Status: service.StatusActive}
return &proxy, nil
}
@@ -261,6 +321,9 @@ func (s *stubAdminService) CheckProxyExists(ctx context.Context, host string, po
}
func (s *stubAdminService) TestProxy(ctx context.Context, id int64) (*service.ProxyTestResult, error) {
s.mu.Lock()
s.testedProxyIDs = append(s.testedProxyIDs, id)
s.mu.Unlock()
return &service.ProxyTestResult{Success: true, Message: "ok"}, nil
}

View File

@@ -0,0 +1,200 @@
//go:build unit
package admin
import (
"bytes"
"context"
"encoding/json"
"errors"
"net/http"
"net/http/httptest"
"sync/atomic"
"testing"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
"github.com/Wei-Shaw/sub2api/internal/service"
)
// failingAdminService 嵌入 stubAdminService可配置 UpdateAccount 在指定 ID 时失败。
type failingAdminService struct {
*stubAdminService
failOnAccountID int64
updateCallCount atomic.Int64
}
func (f *failingAdminService) UpdateAccount(ctx context.Context, id int64, input *service.UpdateAccountInput) (*service.Account, error) {
f.updateCallCount.Add(1)
if id == f.failOnAccountID {
return nil, errors.New("database error")
}
return f.stubAdminService.UpdateAccount(ctx, id, input)
}
func setupAccountHandlerWithService(adminSvc service.AdminService) (*gin.Engine, *AccountHandler) {
gin.SetMode(gin.TestMode)
router := gin.New()
handler := NewAccountHandler(adminSvc, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
router.POST("/api/v1/admin/accounts/batch-update-credentials", handler.BatchUpdateCredentials)
return router, handler
}
func TestBatchUpdateCredentials_AllSuccess(t *testing.T) {
svc := &failingAdminService{stubAdminService: newStubAdminService()}
router, _ := setupAccountHandlerWithService(svc)
body, _ := json.Marshal(BatchUpdateCredentialsRequest{
AccountIDs: []int64{1, 2, 3},
Field: "account_uuid",
Value: "test-uuid",
})
w := httptest.NewRecorder()
req, _ := http.NewRequest("POST", "/api/v1/admin/accounts/batch-update-credentials", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(w, req)
require.Equal(t, http.StatusOK, w.Code, "全部成功时应返回 200")
require.Equal(t, int64(3), svc.updateCallCount.Load(), "应调用 3 次 UpdateAccount")
}
func TestBatchUpdateCredentials_FailFast(t *testing.T) {
// 让第 2 个账号ID=2更新时失败
svc := &failingAdminService{
stubAdminService: newStubAdminService(),
failOnAccountID: 2,
}
router, _ := setupAccountHandlerWithService(svc)
body, _ := json.Marshal(BatchUpdateCredentialsRequest{
AccountIDs: []int64{1, 2, 3},
Field: "org_uuid",
Value: "test-org",
})
w := httptest.NewRecorder()
req, _ := http.NewRequest("POST", "/api/v1/admin/accounts/batch-update-credentials", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(w, req)
require.Equal(t, http.StatusInternalServerError, w.Code, "ID=2 失败时应返回 500")
// 验证 fail-fastID=1 更新成功ID=2 失败ID=3 不应被调用
require.Equal(t, int64(2), svc.updateCallCount.Load(),
"fail-fast: 应只调用 2 次 UpdateAccountID=1 成功、ID=2 失败后停止)")
}
func TestBatchUpdateCredentials_FirstAccountNotFound(t *testing.T) {
// GetAccount 在 stubAdminService 中总是成功的,需要创建一个 GetAccount 会失败的 stub
svc := &getAccountFailingService{
stubAdminService: newStubAdminService(),
failOnAccountID: 1,
}
router, _ := setupAccountHandlerWithService(svc)
body, _ := json.Marshal(BatchUpdateCredentialsRequest{
AccountIDs: []int64{1, 2, 3},
Field: "account_uuid",
Value: "test",
})
w := httptest.NewRecorder()
req, _ := http.NewRequest("POST", "/api/v1/admin/accounts/batch-update-credentials", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(w, req)
require.Equal(t, http.StatusNotFound, w.Code, "第一阶段验证失败应返回 404")
}
// getAccountFailingService 模拟 GetAccount 在特定 ID 时返回 not found。
type getAccountFailingService struct {
*stubAdminService
failOnAccountID int64
}
func (f *getAccountFailingService) GetAccount(ctx context.Context, id int64) (*service.Account, error) {
if id == f.failOnAccountID {
return nil, errors.New("not found")
}
return f.stubAdminService.GetAccount(ctx, id)
}
func TestBatchUpdateCredentials_InterceptWarmupRequests_NonBool(t *testing.T) {
svc := &failingAdminService{stubAdminService: newStubAdminService()}
router, _ := setupAccountHandlerWithService(svc)
// intercept_warmup_requests 传入非 bool 类型string应返回 400
body, _ := json.Marshal(map[string]any{
"account_ids": []int64{1},
"field": "intercept_warmup_requests",
"value": "not-a-bool",
})
w := httptest.NewRecorder()
req, _ := http.NewRequest("POST", "/api/v1/admin/accounts/batch-update-credentials", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(w, req)
require.Equal(t, http.StatusBadRequest, w.Code,
"intercept_warmup_requests 传入非 bool 值应返回 400")
}
func TestBatchUpdateCredentials_InterceptWarmupRequests_ValidBool(t *testing.T) {
svc := &failingAdminService{stubAdminService: newStubAdminService()}
router, _ := setupAccountHandlerWithService(svc)
body, _ := json.Marshal(map[string]any{
"account_ids": []int64{1},
"field": "intercept_warmup_requests",
"value": true,
})
w := httptest.NewRecorder()
req, _ := http.NewRequest("POST", "/api/v1/admin/accounts/batch-update-credentials", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(w, req)
require.Equal(t, http.StatusOK, w.Code,
"intercept_warmup_requests 传入合法 bool 值应返回 200")
}
func TestBatchUpdateCredentials_AccountUUID_NonString(t *testing.T) {
svc := &failingAdminService{stubAdminService: newStubAdminService()}
router, _ := setupAccountHandlerWithService(svc)
// account_uuid 传入非 string 类型number应返回 400
body, _ := json.Marshal(map[string]any{
"account_ids": []int64{1},
"field": "account_uuid",
"value": 12345,
})
w := httptest.NewRecorder()
req, _ := http.NewRequest("POST", "/api/v1/admin/accounts/batch-update-credentials", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(w, req)
require.Equal(t, http.StatusBadRequest, w.Code,
"account_uuid 传入非 string 值应返回 400")
}
func TestBatchUpdateCredentials_AccountUUID_NullValue(t *testing.T) {
svc := &failingAdminService{stubAdminService: newStubAdminService()}
router, _ := setupAccountHandlerWithService(svc)
// account_uuid 传入 null设置为空应正常通过
body, _ := json.Marshal(map[string]any{
"account_ids": []int64{1},
"field": "account_uuid",
"value": nil,
})
w := httptest.NewRecorder()
req, _ := http.NewRequest("POST", "/api/v1/admin/accounts/batch-update-credentials", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(w, req)
require.Equal(t, http.StatusOK, w.Code,
"account_uuid 传入 null 应返回 200")
}

View File

@@ -379,7 +379,7 @@ func (h *DashboardHandler) GetBatchUsersUsage(c *gin.Context) {
return
}
stats, err := h.dashboardService.GetBatchUserUsageStats(c.Request.Context(), req.UserIDs)
stats, err := h.dashboardService.GetBatchUserUsageStats(c.Request.Context(), req.UserIDs, time.Time{}, time.Time{})
if err != nil {
response.Error(c, 500, "Failed to get user usage stats")
return
@@ -407,7 +407,7 @@ func (h *DashboardHandler) GetBatchAPIKeysUsage(c *gin.Context) {
return
}
stats, err := h.dashboardService.GetBatchAPIKeyUsageStats(c.Request.Context(), req.APIKeyIDs)
stats, err := h.dashboardService.GetBatchAPIKeyUsageStats(c.Request.Context(), req.APIKeyIDs, time.Time{}, time.Time{})
if err != nil {
response.Error(c, 500, "Failed to get API key usage stats")
return

View File

@@ -63,6 +63,43 @@ func (h *OpsHandler) GetConcurrencyStats(c *gin.Context) {
response.Success(c, payload)
}
// GetUserConcurrencyStats returns real-time concurrency usage for all active users.
// GET /api/v1/admin/ops/user-concurrency
func (h *OpsHandler) GetUserConcurrencyStats(c *gin.Context) {
if h.opsService == nil {
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
return
}
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
response.ErrorFrom(c, err)
return
}
if !h.opsService.IsRealtimeMonitoringEnabled(c.Request.Context()) {
response.Success(c, gin.H{
"enabled": false,
"user": map[int64]*service.UserConcurrencyInfo{},
"timestamp": time.Now().UTC(),
})
return
}
users, collectedAt, err := h.opsService.GetUserConcurrencyStats(c.Request.Context())
if err != nil {
response.ErrorFrom(c, err)
return
}
payload := gin.H{
"enabled": true,
"user": users,
}
if collectedAt != nil {
payload["timestamp"] = collectedAt.UTC()
}
response.Success(c, payload)
}
// GetAccountAvailability returns account availability statistics.
// GET /api/v1/admin/ops/account-availability
//

View File

@@ -0,0 +1,239 @@
package admin
import (
"context"
"fmt"
"strconv"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
// ExportData exports proxy-only data for migration.
func (h *ProxyHandler) ExportData(c *gin.Context) {
ctx := c.Request.Context()
selectedIDs, err := parseProxyIDs(c)
if err != nil {
response.BadRequest(c, err.Error())
return
}
var proxies []service.Proxy
if len(selectedIDs) > 0 {
proxies, err = h.getProxiesByIDs(ctx, selectedIDs)
if err != nil {
response.ErrorFrom(c, err)
return
}
} else {
protocol := c.Query("protocol")
status := c.Query("status")
search := strings.TrimSpace(c.Query("search"))
if len(search) > 100 {
search = search[:100]
}
proxies, err = h.listProxiesFiltered(ctx, protocol, status, search)
if err != nil {
response.ErrorFrom(c, err)
return
}
}
dataProxies := make([]DataProxy, 0, len(proxies))
for i := range proxies {
p := proxies[i]
key := buildProxyKey(p.Protocol, p.Host, p.Port, p.Username, p.Password)
dataProxies = append(dataProxies, DataProxy{
ProxyKey: key,
Name: p.Name,
Protocol: p.Protocol,
Host: p.Host,
Port: p.Port,
Username: p.Username,
Password: p.Password,
Status: p.Status,
})
}
payload := DataPayload{
ExportedAt: time.Now().UTC().Format(time.RFC3339),
Proxies: dataProxies,
Accounts: []DataAccount{},
}
response.Success(c, payload)
}
// ImportData imports proxy-only data for migration.
func (h *ProxyHandler) ImportData(c *gin.Context) {
type ProxyImportRequest struct {
Data DataPayload `json:"data"`
}
var req ProxyImportRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
if err := validateDataHeader(req.Data); err != nil {
response.BadRequest(c, err.Error())
return
}
ctx := c.Request.Context()
result := DataImportResult{}
existingProxies, err := h.listProxiesFiltered(ctx, "", "", "")
if err != nil {
response.ErrorFrom(c, err)
return
}
proxyByKey := make(map[string]service.Proxy, len(existingProxies))
for i := range existingProxies {
p := existingProxies[i]
key := buildProxyKey(p.Protocol, p.Host, p.Port, p.Username, p.Password)
proxyByKey[key] = p
}
latencyProbeIDs := make([]int64, 0, len(req.Data.Proxies))
for i := range req.Data.Proxies {
item := req.Data.Proxies[i]
key := item.ProxyKey
if key == "" {
key = buildProxyKey(item.Protocol, item.Host, item.Port, item.Username, item.Password)
}
if err := validateDataProxy(item); err != nil {
result.ProxyFailed++
result.Errors = append(result.Errors, DataImportError{
Kind: "proxy",
Name: item.Name,
ProxyKey: key,
Message: err.Error(),
})
continue
}
normalizedStatus := normalizeProxyStatus(item.Status)
if existing, ok := proxyByKey[key]; ok {
result.ProxyReused++
if normalizedStatus != "" && normalizedStatus != existing.Status {
if _, err := h.adminService.UpdateProxy(ctx, existing.ID, &service.UpdateProxyInput{Status: normalizedStatus}); err != nil {
result.Errors = append(result.Errors, DataImportError{
Kind: "proxy",
Name: item.Name,
ProxyKey: key,
Message: "update status failed: " + err.Error(),
})
}
}
latencyProbeIDs = append(latencyProbeIDs, existing.ID)
continue
}
created, err := h.adminService.CreateProxy(ctx, &service.CreateProxyInput{
Name: defaultProxyName(item.Name),
Protocol: item.Protocol,
Host: item.Host,
Port: item.Port,
Username: item.Username,
Password: item.Password,
})
if err != nil {
result.ProxyFailed++
result.Errors = append(result.Errors, DataImportError{
Kind: "proxy",
Name: item.Name,
ProxyKey: key,
Message: err.Error(),
})
continue
}
result.ProxyCreated++
proxyByKey[key] = *created
if normalizedStatus != "" && normalizedStatus != created.Status {
if _, err := h.adminService.UpdateProxy(ctx, created.ID, &service.UpdateProxyInput{Status: normalizedStatus}); err != nil {
result.Errors = append(result.Errors, DataImportError{
Kind: "proxy",
Name: item.Name,
ProxyKey: key,
Message: "update status failed: " + err.Error(),
})
}
}
// CreateProxy already triggers a latency probe, avoid double probing here.
}
if len(latencyProbeIDs) > 0 {
ids := append([]int64(nil), latencyProbeIDs...)
go func() {
for _, id := range ids {
_, _ = h.adminService.TestProxy(context.Background(), id)
}
}()
}
response.Success(c, result)
}
func (h *ProxyHandler) getProxiesByIDs(ctx context.Context, ids []int64) ([]service.Proxy, error) {
if len(ids) == 0 {
return []service.Proxy{}, nil
}
return h.adminService.GetProxiesByIDs(ctx, ids)
}
func parseProxyIDs(c *gin.Context) ([]int64, error) {
values := c.QueryArray("ids")
if len(values) == 0 {
raw := strings.TrimSpace(c.Query("ids"))
if raw != "" {
values = []string{raw}
}
}
if len(values) == 0 {
return nil, nil
}
ids := make([]int64, 0, len(values))
for _, item := range values {
for _, part := range strings.Split(item, ",") {
part = strings.TrimSpace(part)
if part == "" {
continue
}
id, err := strconv.ParseInt(part, 10, 64)
if err != nil || id <= 0 {
return nil, fmt.Errorf("invalid proxy id: %s", part)
}
ids = append(ids, id)
}
}
return ids, nil
}
func (h *ProxyHandler) listProxiesFiltered(ctx context.Context, protocol, status, search string) ([]service.Proxy, error) {
page := 1
pageSize := dataPageCap
var out []service.Proxy
for {
items, total, err := h.adminService.ListProxies(ctx, page, pageSize, protocol, status, search)
if err != nil {
return nil, err
}
out = append(out, items...)
if len(out) >= int(total) || len(items) == 0 {
break
}
page++
}
return out, nil
}

View File

@@ -0,0 +1,188 @@
package admin
import (
"bytes"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
type proxyDataResponse struct {
Code int `json:"code"`
Data DataPayload `json:"data"`
}
type proxyImportResponse struct {
Code int `json:"code"`
Data DataImportResult `json:"data"`
}
func setupProxyDataRouter() (*gin.Engine, *stubAdminService) {
gin.SetMode(gin.TestMode)
router := gin.New()
adminSvc := newStubAdminService()
h := NewProxyHandler(adminSvc)
router.GET("/api/v1/admin/proxies/data", h.ExportData)
router.POST("/api/v1/admin/proxies/data", h.ImportData)
return router, adminSvc
}
func TestProxyExportDataRespectsFilters(t *testing.T) {
router, adminSvc := setupProxyDataRouter()
adminSvc.proxies = []service.Proxy{
{
ID: 1,
Name: "proxy-a",
Protocol: "http",
Host: "127.0.0.1",
Port: 8080,
Username: "user",
Password: "pass",
Status: service.StatusActive,
},
{
ID: 2,
Name: "proxy-b",
Protocol: "https",
Host: "10.0.0.2",
Port: 443,
Username: "u",
Password: "p",
Status: service.StatusDisabled,
},
}
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/proxies/data?protocol=https", nil)
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
var resp proxyDataResponse
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
require.Equal(t, 0, resp.Code)
require.Empty(t, resp.Data.Type)
require.Equal(t, 0, resp.Data.Version)
require.Len(t, resp.Data.Proxies, 1)
require.Len(t, resp.Data.Accounts, 0)
require.Equal(t, "https", resp.Data.Proxies[0].Protocol)
}
func TestProxyExportDataWithSelectedIDs(t *testing.T) {
router, adminSvc := setupProxyDataRouter()
adminSvc.proxies = []service.Proxy{
{
ID: 1,
Name: "proxy-a",
Protocol: "http",
Host: "127.0.0.1",
Port: 8080,
Username: "user",
Password: "pass",
Status: service.StatusActive,
},
{
ID: 2,
Name: "proxy-b",
Protocol: "https",
Host: "10.0.0.2",
Port: 443,
Username: "u",
Password: "p",
Status: service.StatusDisabled,
},
}
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/proxies/data?ids=2", nil)
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
var resp proxyDataResponse
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
require.Equal(t, 0, resp.Code)
require.Len(t, resp.Data.Proxies, 1)
require.Equal(t, "https", resp.Data.Proxies[0].Protocol)
require.Equal(t, "10.0.0.2", resp.Data.Proxies[0].Host)
}
func TestProxyImportDataReusesAndTriggersLatencyProbe(t *testing.T) {
router, adminSvc := setupProxyDataRouter()
adminSvc.proxies = []service.Proxy{
{
ID: 1,
Name: "proxy-a",
Protocol: "http",
Host: "127.0.0.1",
Port: 8080,
Username: "user",
Password: "pass",
Status: service.StatusActive,
},
}
payload := map[string]any{
"data": map[string]any{
"type": dataType,
"version": dataVersion,
"proxies": []map[string]any{
{
"proxy_key": "http|127.0.0.1|8080|user|pass",
"name": "proxy-a",
"protocol": "http",
"host": "127.0.0.1",
"port": 8080,
"username": "user",
"password": "pass",
"status": "inactive",
},
{
"proxy_key": "https|10.0.0.2|443|u|p",
"name": "proxy-b",
"protocol": "https",
"host": "10.0.0.2",
"port": 443,
"username": "u",
"password": "p",
"status": "active",
},
},
"accounts": []map[string]any{},
},
}
body, _ := json.Marshal(payload)
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/proxies/data", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
var resp proxyImportResponse
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
require.Equal(t, 0, resp.Code)
require.Equal(t, 1, resp.Data.ProxyCreated)
require.Equal(t, 1, resp.Data.ProxyReused)
require.Equal(t, 0, resp.Data.ProxyFailed)
adminSvc.mu.Lock()
updatedIDs := append([]int64(nil), adminSvc.updatedProxyIDs...)
adminSvc.mu.Unlock()
require.Contains(t, updatedIDs, int64(1))
require.Eventually(t, func() bool {
adminSvc.mu.Lock()
defer adminSvc.mu.Unlock()
return len(adminSvc.testedProxyIDs) == 1
}, time.Second, 10*time.Millisecond)
}

View File

@@ -0,0 +1,97 @@
//go:build unit
package admin
import (
"testing"
"github.com/stretchr/testify/require"
)
// truncateSearchByRune 模拟 user_handler.go 中的 search 截断逻辑
func truncateSearchByRune(search string, maxRunes int) string {
if runes := []rune(search); len(runes) > maxRunes {
return string(runes[:maxRunes])
}
return search
}
func TestTruncateSearchByRune(t *testing.T) {
tests := []struct {
name string
input string
maxRunes int
wantLen int // 期望的 rune 长度
}{
{
name: "纯中文超长",
input: string(make([]rune, 150)),
maxRunes: 100,
wantLen: 100,
},
{
name: "纯 ASCII 超长",
input: string(make([]byte, 150)),
maxRunes: 100,
wantLen: 100,
},
{
name: "空字符串",
input: "",
maxRunes: 100,
wantLen: 0,
},
{
name: "恰好 100 个字符",
input: string(make([]rune, 100)),
maxRunes: 100,
wantLen: 100,
},
{
name: "不足 100 字符不截断",
input: "hello世界",
maxRunes: 100,
wantLen: 7,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
result := truncateSearchByRune(tc.input, tc.maxRunes)
require.Equal(t, tc.wantLen, len([]rune(result)))
})
}
}
func TestTruncateSearchByRune_PreservesMultibyte(t *testing.T) {
// 101 个中文字符,截断到 100 个后应该仍然是有效 UTF-8
input := ""
for i := 0; i < 101; i++ {
input += "中"
}
result := truncateSearchByRune(input, 100)
require.Equal(t, 100, len([]rune(result)))
// 验证截断结果是有效的 UTF-8每个中文字符 3 字节)
require.Equal(t, 300, len(result))
}
func TestTruncateSearchByRune_MixedASCIIAndMultibyte(t *testing.T) {
// 50 个 ASCII + 51 个中文 = 101 个 rune
input := ""
for i := 0; i < 50; i++ {
input += "a"
}
for i := 0; i < 51; i++ {
input += "中"
}
result := truncateSearchByRune(input, 100)
runes := []rune(result)
require.Equal(t, 100, len(runes))
// 前 50 个应该是 'a',后 50 个应该是 '中'
require.Equal(t, 'a', runes[0])
require.Equal(t, 'a', runes[49])
require.Equal(t, '中', runes[50])
require.Equal(t, '中', runes[99])
}

View File

@@ -70,8 +70,8 @@ func (h *UserHandler) List(c *gin.Context) {
search := c.Query("search")
// 标准化和验证 search 参数
search = strings.TrimSpace(search)
if len(search) > 100 {
search = search[:100]
if runes := []rune(search); len(runes) > 100 {
search = string(runes[:100])
}
filters := service.UserListFilters{

View File

@@ -2,6 +2,7 @@ package handler
import (
"log/slog"
"strings"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
@@ -448,17 +449,12 @@ func (h *AuthHandler) ForgotPassword(c *gin.Context) {
return
}
// Build frontend base URL from request
scheme := "https"
if c.Request.TLS == nil {
// Check X-Forwarded-Proto header (common in reverse proxy setups)
if proto := c.GetHeader("X-Forwarded-Proto"); proto != "" {
scheme = proto
} else {
scheme = "http"
}
frontendBaseURL := strings.TrimSpace(h.cfg.Server.FrontendURL)
if frontendBaseURL == "" {
slog.Error("server.frontend_url not configured; cannot build password reset link")
response.InternalError(c, "Password reset is not configured")
return
}
frontendBaseURL := scheme + "://" + c.Request.Host
// Request password reset (async)
// Note: This returns success even if email doesn't exist (to prevent enumeration)

View File

@@ -215,17 +215,6 @@ func AccountFromServiceShallow(a *service.Account) *Account {
}
}
if scopeLimits := a.GetAntigravityScopeRateLimits(); len(scopeLimits) > 0 {
out.ScopeRateLimits = make(map[string]ScopeRateLimitInfo, len(scopeLimits))
now := time.Now()
for scope, remainingSec := range scopeLimits {
out.ScopeRateLimits[scope] = ScopeRateLimitInfo{
ResetAt: now.Add(time.Duration(remainingSec) * time.Second),
RemainingSec: remainingSec,
}
}
}
return out
}

View File

@@ -2,6 +2,7 @@ package handler
import (
"context"
"crypto/rand"
"encoding/json"
"errors"
"fmt"
@@ -113,9 +114,6 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
return
}
// 检查是否为 Claude Code 客户端,设置到 context 中
SetClaudeCodeClientContext(c, body)
setOpsRequestContext(c, "", false, body)
parsedReq, err := service.ParseGatewayRequest(body)
@@ -126,6 +124,20 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
reqModel := parsedReq.Model
reqStream := parsedReq.Stream
// 设置 max_tokens=1 + haiku 探测请求标识到 context 中
// 必须在 SetClaudeCodeClientContext 之前设置,因为 ClaudeCodeValidator 需要读取此标识进行绕过判断
if isMaxTokensOneHaikuRequest(reqModel, parsedReq.MaxTokens, reqStream) {
ctx := context.WithValue(c.Request.Context(), ctxkey.IsMaxTokensOneHaikuRequest, true)
c.Request = c.Request.WithContext(ctx)
}
// 检查是否为 Claude Code 客户端,设置到 context 中
SetClaudeCodeClientContext(c, body)
isClaudeCodeClient := service.IsClaudeCodeClient(c.Request.Context())
// 在请求上下文中记录 thinking 状态,供 Antigravity 最终模型 key 推导/模型维度限流使用
c.Request = c.Request.WithContext(context.WithValue(c.Request.Context(), ctxkey.ThinkingEnabled, parsedReq.ThinkingEnabled))
setOpsRequestContext(c, reqModel, reqStream, body)
// 验证 model 必填
@@ -137,6 +149,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
// Track if we've started streaming (for error handling)
streamStarted := false
// 绑定错误透传服务,允许 service 层在非 failover 错误场景复用规则。
if h.errorPassthroughService != nil {
service.BindErrorPassthroughService(c, h.errorPassthroughService)
}
// 获取订阅信息可能为nil- 提前获取用于后续检查
subscription, _ := middleware2.GetSubscriptionFromContext(c)
@@ -202,17 +219,27 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
sessionKey = "gemini:" + sessionHash
}
// 查询粘性会话绑定的账号 ID
var sessionBoundAccountID int64
if sessionKey != "" {
sessionBoundAccountID, _ = h.gatewayService.GetCachedSessionAccountID(c.Request.Context(), apiKey.GroupID, sessionKey)
}
// 判断是否真的绑定了粘性会话:有 sessionKey 且已经绑定到某个账号
hasBoundSession := sessionKey != "" && sessionBoundAccountID > 0
if platform == service.PlatformGemini {
maxAccountSwitches := h.maxAccountSwitchesGemini
switchCount := 0
failedAccountIDs := make(map[int64]struct{})
var lastFailoverErr *service.UpstreamFailoverError
var forceCacheBilling bool // 粘性会话切换时的缓存计费标记
for {
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, reqModel, failedAccountIDs, "") // Gemini 不使用会话限制
if err != nil {
if len(failedAccountIDs) == 0 {
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
log.Printf("[Gateway] SelectAccount failed: %v", err)
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "Service temporarily unavailable", streamStarted)
return
}
if lastFailoverErr != nil {
@@ -227,7 +254,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
// 检查请求拦截预热请求、SUGGESTION MODE等
if account.IsInterceptWarmupEnabled() {
interceptType := detectInterceptType(body)
interceptType := detectInterceptType(body, reqModel, parsedReq.MaxTokens, reqStream, isClaudeCodeClient)
if interceptType != InterceptTypeNone {
if selection.Acquired && selection.ReleaseFunc != nil {
selection.ReleaseFunc()
@@ -260,12 +287,12 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
if err == nil && canWait {
accountWaitCounted = true
}
// Ensure the wait counter is decremented if we exit before acquiring the slot.
defer func() {
releaseWait := func() {
if accountWaitCounted {
h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID)
accountWaitCounted = false
}
}()
}
accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout(
c,
@@ -277,14 +304,12 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
)
if err != nil {
log.Printf("Account concurrency acquire failed: %v", err)
releaseWait()
h.handleConcurrencyError(c, err, "account", streamStarted)
return
}
// Slot acquired: no longer waiting in queue.
if accountWaitCounted {
h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID)
accountWaitCounted = false
}
releaseWait()
if err := h.gatewayService.BindStickySession(c.Request.Context(), apiKey.GroupID, sessionKey, account.ID); err != nil {
log.Printf("Bind sticky session failed: %v", err)
}
@@ -299,7 +324,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
requestCtx = context.WithValue(requestCtx, ctxkey.AccountSwitchCount, switchCount)
}
if account.Platform == service.PlatformAntigravity {
result, err = h.antigravityGatewayService.ForwardGemini(requestCtx, c, account, reqModel, "generateContent", reqStream, body)
result, err = h.antigravityGatewayService.ForwardGemini(requestCtx, c, account, reqModel, "generateContent", reqStream, body, hasBoundSession)
} else {
result, err = h.geminiCompatService.Forward(requestCtx, c, account, body)
}
@@ -311,6 +336,9 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
if errors.As(err, &failoverErr) {
failedAccountIDs[account.ID] = struct{}{}
lastFailoverErr = failoverErr
if failoverErr.ForceCacheBilling {
forceCacheBilling = true
}
if switchCount >= maxAccountSwitches {
h.handleFailoverExhausted(c, failoverErr, service.PlatformGemini, streamStarted)
return
@@ -329,22 +357,23 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
clientIP := ip.GetClientIP(c)
// 异步记录使用量subscription已在函数开头获取
go func(result *service.ForwardResult, usedAccount *service.Account, ua, clientIP string) {
go func(result *service.ForwardResult, usedAccount *service.Account, ua, clientIP string, fcb bool) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
Result: result,
APIKey: apiKey,
User: apiKey.User,
Account: usedAccount,
Subscription: subscription,
UserAgent: ua,
IPAddress: clientIP,
APIKeyService: h.apiKeyService,
Result: result,
APIKey: apiKey,
User: apiKey.User,
Account: usedAccount,
Subscription: subscription,
UserAgent: ua,
IPAddress: clientIP,
ForceCacheBilling: fcb,
APIKeyService: h.apiKeyService,
}); err != nil {
log.Printf("Record usage failed: %v", err)
}
}(result, account, userAgent, clientIP)
}(result, account, userAgent, clientIP, forceCacheBilling)
return
}
}
@@ -363,13 +392,15 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
failedAccountIDs := make(map[int64]struct{})
var lastFailoverErr *service.UpstreamFailoverError
retryWithFallback := false
var forceCacheBilling bool // 粘性会话切换时的缓存计费标记
for {
// 选择支持该模型的账号
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), currentAPIKey.GroupID, sessionKey, reqModel, failedAccountIDs, parsedReq.MetadataUserID)
if err != nil {
if len(failedAccountIDs) == 0 {
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
log.Printf("[Gateway] SelectAccount failed: %v", err)
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "Service temporarily unavailable", streamStarted)
return
}
if lastFailoverErr != nil {
@@ -384,7 +415,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
// 检查请求拦截预热请求、SUGGESTION MODE等
if account.IsInterceptWarmupEnabled() {
interceptType := detectInterceptType(body)
interceptType := detectInterceptType(body, reqModel, parsedReq.MaxTokens, reqStream, isClaudeCodeClient)
if interceptType != InterceptTypeNone {
if selection.Acquired && selection.ReleaseFunc != nil {
selection.ReleaseFunc()
@@ -417,11 +448,12 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
if err == nil && canWait {
accountWaitCounted = true
}
defer func() {
releaseWait := func() {
if accountWaitCounted {
h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID)
accountWaitCounted = false
}
}()
}
accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout(
c,
@@ -433,13 +465,12 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
)
if err != nil {
log.Printf("Account concurrency acquire failed: %v", err)
releaseWait()
h.handleConcurrencyError(c, err, "account", streamStarted)
return
}
if accountWaitCounted {
h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID)
accountWaitCounted = false
}
// Slot acquired: no longer waiting in queue.
releaseWait()
if err := h.gatewayService.BindStickySession(c.Request.Context(), currentAPIKey.GroupID, sessionKey, account.ID); err != nil {
log.Printf("Bind sticky session failed: %v", err)
}
@@ -454,7 +485,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
requestCtx = context.WithValue(requestCtx, ctxkey.AccountSwitchCount, switchCount)
}
if account.Platform == service.PlatformAntigravity {
result, err = h.antigravityGatewayService.Forward(requestCtx, c, account, body)
result, err = h.antigravityGatewayService.Forward(requestCtx, c, account, body, hasBoundSession)
} else {
result, err = h.gatewayService.Forward(requestCtx, c, account, parsedReq)
}
@@ -501,6 +532,9 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
if errors.As(err, &failoverErr) {
failedAccountIDs[account.ID] = struct{}{}
lastFailoverErr = failoverErr
if failoverErr.ForceCacheBilling {
forceCacheBilling = true
}
if switchCount >= maxAccountSwitches {
h.handleFailoverExhausted(c, failoverErr, account.Platform, streamStarted)
return
@@ -519,22 +553,23 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
clientIP := ip.GetClientIP(c)
// 异步记录使用量subscription已在函数开头获取
go func(result *service.ForwardResult, usedAccount *service.Account, ua, clientIP string) {
go func(result *service.ForwardResult, usedAccount *service.Account, ua, clientIP string, fcb bool) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
Result: result,
APIKey: currentAPIKey,
User: currentAPIKey.User,
Account: usedAccount,
Subscription: currentSubscription,
UserAgent: ua,
IPAddress: clientIP,
APIKeyService: h.apiKeyService,
Result: result,
APIKey: currentAPIKey,
User: currentAPIKey.User,
Account: usedAccount,
Subscription: currentSubscription,
UserAgent: ua,
IPAddress: clientIP,
ForceCacheBilling: fcb,
APIKeyService: h.apiKeyService,
}); err != nil {
log.Printf("Record usage failed: %v", err)
}
}(result, account, userAgent, clientIP)
}(result, account, userAgent, clientIP, forceCacheBilling)
return
}
if !retryWithFallback {
@@ -917,6 +952,8 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body")
return
}
// 在请求上下文中记录 thinking 状态,供 Antigravity 最终模型 key 推导/模型维度限流使用
c.Request = c.Request.WithContext(context.WithValue(c.Request.Context(), ctxkey.ThinkingEnabled, parsedReq.ThinkingEnabled))
// 验证 model 必填
if parsedReq.Model == "" {
@@ -943,7 +980,8 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
// 选择支持该模型的账号
account, err := h.gatewayService.SelectAccountForModel(c.Request.Context(), apiKey.GroupID, sessionHash, parsedReq.Model)
if err != nil {
h.errorResponse(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error())
log.Printf("[Gateway] SelectAccountForModel failed: %v", err)
h.errorResponse(c, http.StatusServiceUnavailable, "api_error", "Service temporarily unavailable")
return
}
setOpsSelectedAccount(c, account.ID)
@@ -960,13 +998,37 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
type InterceptType int
const (
InterceptTypeNone InterceptType = iota
InterceptTypeWarmup // 预热请求(返回 "New Conversation"
InterceptTypeSuggestionMode // SUGGESTION MODE返回空字符串
InterceptTypeNone InterceptType = iota
InterceptTypeWarmup // 预热请求(返回 "New Conversation"
InterceptTypeSuggestionMode // SUGGESTION MODE返回空字符串
InterceptTypeMaxTokensOneHaiku // max_tokens=1 + haiku 探测请求(返回 "#"
)
// isHaikuModel 检查模型名称是否包含 "haiku"(大小写不敏感)
func isHaikuModel(model string) bool {
return strings.Contains(strings.ToLower(model), "haiku")
}
// isMaxTokensOneHaikuRequest 检查是否为 max_tokens=1 + haiku 模型的探测请求
// 这类请求用于 Claude Code 验证 API 连通性
// 条件max_tokens == 1 且 model 包含 "haiku" 且非流式请求
func isMaxTokensOneHaikuRequest(model string, maxTokens int, isStream bool) bool {
return maxTokens == 1 && isHaikuModel(model) && !isStream
}
// detectInterceptType 检测请求是否需要拦截,返回拦截类型
func detectInterceptType(body []byte) InterceptType {
// 参数说明:
// - body: 请求体字节
// - model: 请求的模型名称
// - maxTokens: max_tokens 值
// - isStream: 是否为流式请求
// - isClaudeCodeClient: 是否已通过 Claude Code 客户端校验
func detectInterceptType(body []byte, model string, maxTokens int, isStream bool, isClaudeCodeClient bool) InterceptType {
// 优先检查 max_tokens=1 + haiku 探测请求(仅非流式)
if isClaudeCodeClient && isMaxTokensOneHaikuRequest(model, maxTokens, isStream) {
return InterceptTypeMaxTokensOneHaiku
}
// 快速检查:如果不包含任何关键字,直接返回
bodyStr := string(body)
hasSuggestionMode := strings.Contains(bodyStr, "[SUGGESTION MODE:")
@@ -1116,9 +1178,25 @@ func sendMockInterceptStream(c *gin.Context, model string, interceptType Interce
}
}
// generateRealisticMsgID 生成仿真的消息 IDmsg_bdrk_XXXXXXX 格式)
// 格式与 Claude API 真实响应一致24 位随机字母数字
func generateRealisticMsgID() string {
const charset = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
const idLen = 24
randomBytes := make([]byte, idLen)
if _, err := rand.Read(randomBytes); err != nil {
return fmt.Sprintf("msg_bdrk_%d", time.Now().UnixNano())
}
b := make([]byte, idLen)
for i := range b {
b[i] = charset[int(randomBytes[i])%len(charset)]
}
return "msg_bdrk_" + string(b)
}
// sendMockInterceptResponse 发送非流式 mock 响应(用于请求拦截)
func sendMockInterceptResponse(c *gin.Context, model string, interceptType InterceptType) {
var msgID, text string
var msgID, text, stopReason string
var outputTokens int
switch interceptType {
@@ -1126,24 +1204,42 @@ func sendMockInterceptResponse(c *gin.Context, model string, interceptType Inter
msgID = "msg_mock_suggestion"
text = ""
outputTokens = 1
stopReason = "end_turn"
case InterceptTypeMaxTokensOneHaiku:
msgID = generateRealisticMsgID()
text = "#"
outputTokens = 1
stopReason = "max_tokens" // max_tokens=1 探测请求的 stop_reason 应为 max_tokens
default: // InterceptTypeWarmup
msgID = "msg_mock_warmup"
text = "New Conversation"
outputTokens = 2
stopReason = "end_turn"
}
c.JSON(http.StatusOK, gin.H{
"id": msgID,
"type": "message",
"role": "assistant",
"model": model,
"content": []gin.H{{"type": "text", "text": text}},
"stop_reason": "end_turn",
// 构建完整的响应格式(与 Claude API 响应格式一致)
response := gin.H{
"model": model,
"id": msgID,
"type": "message",
"role": "assistant",
"content": []gin.H{{"type": "text", "text": text}},
"stop_reason": stopReason,
"stop_sequence": nil,
"usage": gin.H{
"input_tokens": 10,
"input_tokens": 10,
"cache_creation_input_tokens": 0,
"cache_read_input_tokens": 0,
"cache_creation": gin.H{
"ephemeral_5m_input_tokens": 0,
"ephemeral_1h_input_tokens": 0,
},
"output_tokens": outputTokens,
"total_tokens": 10 + outputTokens,
},
})
}
c.JSON(http.StatusOK, response)
}
func billingErrorDetails(err error) (status int, code, message string) {
@@ -1156,7 +1252,8 @@ func billingErrorDetails(err error) (status int, code, message string) {
}
msg := pkgerrors.Message(err)
if msg == "" {
msg = err.Error()
log.Printf("[Gateway] billing error details: %v", err)
msg = "Billing error"
}
return http.StatusForbidden, "billing_error", msg
}

View File

@@ -0,0 +1,65 @@
package handler
import (
"encoding/json"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
func TestDetectInterceptType_MaxTokensOneHaikuRequiresClaudeCodeClient(t *testing.T) {
body := []byte(`{"messages":[{"role":"user","content":[{"type":"text","text":"hello"}]}]}`)
notClaudeCode := detectInterceptType(body, "claude-haiku-4-5", 1, false, false)
require.Equal(t, InterceptTypeNone, notClaudeCode)
isClaudeCode := detectInterceptType(body, "claude-haiku-4-5", 1, false, true)
require.Equal(t, InterceptTypeMaxTokensOneHaiku, isClaudeCode)
}
func TestDetectInterceptType_SuggestionModeUnaffected(t *testing.T) {
body := []byte(`{
"messages":[{
"role":"user",
"content":[{"type":"text","text":"[SUGGESTION MODE:foo]"}]
}],
"system":[]
}`)
got := detectInterceptType(body, "claude-sonnet-4-5", 256, false, false)
require.Equal(t, InterceptTypeSuggestionMode, got)
}
func TestSendMockInterceptResponse_MaxTokensOneHaiku(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
ctx, _ := gin.CreateTestContext(rec)
sendMockInterceptResponse(ctx, "claude-haiku-4-5", InterceptTypeMaxTokensOneHaiku)
require.Equal(t, http.StatusOK, rec.Code)
var response map[string]any
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &response))
require.Equal(t, "max_tokens", response["stop_reason"])
id, ok := response["id"].(string)
require.True(t, ok)
require.True(t, strings.HasPrefix(id, "msg_bdrk_"))
content, ok := response["content"].([]any)
require.True(t, ok)
require.NotEmpty(t, content)
firstBlock, ok := content[0].(map[string]any)
require.True(t, ok)
require.Equal(t, "#", firstBlock["text"])
usage, ok := response["usage"].(map[string]any)
require.True(t, ok)
require.Equal(t, float64(1), usage["output_tokens"])
}

View File

@@ -120,3 +120,24 @@ func TestGeminiCLITmpDirRegex(t *testing.T) {
})
}
}
func TestSafeShortPrefix(t *testing.T) {
tests := []struct {
name string
input string
n int
want string
}{
{name: "空字符串", input: "", n: 8, want: ""},
{name: "长度小于截断值", input: "abc", n: 8, want: "abc"},
{name: "长度等于截断值", input: "12345678", n: 8, want: "12345678"},
{name: "长度大于截断值", input: "1234567890", n: 8, want: "12345678"},
{name: "截断值为0", input: "123456", n: 0, want: "123456"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
require.Equal(t, tt.want, safeShortPrefix(tt.input, tt.n))
})
}
}

View File

@@ -5,6 +5,7 @@ import (
"context"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"errors"
"io"
"log"
@@ -20,6 +21,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/google/uuid"
"github.com/gin-gonic/gin"
)
@@ -207,6 +209,9 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
// 1) user concurrency slot
streamStarted := false
if h.errorPassthroughService != nil {
service.BindErrorPassthroughService(c, h.errorPassthroughService)
}
userReleaseFunc, err := geminiConcurrency.AcquireUserSlotWithWait(c, authSubject.UserID, authSubject.Concurrency, stream, &streamStarted)
if err != nil {
googleError(c, http.StatusTooManyRequests, err.Error())
@@ -247,6 +252,70 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
if sessionKey != "" {
sessionBoundAccountID, _ = h.gatewayService.GetCachedSessionAccountID(c.Request.Context(), apiKey.GroupID, sessionKey)
}
// === Gemini 内容摘要会话 Fallback 逻辑 ===
// 当原有会话标识无效时sessionBoundAccountID == 0尝试基于内容摘要链匹配
var geminiDigestChain string
var geminiPrefixHash string
var geminiSessionUUID string
useDigestFallback := sessionBoundAccountID == 0
if useDigestFallback {
// 解析 Gemini 请求体
var geminiReq antigravity.GeminiRequest
if err := json.Unmarshal(body, &geminiReq); err == nil && len(geminiReq.Contents) > 0 {
// 生成摘要链
geminiDigestChain = service.BuildGeminiDigestChain(&geminiReq)
if geminiDigestChain != "" {
// 生成前缀 hash
userAgent := c.GetHeader("User-Agent")
clientIP := ip.GetClientIP(c)
platform := ""
if apiKey.Group != nil {
platform = apiKey.Group.Platform
}
geminiPrefixHash = service.GenerateGeminiPrefixHash(
authSubject.UserID,
apiKey.ID,
clientIP,
userAgent,
platform,
modelName,
)
// 查找会话
foundUUID, foundAccountID, found := h.gatewayService.FindGeminiSession(
c.Request.Context(),
derefGroupID(apiKey.GroupID),
geminiPrefixHash,
geminiDigestChain,
)
if found {
sessionBoundAccountID = foundAccountID
geminiSessionUUID = foundUUID
log.Printf("[Gemini] Digest fallback matched: uuid=%s, accountID=%d, chain=%s",
safeShortPrefix(foundUUID, 8), foundAccountID, truncateDigestChain(geminiDigestChain))
// 关键:如果原 sessionKey 为空,使用 prefixHash + uuid 作为 sessionKey
// 这样 SelectAccountWithLoadAwareness 的粘性会话逻辑会优先使用匹配到的账号
if sessionKey == "" {
sessionKey = service.GenerateGeminiDigestSessionKey(geminiPrefixHash, foundUUID)
}
_ = h.gatewayService.BindStickySession(c.Request.Context(), apiKey.GroupID, sessionKey, foundAccountID)
} else {
// 生成新的会话 UUID
geminiSessionUUID = uuid.New().String()
// 为新会话也生成 sessionKey用于后续请求的粘性会话
if sessionKey == "" {
sessionKey = service.GenerateGeminiDigestSessionKey(geminiPrefixHash, geminiSessionUUID)
}
}
}
}
}
// 判断是否真的绑定了粘性会话:有 sessionKey 且已经绑定到某个账号
hasBoundSession := sessionKey != "" && sessionBoundAccountID > 0
isCLI := isGeminiCLIRequest(c, body)
cleanedForUnknownBinding := false
@@ -254,6 +323,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
switchCount := 0
failedAccountIDs := make(map[int64]struct{})
var lastFailoverErr *service.UpstreamFailoverError
var forceCacheBilling bool // 粘性会话切换时的缓存计费标记
for {
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, modelName, failedAccountIDs, "") // Gemini 不使用会话限制
@@ -341,7 +411,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
requestCtx = context.WithValue(requestCtx, ctxkey.AccountSwitchCount, switchCount)
}
if account.Platform == service.PlatformAntigravity {
result, err = h.antigravityGatewayService.ForwardGemini(requestCtx, c, account, modelName, action, stream, body)
result, err = h.antigravityGatewayService.ForwardGemini(requestCtx, c, account, modelName, action, stream, body, hasBoundSession)
} else {
result, err = h.geminiCompatService.ForwardNative(requestCtx, c, account, modelName, action, stream, body)
}
@@ -352,6 +422,9 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
var failoverErr *service.UpstreamFailoverError
if errors.As(err, &failoverErr) {
failedAccountIDs[account.ID] = struct{}{}
if failoverErr.ForceCacheBilling {
forceCacheBilling = true
}
if switchCount >= maxAccountSwitches {
lastFailoverErr = failoverErr
h.handleGeminiFailoverExhausted(c, lastFailoverErr)
@@ -371,8 +444,22 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
userAgent := c.GetHeader("User-Agent")
clientIP := ip.GetClientIP(c)
// 保存 Gemini 内容摘要会话(用于 Fallback 匹配)
if useDigestFallback && geminiDigestChain != "" && geminiPrefixHash != "" {
if err := h.gatewayService.SaveGeminiSession(
c.Request.Context(),
derefGroupID(apiKey.GroupID),
geminiPrefixHash,
geminiDigestChain,
geminiSessionUUID,
account.ID,
); err != nil {
log.Printf("[Gemini] Failed to save digest session: %v", err)
}
}
// 6) record usage async (Gemini 使用长上下文双倍计费)
go func(result *service.ForwardResult, usedAccount *service.Account, ua, ip string) {
go func(result *service.ForwardResult, usedAccount *service.Account, ua, ip string, fcb bool) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
@@ -386,11 +473,12 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
IPAddress: ip,
LongContextThreshold: 200000, // Gemini 200K 阈值
LongContextMultiplier: 2.0, // 超出部分双倍计费
ForceCacheBilling: fcb,
APIKeyService: h.apiKeyService,
}); err != nil {
log.Printf("Record usage failed: %v", err)
}
}(result, account, userAgent, clientIP)
}(result, account, userAgent, clientIP, forceCacheBilling)
return
}
}
@@ -553,3 +641,28 @@ func extractGeminiCLISessionHash(c *gin.Context, body []byte) string {
// 如果没有 privileged-user-id直接使用 tmp 目录哈希
return tmpDirHash
}
// truncateDigestChain 截断摘要链用于日志显示
func truncateDigestChain(chain string) string {
if len(chain) <= 50 {
return chain
}
return chain[:50] + "..."
}
// safeShortPrefix 返回字符串前 n 个字符;长度不足时返回原字符串。
// 用于日志展示,避免切片越界。
func safeShortPrefix(value string, n int) string {
if n <= 0 || len(value) <= n {
return value
}
return value[:n]
}
// derefGroupID 安全解引用 *int64nil 返回 0
func derefGroupID(groupID *int64) int64 {
if groupID == nil {
return 0
}
return *groupID
}

View File

@@ -28,6 +28,7 @@ type OpenAIGatewayHandler struct {
errorPassthroughService *service.ErrorPassthroughService
concurrencyHelper *ConcurrencyHelper
maxAccountSwitches int
cfg *config.Config
}
// NewOpenAIGatewayHandler creates a new OpenAIGatewayHandler
@@ -54,6 +55,7 @@ func NewOpenAIGatewayHandler(
errorPassthroughService: errorPassthroughService,
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatComment, pingInterval),
maxAccountSwitches: maxAccountSwitches,
cfg: cfg,
}
}
@@ -109,7 +111,8 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
}
userAgent := c.GetHeader("User-Agent")
if !openai.IsCodexCLIRequest(userAgent) {
isCodexCLI := openai.IsCodexCLIRequest(userAgent) || (h.cfg != nil && h.cfg.Gateway.ForceCodexCLI)
if !isCodexCLI {
existingInstructions, _ := reqBody["instructions"].(string)
if strings.TrimSpace(existingInstructions) == "" {
if instructions := strings.TrimSpace(service.GetOpenCodeInstructions()); instructions != "" {
@@ -149,6 +152,11 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
// Track if we've started streaming (for error handling)
streamStarted := false
// 绑定错误透传服务,允许 service 层在非 failover 错误场景复用规则。
if h.errorPassthroughService != nil {
service.BindErrorPassthroughService(c, h.errorPassthroughService)
}
// Get subscription info (may be nil)
subscription, _ := middleware2.GetSubscriptionFromContext(c)
@@ -213,7 +221,8 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
if err != nil {
log.Printf("[OpenAI Handler] SelectAccount failed: %v", err)
if len(failedAccountIDs) == 0 {
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
log.Printf("[OpenAI Gateway] SelectAccount failed: %v", err)
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "Service temporarily unavailable", streamStarted)
return
}
if lastFailoverErr != nil {
@@ -246,11 +255,12 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
if err == nil && canWait {
accountWaitCounted = true
}
defer func() {
releaseWait := func() {
if accountWaitCounted {
h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID)
accountWaitCounted = false
}
}()
}
accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout(
c,
@@ -262,13 +272,12 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
)
if err != nil {
log.Printf("Account concurrency acquire failed: %v", err)
releaseWait()
h.handleConcurrencyError(c, err, "account", streamStarted)
return
}
if accountWaitCounted {
h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID)
accountWaitCounted = false
}
// Slot acquired: no longer waiting in queue.
releaseWait()
if err := h.gatewayService.BindStickySession(c.Request.Context(), apiKey.GroupID, sessionHash, account.ID); err != nil {
log.Printf("Bind sticky session failed: %v", err)
}

View File

@@ -392,7 +392,7 @@ func (h *UsageHandler) DashboardAPIKeysUsage(c *gin.Context) {
return
}
stats, err := h.usageService.GetBatchAPIKeyUsageStats(c.Request.Context(), validAPIKeyIDs)
stats, err := h.usageService.GetBatchAPIKeyUsageStats(c.Request.Context(), validAPIKeyIDs, time.Time{}, time.Time{})
if err != nil {
response.ErrorFrom(c, err)
return