Merge remote-tracking branch 'upstream/main'
This commit is contained in:
@@ -52,6 +52,15 @@ type Config struct {
|
||||
RunMode string `mapstructure:"run_mode" yaml:"run_mode"`
|
||||
Timezone string `mapstructure:"timezone"` // e.g. "Asia/Shanghai", "UTC"
|
||||
Gemini GeminiConfig `mapstructure:"gemini"`
|
||||
Update UpdateConfig `mapstructure:"update"`
|
||||
}
|
||||
|
||||
// UpdateConfig 在线更新相关配置
|
||||
type UpdateConfig struct {
|
||||
// ProxyURL 用于访问 GitHub 的代理地址
|
||||
// 支持 http/https/socks5/socks5h 协议
|
||||
// 例如: "http://127.0.0.1:7890", "socks5://127.0.0.1:1080"
|
||||
ProxyURL string `mapstructure:"proxy_url"`
|
||||
}
|
||||
|
||||
type GeminiConfig struct {
|
||||
@@ -148,7 +157,7 @@ type CSPConfig struct {
|
||||
}
|
||||
|
||||
type ProxyProbeConfig struct {
|
||||
InsecureSkipVerify bool `mapstructure:"insecure_skip_verify"`
|
||||
InsecureSkipVerify bool `mapstructure:"insecure_skip_verify"` // 已禁用:禁止跳过 TLS 证书验证
|
||||
}
|
||||
|
||||
type BillingConfig struct {
|
||||
@@ -448,8 +457,8 @@ func setDefaults() {
|
||||
"raw.githubusercontent.com",
|
||||
})
|
||||
viper.SetDefault("security.url_allowlist.crs_hosts", []string{})
|
||||
viper.SetDefault("security.url_allowlist.allow_private_hosts", false)
|
||||
viper.SetDefault("security.url_allowlist.allow_insecure_http", false)
|
||||
viper.SetDefault("security.url_allowlist.allow_private_hosts", true)
|
||||
viper.SetDefault("security.url_allowlist.allow_insecure_http", true)
|
||||
viper.SetDefault("security.response_headers.enabled", false)
|
||||
viper.SetDefault("security.response_headers.additional_allowed", []string{})
|
||||
viper.SetDefault("security.response_headers.force_remove", []string{})
|
||||
@@ -558,6 +567,10 @@ func setDefaults() {
|
||||
viper.SetDefault("gemini.oauth.client_secret", "")
|
||||
viper.SetDefault("gemini.oauth.scopes", "")
|
||||
viper.SetDefault("gemini.quota.policy", "")
|
||||
|
||||
// Update - 在线更新配置
|
||||
// 代理地址为空表示直连 GitHub(适用于海外服务器)
|
||||
viper.SetDefault("update.proxy_url", "")
|
||||
}
|
||||
|
||||
func (c *Config) Validate() error {
|
||||
|
||||
@@ -80,8 +80,11 @@ func TestLoadDefaultSecurityToggles(t *testing.T) {
|
||||
if cfg.Security.URLAllowlist.Enabled {
|
||||
t.Fatalf("URLAllowlist.Enabled = true, want false")
|
||||
}
|
||||
if cfg.Security.URLAllowlist.AllowInsecureHTTP {
|
||||
t.Fatalf("URLAllowlist.AllowInsecureHTTP = true, want false")
|
||||
if !cfg.Security.URLAllowlist.AllowInsecureHTTP {
|
||||
t.Fatalf("URLAllowlist.AllowInsecureHTTP = false, want true")
|
||||
}
|
||||
if !cfg.Security.URLAllowlist.AllowPrivateHosts {
|
||||
t.Fatalf("URLAllowlist.AllowPrivateHosts = false, want true")
|
||||
}
|
||||
if cfg.Security.ResponseHeaders.Enabled {
|
||||
t.Fatalf("ResponseHeaders.Enabled = true, want false")
|
||||
|
||||
@@ -85,6 +85,8 @@ type CreateAccountRequest struct {
|
||||
Concurrency int `json:"concurrency"`
|
||||
Priority int `json:"priority"`
|
||||
GroupIDs []int64 `json:"group_ids"`
|
||||
ExpiresAt *int64 `json:"expires_at"`
|
||||
AutoPauseOnExpired *bool `json:"auto_pause_on_expired"`
|
||||
ConfirmMixedChannelRisk *bool `json:"confirm_mixed_channel_risk"` // 用户确认混合渠道风险
|
||||
}
|
||||
|
||||
@@ -101,6 +103,8 @@ type UpdateAccountRequest struct {
|
||||
Priority *int `json:"priority"`
|
||||
Status string `json:"status" binding:"omitempty,oneof=active inactive"`
|
||||
GroupIDs *[]int64 `json:"group_ids"`
|
||||
ExpiresAt *int64 `json:"expires_at"`
|
||||
AutoPauseOnExpired *bool `json:"auto_pause_on_expired"`
|
||||
ConfirmMixedChannelRisk *bool `json:"confirm_mixed_channel_risk"` // 用户确认混合渠道风险
|
||||
}
|
||||
|
||||
@@ -204,6 +208,8 @@ func (h *AccountHandler) Create(c *gin.Context) {
|
||||
Concurrency: req.Concurrency,
|
||||
Priority: req.Priority,
|
||||
GroupIDs: req.GroupIDs,
|
||||
ExpiresAt: req.ExpiresAt,
|
||||
AutoPauseOnExpired: req.AutoPauseOnExpired,
|
||||
SkipMixedChannelCheck: skipCheck,
|
||||
})
|
||||
if err != nil {
|
||||
@@ -261,6 +267,8 @@ func (h *AccountHandler) Update(c *gin.Context) {
|
||||
Priority: req.Priority, // 指针类型,nil 表示未提供
|
||||
Status: req.Status,
|
||||
GroupIDs: req.GroupIDs,
|
||||
ExpiresAt: req.ExpiresAt,
|
||||
AutoPauseOnExpired: req.AutoPauseOnExpired,
|
||||
SkipMixedChannelCheck: skipCheck,
|
||||
})
|
||||
if err != nil {
|
||||
|
||||
@@ -26,31 +26,33 @@ func NewDashboardHandler(dashboardService *service.DashboardService) *DashboardH
|
||||
}
|
||||
|
||||
// parseTimeRange parses start_date, end_date query parameters
|
||||
// Uses user's timezone if provided, otherwise falls back to server timezone
|
||||
func parseTimeRange(c *gin.Context) (time.Time, time.Time) {
|
||||
now := timezone.Now()
|
||||
userTZ := c.Query("timezone") // Get user's timezone from request
|
||||
now := timezone.NowInUserLocation(userTZ)
|
||||
startDate := c.Query("start_date")
|
||||
endDate := c.Query("end_date")
|
||||
|
||||
var startTime, endTime time.Time
|
||||
|
||||
if startDate != "" {
|
||||
if t, err := timezone.ParseInLocation("2006-01-02", startDate); err == nil {
|
||||
if t, err := timezone.ParseInUserLocation("2006-01-02", startDate, userTZ); err == nil {
|
||||
startTime = t
|
||||
} else {
|
||||
startTime = timezone.StartOfDay(now.AddDate(0, 0, -7))
|
||||
startTime = timezone.StartOfDayInUserLocation(now.AddDate(0, 0, -7), userTZ)
|
||||
}
|
||||
} else {
|
||||
startTime = timezone.StartOfDay(now.AddDate(0, 0, -7))
|
||||
startTime = timezone.StartOfDayInUserLocation(now.AddDate(0, 0, -7), userTZ)
|
||||
}
|
||||
|
||||
if endDate != "" {
|
||||
if t, err := timezone.ParseInLocation("2006-01-02", endDate); err == nil {
|
||||
if t, err := timezone.ParseInUserLocation("2006-01-02", endDate, userTZ); err == nil {
|
||||
endTime = t.Add(24 * time.Hour) // Include the end date
|
||||
} else {
|
||||
endTime = timezone.StartOfDay(now.AddDate(0, 0, 1))
|
||||
endTime = timezone.StartOfDayInUserLocation(now.AddDate(0, 0, 1), userTZ)
|
||||
}
|
||||
} else {
|
||||
endTime = timezone.StartOfDay(now.AddDate(0, 0, 1))
|
||||
endTime = timezone.StartOfDayInUserLocation(now.AddDate(0, 0, 1), userTZ)
|
||||
}
|
||||
|
||||
return startTime, endTime
|
||||
|
||||
@@ -102,8 +102,9 @@ func (h *UsageHandler) List(c *gin.Context) {
|
||||
|
||||
// Parse date range
|
||||
var startTime, endTime *time.Time
|
||||
userTZ := c.Query("timezone") // Get user's timezone from request
|
||||
if startDateStr := c.Query("start_date"); startDateStr != "" {
|
||||
t, err := timezone.ParseInLocation("2006-01-02", startDateStr)
|
||||
t, err := timezone.ParseInUserLocation("2006-01-02", startDateStr, userTZ)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid start_date format, use YYYY-MM-DD")
|
||||
return
|
||||
@@ -112,7 +113,7 @@ func (h *UsageHandler) List(c *gin.Context) {
|
||||
}
|
||||
|
||||
if endDateStr := c.Query("end_date"); endDateStr != "" {
|
||||
t, err := timezone.ParseInLocation("2006-01-02", endDateStr)
|
||||
t, err := timezone.ParseInUserLocation("2006-01-02", endDateStr, userTZ)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid end_date format, use YYYY-MM-DD")
|
||||
return
|
||||
@@ -151,8 +152,8 @@ func (h *UsageHandler) List(c *gin.Context) {
|
||||
// Stats handles getting usage statistics with filters
|
||||
// GET /api/v1/admin/usage/stats
|
||||
func (h *UsageHandler) Stats(c *gin.Context) {
|
||||
// Parse filters
|
||||
var userID, apiKeyID int64
|
||||
// Parse filters - same as List endpoint
|
||||
var userID, apiKeyID, accountID, groupID int64
|
||||
if userIDStr := c.Query("user_id"); userIDStr != "" {
|
||||
id, err := strconv.ParseInt(userIDStr, 10, 64)
|
||||
if err != nil {
|
||||
@@ -171,8 +172,50 @@ func (h *UsageHandler) Stats(c *gin.Context) {
|
||||
apiKeyID = id
|
||||
}
|
||||
|
||||
if accountIDStr := c.Query("account_id"); accountIDStr != "" {
|
||||
id, err := strconv.ParseInt(accountIDStr, 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid account_id")
|
||||
return
|
||||
}
|
||||
accountID = id
|
||||
}
|
||||
|
||||
if groupIDStr := c.Query("group_id"); groupIDStr != "" {
|
||||
id, err := strconv.ParseInt(groupIDStr, 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid group_id")
|
||||
return
|
||||
}
|
||||
groupID = id
|
||||
}
|
||||
|
||||
model := c.Query("model")
|
||||
|
||||
var stream *bool
|
||||
if streamStr := c.Query("stream"); streamStr != "" {
|
||||
val, err := strconv.ParseBool(streamStr)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid stream value, use true or false")
|
||||
return
|
||||
}
|
||||
stream = &val
|
||||
}
|
||||
|
||||
var billingType *int8
|
||||
if billingTypeStr := c.Query("billing_type"); billingTypeStr != "" {
|
||||
val, err := strconv.ParseInt(billingTypeStr, 10, 8)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid billing_type")
|
||||
return
|
||||
}
|
||||
bt := int8(val)
|
||||
billingType = &bt
|
||||
}
|
||||
|
||||
// Parse date range
|
||||
now := timezone.Now()
|
||||
userTZ := c.Query("timezone")
|
||||
now := timezone.NowInUserLocation(userTZ)
|
||||
var startTime, endTime time.Time
|
||||
|
||||
startDateStr := c.Query("start_date")
|
||||
@@ -180,12 +223,12 @@ func (h *UsageHandler) Stats(c *gin.Context) {
|
||||
|
||||
if startDateStr != "" && endDateStr != "" {
|
||||
var err error
|
||||
startTime, err = timezone.ParseInLocation("2006-01-02", startDateStr)
|
||||
startTime, err = timezone.ParseInUserLocation("2006-01-02", startDateStr, userTZ)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid start_date format, use YYYY-MM-DD")
|
||||
return
|
||||
}
|
||||
endTime, err = timezone.ParseInLocation("2006-01-02", endDateStr)
|
||||
endTime, err = timezone.ParseInUserLocation("2006-01-02", endDateStr, userTZ)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid end_date format, use YYYY-MM-DD")
|
||||
return
|
||||
@@ -195,39 +238,31 @@ func (h *UsageHandler) Stats(c *gin.Context) {
|
||||
period := c.DefaultQuery("period", "today")
|
||||
switch period {
|
||||
case "today":
|
||||
startTime = timezone.StartOfDay(now)
|
||||
startTime = timezone.StartOfDayInUserLocation(now, userTZ)
|
||||
case "week":
|
||||
startTime = now.AddDate(0, 0, -7)
|
||||
case "month":
|
||||
startTime = now.AddDate(0, -1, 0)
|
||||
default:
|
||||
startTime = timezone.StartOfDay(now)
|
||||
startTime = timezone.StartOfDayInUserLocation(now, userTZ)
|
||||
}
|
||||
endTime = now
|
||||
}
|
||||
|
||||
if apiKeyID > 0 {
|
||||
stats, err := h.usageService.GetStatsByAPIKey(c.Request.Context(), apiKeyID, startTime, endTime)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, stats)
|
||||
return
|
||||
// Build filters and call GetStatsWithFilters
|
||||
filters := usagestats.UsageLogFilters{
|
||||
UserID: userID,
|
||||
APIKeyID: apiKeyID,
|
||||
AccountID: accountID,
|
||||
GroupID: groupID,
|
||||
Model: model,
|
||||
Stream: stream,
|
||||
BillingType: billingType,
|
||||
StartTime: &startTime,
|
||||
EndTime: &endTime,
|
||||
}
|
||||
|
||||
if userID > 0 {
|
||||
stats, err := h.usageService.GetStatsByUser(c.Request.Context(), userID, startTime, endTime)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, stats)
|
||||
return
|
||||
}
|
||||
|
||||
// Get global stats
|
||||
stats, err := h.usageService.GetGlobalStats(c.Request.Context(), startTime, endTime)
|
||||
stats, err := h.usageService.GetStatsWithFilters(c.Request.Context(), filters)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
|
||||
@@ -1,7 +1,11 @@
|
||||
// Package dto provides data transfer objects for HTTP handlers.
|
||||
package dto
|
||||
|
||||
import "github.com/Wei-Shaw/sub2api/internal/service"
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
func UserFromServiceShallow(u *service.User) *User {
|
||||
if u == nil {
|
||||
@@ -120,6 +124,8 @@ func AccountFromServiceShallow(a *service.Account) *Account {
|
||||
Status: a.Status,
|
||||
ErrorMessage: a.ErrorMessage,
|
||||
LastUsedAt: a.LastUsedAt,
|
||||
ExpiresAt: timeToUnixSeconds(a.ExpiresAt),
|
||||
AutoPauseOnExpired: a.AutoPauseOnExpired,
|
||||
CreatedAt: a.CreatedAt,
|
||||
UpdatedAt: a.UpdatedAt,
|
||||
Schedulable: a.Schedulable,
|
||||
@@ -157,6 +163,14 @@ func AccountFromService(a *service.Account) *Account {
|
||||
return out
|
||||
}
|
||||
|
||||
func timeToUnixSeconds(value *time.Time) *int64 {
|
||||
if value == nil {
|
||||
return nil
|
||||
}
|
||||
ts := value.Unix()
|
||||
return &ts
|
||||
}
|
||||
|
||||
func AccountGroupFromService(ag *service.AccountGroup) *AccountGroup {
|
||||
if ag == nil {
|
||||
return nil
|
||||
|
||||
@@ -60,21 +60,23 @@ type Group struct {
|
||||
}
|
||||
|
||||
type Account struct {
|
||||
ID int64 `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Notes *string `json:"notes"`
|
||||
Platform string `json:"platform"`
|
||||
Type string `json:"type"`
|
||||
Credentials map[string]any `json:"credentials"`
|
||||
Extra map[string]any `json:"extra"`
|
||||
ProxyID *int64 `json:"proxy_id"`
|
||||
Concurrency int `json:"concurrency"`
|
||||
Priority int `json:"priority"`
|
||||
Status string `json:"status"`
|
||||
ErrorMessage string `json:"error_message"`
|
||||
LastUsedAt *time.Time `json:"last_used_at"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
ID int64 `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Notes *string `json:"notes"`
|
||||
Platform string `json:"platform"`
|
||||
Type string `json:"type"`
|
||||
Credentials map[string]any `json:"credentials"`
|
||||
Extra map[string]any `json:"extra"`
|
||||
ProxyID *int64 `json:"proxy_id"`
|
||||
Concurrency int `json:"concurrency"`
|
||||
Priority int `json:"priority"`
|
||||
Status string `json:"status"`
|
||||
ErrorMessage string `json:"error_message"`
|
||||
LastUsedAt *time.Time `json:"last_used_at"`
|
||||
ExpiresAt *int64 `json:"expires_at"`
|
||||
AutoPauseOnExpired bool `json:"auto_pause_on_expired"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
|
||||
Schedulable bool `json:"schedulable"`
|
||||
|
||||
|
||||
@@ -108,6 +108,9 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
// 获取订阅信息(可能为nil)- 提前获取用于后续检查
|
||||
subscription, _ := middleware2.GetSubscriptionFromContext(c)
|
||||
|
||||
// 获取 User-Agent
|
||||
userAgent := c.Request.UserAgent()
|
||||
|
||||
// 0. 检查wait队列是否已满
|
||||
maxWait := service.CalculateMaxWait(subject.Concurrency)
|
||||
canWait, err := h.concurrencyHelper.IncrementWaitCount(c.Request.Context(), subject.UserID, maxWait)
|
||||
@@ -267,7 +270,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
}
|
||||
|
||||
// 异步记录使用量(subscription已在函数开头获取)
|
||||
go func(result *service.ForwardResult, usedAccount *service.Account) {
|
||||
go func(result *service.ForwardResult, usedAccount *service.Account, ua string) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
|
||||
@@ -276,10 +279,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
User: apiKey.User,
|
||||
Account: usedAccount,
|
||||
Subscription: subscription,
|
||||
UserAgent: ua,
|
||||
}); err != nil {
|
||||
log.Printf("Record usage failed: %v", err)
|
||||
}
|
||||
}(result, account)
|
||||
}(result, account, userAgent)
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -394,7 +398,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
}
|
||||
|
||||
// 异步记录使用量(subscription已在函数开头获取)
|
||||
go func(result *service.ForwardResult, usedAccount *service.Account) {
|
||||
go func(result *service.ForwardResult, usedAccount *service.Account, ua string) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
|
||||
@@ -403,10 +407,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
User: apiKey.User,
|
||||
Account: usedAccount,
|
||||
Subscription: subscription,
|
||||
UserAgent: ua,
|
||||
}); err != nil {
|
||||
log.Printf("Record usage failed: %v", err)
|
||||
}
|
||||
}(result, account)
|
||||
}(result, account, userAgent)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
@@ -83,19 +83,33 @@ func NewConcurrencyHelper(concurrencyService *service.ConcurrencyService, pingFo
|
||||
|
||||
// wrapReleaseOnDone ensures release runs at most once and still triggers on context cancellation.
|
||||
// 用于避免客户端断开或上游超时导致的并发槽位泄漏。
|
||||
// 修复:添加 quit channel 确保 goroutine 及时退出,避免泄露
|
||||
func wrapReleaseOnDone(ctx context.Context, releaseFunc func()) func() {
|
||||
if releaseFunc == nil {
|
||||
return nil
|
||||
}
|
||||
var once sync.Once
|
||||
wrapped := func() {
|
||||
once.Do(releaseFunc)
|
||||
quit := make(chan struct{})
|
||||
|
||||
release := func() {
|
||||
once.Do(func() {
|
||||
releaseFunc()
|
||||
close(quit) // 通知监听 goroutine 退出
|
||||
})
|
||||
}
|
||||
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
wrapped()
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
// Context 取消时释放资源
|
||||
release()
|
||||
case <-quit:
|
||||
// 正常释放已完成,goroutine 退出
|
||||
return
|
||||
}
|
||||
}()
|
||||
return wrapped
|
||||
|
||||
return release
|
||||
}
|
||||
|
||||
// IncrementWaitCount increments the wait count for a user
|
||||
|
||||
141
backend/internal/handler/gateway_helper_test.go
Normal file
141
backend/internal/handler/gateway_helper_test.go
Normal file
@@ -0,0 +1,141 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"runtime"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// TestWrapReleaseOnDone_NoGoroutineLeak 验证 wrapReleaseOnDone 修复后不会泄露 goroutine
|
||||
func TestWrapReleaseOnDone_NoGoroutineLeak(t *testing.T) {
|
||||
// 记录测试开始时的 goroutine 数量
|
||||
runtime.GC()
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
initialGoroutines := runtime.NumGoroutine()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
var releaseCount int32
|
||||
release := wrapReleaseOnDone(ctx, func() {
|
||||
atomic.AddInt32(&releaseCount, 1)
|
||||
})
|
||||
|
||||
// 正常释放
|
||||
release()
|
||||
|
||||
// 等待足够时间确保 goroutine 退出
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
|
||||
// 验证只释放一次
|
||||
if count := atomic.LoadInt32(&releaseCount); count != 1 {
|
||||
t.Errorf("expected release count to be 1, got %d", count)
|
||||
}
|
||||
|
||||
// 强制 GC,清理已退出的 goroutine
|
||||
runtime.GC()
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// 验证 goroutine 数量没有增加(允许±2的误差,考虑到测试框架本身可能创建的 goroutine)
|
||||
finalGoroutines := runtime.NumGoroutine()
|
||||
if finalGoroutines > initialGoroutines+2 {
|
||||
t.Errorf("goroutine leak detected: initial=%d, final=%d, leaked=%d",
|
||||
initialGoroutines, finalGoroutines, finalGoroutines-initialGoroutines)
|
||||
}
|
||||
}
|
||||
|
||||
// TestWrapReleaseOnDone_ContextCancellation 验证 context 取消时也能正确释放
|
||||
func TestWrapReleaseOnDone_ContextCancellation(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
var releaseCount int32
|
||||
_ = wrapReleaseOnDone(ctx, func() {
|
||||
atomic.AddInt32(&releaseCount, 1)
|
||||
})
|
||||
|
||||
// 取消 context,应该触发释放
|
||||
cancel()
|
||||
|
||||
// 等待释放完成
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// 验证释放被调用
|
||||
if count := atomic.LoadInt32(&releaseCount); count != 1 {
|
||||
t.Errorf("expected release count to be 1, got %d", count)
|
||||
}
|
||||
}
|
||||
|
||||
// TestWrapReleaseOnDone_MultipleCallsOnlyReleaseOnce 验证多次调用 release 只释放一次
|
||||
func TestWrapReleaseOnDone_MultipleCallsOnlyReleaseOnce(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
var releaseCount int32
|
||||
release := wrapReleaseOnDone(ctx, func() {
|
||||
atomic.AddInt32(&releaseCount, 1)
|
||||
})
|
||||
|
||||
// 调用多次
|
||||
release()
|
||||
release()
|
||||
release()
|
||||
|
||||
// 等待执行完成
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// 验证只释放一次
|
||||
if count := atomic.LoadInt32(&releaseCount); count != 1 {
|
||||
t.Errorf("expected release count to be 1, got %d", count)
|
||||
}
|
||||
}
|
||||
|
||||
// TestWrapReleaseOnDone_NilReleaseFunc 验证 nil releaseFunc 不会 panic
|
||||
func TestWrapReleaseOnDone_NilReleaseFunc(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
release := wrapReleaseOnDone(ctx, nil)
|
||||
|
||||
if release != nil {
|
||||
t.Error("expected nil release function when releaseFunc is nil")
|
||||
}
|
||||
}
|
||||
|
||||
// TestWrapReleaseOnDone_ConcurrentCalls 验证并发调用的安全性
|
||||
func TestWrapReleaseOnDone_ConcurrentCalls(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
var releaseCount int32
|
||||
release := wrapReleaseOnDone(ctx, func() {
|
||||
atomic.AddInt32(&releaseCount, 1)
|
||||
})
|
||||
|
||||
// 并发调用 release
|
||||
const numGoroutines = 10
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
go release()
|
||||
}
|
||||
|
||||
// 等待所有 goroutine 完成
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
|
||||
// 验证只释放一次
|
||||
if count := atomic.LoadInt32(&releaseCount); count != 1 {
|
||||
t.Errorf("expected release count to be 1, got %d", count)
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkWrapReleaseOnDone 性能基准测试
|
||||
func BenchmarkWrapReleaseOnDone(b *testing.B) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
release := wrapReleaseOnDone(ctx, func() {})
|
||||
release()
|
||||
}
|
||||
}
|
||||
@@ -164,6 +164,9 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
// Get subscription (may be nil)
|
||||
subscription, _ := middleware.GetSubscriptionFromContext(c)
|
||||
|
||||
// 获取 User-Agent
|
||||
userAgent := c.Request.UserAgent()
|
||||
|
||||
// For Gemini native API, do not send Claude-style ping frames.
|
||||
geminiConcurrency := NewConcurrencyHelper(h.concurrencyHelper.concurrencyService, SSEPingFormatNone, 0)
|
||||
|
||||
@@ -300,7 +303,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
}
|
||||
|
||||
// 6) record usage async
|
||||
go func(result *service.ForwardResult, usedAccount *service.Account) {
|
||||
go func(result *service.ForwardResult, usedAccount *service.Account, ua string) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
|
||||
@@ -309,10 +312,11 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
User: apiKey.User,
|
||||
Account: usedAccount,
|
||||
Subscription: subscription,
|
||||
UserAgent: ua,
|
||||
}); err != nil {
|
||||
log.Printf("Record usage failed: %v", err)
|
||||
}
|
||||
}(result, account)
|
||||
}(result, account, userAgent)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
@@ -242,7 +242,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
}
|
||||
|
||||
// Async record usage
|
||||
go func(result *service.OpenAIForwardResult, usedAccount *service.Account) {
|
||||
go func(result *service.OpenAIForwardResult, usedAccount *service.Account, ua string) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
|
||||
@@ -251,10 +251,11 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
User: apiKey.User,
|
||||
Account: usedAccount,
|
||||
Subscription: subscription,
|
||||
UserAgent: ua,
|
||||
}); err != nil {
|
||||
log.Printf("Record usage failed: %v", err)
|
||||
}
|
||||
}(result, account)
|
||||
}(result, account, userAgent)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
@@ -88,8 +88,9 @@ func (h *UsageHandler) List(c *gin.Context) {
|
||||
|
||||
// Parse date range
|
||||
var startTime, endTime *time.Time
|
||||
userTZ := c.Query("timezone") // Get user's timezone from request
|
||||
if startDateStr := c.Query("start_date"); startDateStr != "" {
|
||||
t, err := timezone.ParseInLocation("2006-01-02", startDateStr)
|
||||
t, err := timezone.ParseInUserLocation("2006-01-02", startDateStr, userTZ)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid start_date format, use YYYY-MM-DD")
|
||||
return
|
||||
@@ -98,7 +99,7 @@ func (h *UsageHandler) List(c *gin.Context) {
|
||||
}
|
||||
|
||||
if endDateStr := c.Query("end_date"); endDateStr != "" {
|
||||
t, err := timezone.ParseInLocation("2006-01-02", endDateStr)
|
||||
t, err := timezone.ParseInUserLocation("2006-01-02", endDateStr, userTZ)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid end_date format, use YYYY-MM-DD")
|
||||
return
|
||||
@@ -194,7 +195,8 @@ func (h *UsageHandler) Stats(c *gin.Context) {
|
||||
}
|
||||
|
||||
// 获取时间范围参数
|
||||
now := timezone.Now()
|
||||
userTZ := c.Query("timezone") // Get user's timezone from request
|
||||
now := timezone.NowInUserLocation(userTZ)
|
||||
var startTime, endTime time.Time
|
||||
|
||||
// 优先使用 start_date 和 end_date 参数
|
||||
@@ -204,12 +206,12 @@ func (h *UsageHandler) Stats(c *gin.Context) {
|
||||
if startDateStr != "" && endDateStr != "" {
|
||||
// 使用自定义日期范围
|
||||
var err error
|
||||
startTime, err = timezone.ParseInLocation("2006-01-02", startDateStr)
|
||||
startTime, err = timezone.ParseInUserLocation("2006-01-02", startDateStr, userTZ)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid start_date format, use YYYY-MM-DD")
|
||||
return
|
||||
}
|
||||
endTime, err = timezone.ParseInLocation("2006-01-02", endDateStr)
|
||||
endTime, err = timezone.ParseInUserLocation("2006-01-02", endDateStr, userTZ)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid end_date format, use YYYY-MM-DD")
|
||||
return
|
||||
@@ -221,13 +223,13 @@ func (h *UsageHandler) Stats(c *gin.Context) {
|
||||
period := c.DefaultQuery("period", "today")
|
||||
switch period {
|
||||
case "today":
|
||||
startTime = timezone.StartOfDay(now)
|
||||
startTime = timezone.StartOfDayInUserLocation(now, userTZ)
|
||||
case "week":
|
||||
startTime = now.AddDate(0, 0, -7)
|
||||
case "month":
|
||||
startTime = now.AddDate(0, -1, 0)
|
||||
default:
|
||||
startTime = timezone.StartOfDay(now)
|
||||
startTime = timezone.StartOfDayInUserLocation(now, userTZ)
|
||||
}
|
||||
endTime = now
|
||||
}
|
||||
@@ -248,31 +250,33 @@ func (h *UsageHandler) Stats(c *gin.Context) {
|
||||
}
|
||||
|
||||
// parseUserTimeRange parses start_date, end_date query parameters for user dashboard
|
||||
// Uses user's timezone if provided, otherwise falls back to server timezone
|
||||
func parseUserTimeRange(c *gin.Context) (time.Time, time.Time) {
|
||||
now := timezone.Now()
|
||||
userTZ := c.Query("timezone") // Get user's timezone from request
|
||||
now := timezone.NowInUserLocation(userTZ)
|
||||
startDate := c.Query("start_date")
|
||||
endDate := c.Query("end_date")
|
||||
|
||||
var startTime, endTime time.Time
|
||||
|
||||
if startDate != "" {
|
||||
if t, err := timezone.ParseInLocation("2006-01-02", startDate); err == nil {
|
||||
if t, err := timezone.ParseInUserLocation("2006-01-02", startDate, userTZ); err == nil {
|
||||
startTime = t
|
||||
} else {
|
||||
startTime = timezone.StartOfDay(now.AddDate(0, 0, -7))
|
||||
startTime = timezone.StartOfDayInUserLocation(now.AddDate(0, 0, -7), userTZ)
|
||||
}
|
||||
} else {
|
||||
startTime = timezone.StartOfDay(now.AddDate(0, 0, -7))
|
||||
startTime = timezone.StartOfDayInUserLocation(now.AddDate(0, 0, -7), userTZ)
|
||||
}
|
||||
|
||||
if endDate != "" {
|
||||
if t, err := timezone.ParseInLocation("2006-01-02", endDate); err == nil {
|
||||
if t, err := timezone.ParseInUserLocation("2006-01-02", endDate, userTZ); err == nil {
|
||||
endTime = t.Add(24 * time.Hour) // Include the end date
|
||||
} else {
|
||||
endTime = timezone.StartOfDay(now.AddDate(0, 0, 1))
|
||||
endTime = timezone.StartOfDayInUserLocation(now.AddDate(0, 0, 1), userTZ)
|
||||
}
|
||||
} else {
|
||||
endTime = timezone.StartOfDay(now.AddDate(0, 0, 1))
|
||||
endTime = timezone.StartOfDayInUserLocation(now.AddDate(0, 0, 1), userTZ)
|
||||
}
|
||||
|
||||
return startTime, endTime
|
||||
|
||||
@@ -16,7 +16,6 @@
|
||||
package httpclient
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
@@ -40,7 +39,7 @@ type Options struct {
|
||||
ProxyURL string // 代理 URL(支持 http/https/socks5/socks5h)
|
||||
Timeout time.Duration // 请求总超时时间
|
||||
ResponseHeaderTimeout time.Duration // 等待响应头超时时间
|
||||
InsecureSkipVerify bool // 是否跳过 TLS 证书验证
|
||||
InsecureSkipVerify bool // 是否跳过 TLS 证书验证(已禁用,不允许设置为 true)
|
||||
ProxyStrict bool // 严格代理模式:代理失败时返回错误而非回退
|
||||
ValidateResolvedIP bool // 是否校验解析后的 IP(防止 DNS Rebinding)
|
||||
AllowPrivateHosts bool // 允许私有地址解析(与 ValidateResolvedIP 一起使用)
|
||||
@@ -113,7 +112,8 @@ func buildTransport(opts Options) (*http.Transport, error) {
|
||||
}
|
||||
|
||||
if opts.InsecureSkipVerify {
|
||||
transport.TLSClientConfig = &tls.Config{InsecureSkipVerify: true}
|
||||
// 安全要求:禁止跳过证书验证,避免中间人攻击。
|
||||
return nil, fmt.Errorf("insecure_skip_verify is not allowed; install a trusted certificate instead")
|
||||
}
|
||||
|
||||
proxyURL := strings.TrimSpace(opts.ProxyURL)
|
||||
|
||||
@@ -122,3 +122,40 @@ func StartOfMonth(t time.Time) time.Time {
|
||||
func ParseInLocation(layout, value string) (time.Time, error) {
|
||||
return time.ParseInLocation(layout, value, Location())
|
||||
}
|
||||
|
||||
// ParseInUserLocation parses a time string in the user's timezone.
|
||||
// If userTZ is empty or invalid, falls back to the configured server timezone.
|
||||
func ParseInUserLocation(layout, value, userTZ string) (time.Time, error) {
|
||||
loc := Location() // default to server timezone
|
||||
if userTZ != "" {
|
||||
if userLoc, err := time.LoadLocation(userTZ); err == nil {
|
||||
loc = userLoc
|
||||
}
|
||||
}
|
||||
return time.ParseInLocation(layout, value, loc)
|
||||
}
|
||||
|
||||
// NowInUserLocation returns the current time in the user's timezone.
|
||||
// If userTZ is empty or invalid, falls back to the configured server timezone.
|
||||
func NowInUserLocation(userTZ string) time.Time {
|
||||
if userTZ == "" {
|
||||
return Now()
|
||||
}
|
||||
if userLoc, err := time.LoadLocation(userTZ); err == nil {
|
||||
return time.Now().In(userLoc)
|
||||
}
|
||||
return Now()
|
||||
}
|
||||
|
||||
// StartOfDayInUserLocation returns the start of the given day in the user's timezone.
|
||||
// If userTZ is empty or invalid, falls back to the configured server timezone.
|
||||
func StartOfDayInUserLocation(t time.Time, userTZ string) time.Time {
|
||||
loc := Location()
|
||||
if userTZ != "" {
|
||||
if userLoc, err := time.LoadLocation(userTZ); err == nil {
|
||||
loc = userLoc
|
||||
}
|
||||
}
|
||||
t = t.In(loc)
|
||||
return time.Date(t.Year(), t.Month(), t.Day(), 0, 0, 0, 0, loc)
|
||||
}
|
||||
|
||||
@@ -76,7 +76,8 @@ func (r *accountRepository) Create(ctx context.Context, account *service.Account
|
||||
SetPriority(account.Priority).
|
||||
SetStatus(account.Status).
|
||||
SetErrorMessage(account.ErrorMessage).
|
||||
SetSchedulable(account.Schedulable)
|
||||
SetSchedulable(account.Schedulable).
|
||||
SetAutoPauseOnExpired(account.AutoPauseOnExpired)
|
||||
|
||||
if account.ProxyID != nil {
|
||||
builder.SetProxyID(*account.ProxyID)
|
||||
@@ -84,6 +85,9 @@ func (r *accountRepository) Create(ctx context.Context, account *service.Account
|
||||
if account.LastUsedAt != nil {
|
||||
builder.SetLastUsedAt(*account.LastUsedAt)
|
||||
}
|
||||
if account.ExpiresAt != nil {
|
||||
builder.SetExpiresAt(*account.ExpiresAt)
|
||||
}
|
||||
if account.RateLimitedAt != nil {
|
||||
builder.SetRateLimitedAt(*account.RateLimitedAt)
|
||||
}
|
||||
@@ -280,7 +284,8 @@ func (r *accountRepository) Update(ctx context.Context, account *service.Account
|
||||
SetPriority(account.Priority).
|
||||
SetStatus(account.Status).
|
||||
SetErrorMessage(account.ErrorMessage).
|
||||
SetSchedulable(account.Schedulable)
|
||||
SetSchedulable(account.Schedulable).
|
||||
SetAutoPauseOnExpired(account.AutoPauseOnExpired)
|
||||
|
||||
if account.ProxyID != nil {
|
||||
builder.SetProxyID(*account.ProxyID)
|
||||
@@ -292,6 +297,11 @@ func (r *accountRepository) Update(ctx context.Context, account *service.Account
|
||||
} else {
|
||||
builder.ClearLastUsedAt()
|
||||
}
|
||||
if account.ExpiresAt != nil {
|
||||
builder.SetExpiresAt(*account.ExpiresAt)
|
||||
} else {
|
||||
builder.ClearExpiresAt()
|
||||
}
|
||||
if account.RateLimitedAt != nil {
|
||||
builder.SetRateLimitedAt(*account.RateLimitedAt)
|
||||
} else {
|
||||
@@ -570,6 +580,7 @@ func (r *accountRepository) ListSchedulable(ctx context.Context) ([]service.Acco
|
||||
dbaccount.StatusEQ(service.StatusActive),
|
||||
dbaccount.SchedulableEQ(true),
|
||||
tempUnschedulablePredicate(),
|
||||
notExpiredPredicate(now),
|
||||
dbaccount.Or(dbaccount.OverloadUntilIsNil(), dbaccount.OverloadUntilLTE(now)),
|
||||
dbaccount.Or(dbaccount.RateLimitResetAtIsNil(), dbaccount.RateLimitResetAtLTE(now)),
|
||||
).
|
||||
@@ -596,6 +607,7 @@ func (r *accountRepository) ListSchedulableByPlatform(ctx context.Context, platf
|
||||
dbaccount.StatusEQ(service.StatusActive),
|
||||
dbaccount.SchedulableEQ(true),
|
||||
tempUnschedulablePredicate(),
|
||||
notExpiredPredicate(now),
|
||||
dbaccount.Or(dbaccount.OverloadUntilIsNil(), dbaccount.OverloadUntilLTE(now)),
|
||||
dbaccount.Or(dbaccount.RateLimitResetAtIsNil(), dbaccount.RateLimitResetAtLTE(now)),
|
||||
).
|
||||
@@ -629,6 +641,7 @@ func (r *accountRepository) ListSchedulableByPlatforms(ctx context.Context, plat
|
||||
dbaccount.StatusEQ(service.StatusActive),
|
||||
dbaccount.SchedulableEQ(true),
|
||||
tempUnschedulablePredicate(),
|
||||
notExpiredPredicate(now),
|
||||
dbaccount.Or(dbaccount.OverloadUntilIsNil(), dbaccount.OverloadUntilLTE(now)),
|
||||
dbaccount.Or(dbaccount.RateLimitResetAtIsNil(), dbaccount.RateLimitResetAtLTE(now)),
|
||||
).
|
||||
@@ -727,6 +740,27 @@ func (r *accountRepository) SetSchedulable(ctx context.Context, id int64, schedu
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *accountRepository) AutoPauseExpiredAccounts(ctx context.Context, now time.Time) (int64, error) {
|
||||
result, err := r.sql.ExecContext(ctx, `
|
||||
UPDATE accounts
|
||||
SET schedulable = FALSE,
|
||||
updated_at = NOW()
|
||||
WHERE deleted_at IS NULL
|
||||
AND schedulable = TRUE
|
||||
AND auto_pause_on_expired = TRUE
|
||||
AND expires_at IS NOT NULL
|
||||
AND expires_at <= $1
|
||||
`, now)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
rows, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return rows, nil
|
||||
}
|
||||
|
||||
func (r *accountRepository) UpdateExtra(ctx context.Context, id int64, updates map[string]any) error {
|
||||
if len(updates) == 0 {
|
||||
return nil
|
||||
@@ -861,6 +895,7 @@ func (r *accountRepository) queryAccountsByGroup(ctx context.Context, groupID in
|
||||
preds = append(preds,
|
||||
dbaccount.SchedulableEQ(true),
|
||||
tempUnschedulablePredicate(),
|
||||
notExpiredPredicate(now),
|
||||
dbaccount.Or(dbaccount.OverloadUntilIsNil(), dbaccount.OverloadUntilLTE(now)),
|
||||
dbaccount.Or(dbaccount.RateLimitResetAtIsNil(), dbaccount.RateLimitResetAtLTE(now)),
|
||||
)
|
||||
@@ -971,6 +1006,14 @@ func tempUnschedulablePredicate() dbpredicate.Account {
|
||||
})
|
||||
}
|
||||
|
||||
func notExpiredPredicate(now time.Time) dbpredicate.Account {
|
||||
return dbaccount.Or(
|
||||
dbaccount.ExpiresAtIsNil(),
|
||||
dbaccount.ExpiresAtGT(now),
|
||||
dbaccount.AutoPauseOnExpiredEQ(false),
|
||||
)
|
||||
}
|
||||
|
||||
func (r *accountRepository) loadTempUnschedStates(ctx context.Context, accountIDs []int64) (map[int64]tempUnschedSnapshot, error) {
|
||||
out := make(map[int64]tempUnschedSnapshot)
|
||||
if len(accountIDs) == 0 {
|
||||
@@ -1086,6 +1129,8 @@ func accountEntityToService(m *dbent.Account) *service.Account {
|
||||
Status: m.Status,
|
||||
ErrorMessage: derefString(m.ErrorMessage),
|
||||
LastUsedAt: m.LastUsedAt,
|
||||
ExpiresAt: m.ExpiresAt,
|
||||
AutoPauseOnExpired: m.AutoPauseOnExpired,
|
||||
CreatedAt: m.CreatedAt,
|
||||
UpdatedAt: m.UpdatedAt,
|
||||
Schedulable: m.Schedulable,
|
||||
|
||||
@@ -14,23 +14,33 @@ import (
|
||||
)
|
||||
|
||||
type githubReleaseClient struct {
|
||||
httpClient *http.Client
|
||||
allowPrivateHosts bool
|
||||
httpClient *http.Client
|
||||
downloadHTTPClient *http.Client
|
||||
}
|
||||
|
||||
func NewGitHubReleaseClient() service.GitHubReleaseClient {
|
||||
allowPrivate := false
|
||||
// NewGitHubReleaseClient 创建 GitHub Release 客户端
|
||||
// proxyURL 为空时直连 GitHub,支持 http/https/socks5/socks5h 协议
|
||||
func NewGitHubReleaseClient(proxyURL string) service.GitHubReleaseClient {
|
||||
sharedClient, err := httpclient.GetClient(httpclient.Options{
|
||||
Timeout: 30 * time.Second,
|
||||
ValidateResolvedIP: true,
|
||||
AllowPrivateHosts: allowPrivate,
|
||||
Timeout: 30 * time.Second,
|
||||
ProxyURL: proxyURL,
|
||||
})
|
||||
if err != nil {
|
||||
sharedClient = &http.Client{Timeout: 30 * time.Second}
|
||||
}
|
||||
|
||||
// 下载客户端需要更长的超时时间
|
||||
downloadClient, err := httpclient.GetClient(httpclient.Options{
|
||||
Timeout: 10 * time.Minute,
|
||||
ProxyURL: proxyURL,
|
||||
})
|
||||
if err != nil {
|
||||
downloadClient = &http.Client{Timeout: 10 * time.Minute}
|
||||
}
|
||||
|
||||
return &githubReleaseClient{
|
||||
httpClient: sharedClient,
|
||||
allowPrivateHosts: allowPrivate,
|
||||
httpClient: sharedClient,
|
||||
downloadHTTPClient: downloadClient,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -68,15 +78,8 @@ func (c *githubReleaseClient) DownloadFile(ctx context.Context, url, dest string
|
||||
return err
|
||||
}
|
||||
|
||||
downloadClient, err := httpclient.GetClient(httpclient.Options{
|
||||
Timeout: 10 * time.Minute,
|
||||
ValidateResolvedIP: true,
|
||||
AllowPrivateHosts: c.allowPrivateHosts,
|
||||
})
|
||||
if err != nil {
|
||||
downloadClient = &http.Client{Timeout: 10 * time.Minute}
|
||||
}
|
||||
resp, err := downloadClient.Do(req)
|
||||
// 使用预配置的下载客户端(已包含代理配置)
|
||||
resp, err := c.downloadHTTPClient.Do(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -39,8 +39,8 @@ func (t *testTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
|
||||
func newTestGitHubReleaseClient() *githubReleaseClient {
|
||||
return &githubReleaseClient{
|
||||
httpClient: &http.Client{},
|
||||
allowPrivateHosts: true,
|
||||
httpClient: &http.Client{},
|
||||
downloadHTTPClient: &http.Client{},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -234,7 +234,7 @@ func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_Success() {
|
||||
httpClient: &http.Client{
|
||||
Transport: &testTransport{testServerURL: s.srv.URL},
|
||||
},
|
||||
allowPrivateHosts: true,
|
||||
downloadHTTPClient: &http.Client{},
|
||||
}
|
||||
|
||||
release, err := s.client.FetchLatestRelease(context.Background(), "test/repo")
|
||||
@@ -254,7 +254,7 @@ func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_Non200() {
|
||||
httpClient: &http.Client{
|
||||
Transport: &testTransport{testServerURL: s.srv.URL},
|
||||
},
|
||||
allowPrivateHosts: true,
|
||||
downloadHTTPClient: &http.Client{},
|
||||
}
|
||||
|
||||
_, err := s.client.FetchLatestRelease(context.Background(), "test/repo")
|
||||
@@ -272,7 +272,7 @@ func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_InvalidJSON() {
|
||||
httpClient: &http.Client{
|
||||
Transport: &testTransport{testServerURL: s.srv.URL},
|
||||
},
|
||||
allowPrivateHosts: true,
|
||||
downloadHTTPClient: &http.Client{},
|
||||
}
|
||||
|
||||
_, err := s.client.FetchLatestRelease(context.Background(), "test/repo")
|
||||
@@ -288,7 +288,7 @@ func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_ContextCancel() {
|
||||
httpClient: &http.Client{
|
||||
Transport: &testTransport{testServerURL: s.srv.URL},
|
||||
},
|
||||
allowPrivateHosts: true,
|
||||
downloadHTTPClient: &http.Client{},
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
@@ -8,7 +8,6 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
@@ -17,17 +16,12 @@ type pricingRemoteClient struct {
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
func NewPricingRemoteClient(cfg *config.Config) service.PricingRemoteClient {
|
||||
allowPrivate := false
|
||||
validateResolvedIP := true
|
||||
if cfg != nil {
|
||||
allowPrivate = cfg.Security.URLAllowlist.AllowPrivateHosts
|
||||
validateResolvedIP = cfg.Security.URLAllowlist.Enabled
|
||||
}
|
||||
// NewPricingRemoteClient 创建定价数据远程客户端
|
||||
// proxyURL 为空时直连,支持 http/https/socks5/socks5h 协议
|
||||
func NewPricingRemoteClient(proxyURL string) service.PricingRemoteClient {
|
||||
sharedClient, err := httpclient.GetClient(httpclient.Options{
|
||||
Timeout: 30 * time.Second,
|
||||
ValidateResolvedIP: validateResolvedIP,
|
||||
AllowPrivateHosts: allowPrivate,
|
||||
Timeout: 30 * time.Second,
|
||||
ProxyURL: proxyURL,
|
||||
})
|
||||
if err != nil {
|
||||
sharedClient = &http.Client{Timeout: 30 * time.Second}
|
||||
|
||||
@@ -6,7 +6,6 @@ import (
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
@@ -20,13 +19,7 @@ type PricingServiceSuite struct {
|
||||
|
||||
func (s *PricingServiceSuite) SetupTest() {
|
||||
s.ctx = context.Background()
|
||||
client, ok := NewPricingRemoteClient(&config.Config{
|
||||
Security: config.SecurityConfig{
|
||||
URLAllowlist: config.URLAllowlistConfig{
|
||||
AllowPrivateHosts: true,
|
||||
},
|
||||
},
|
||||
}).(*pricingRemoteClient)
|
||||
client, ok := NewPricingRemoteClient("").(*pricingRemoteClient)
|
||||
require.True(s.T(), ok, "type assertion failed")
|
||||
s.client = client
|
||||
}
|
||||
|
||||
@@ -24,7 +24,7 @@ func NewProxyExitInfoProber(cfg *config.Config) service.ProxyExitInfoProber {
|
||||
validateResolvedIP = cfg.Security.URLAllowlist.Enabled
|
||||
}
|
||||
if insecure {
|
||||
log.Printf("[ProxyProbe] Warning: TLS verification is disabled for proxy probing.")
|
||||
log.Printf("[ProxyProbe] Warning: insecure_skip_verify is not allowed and will cause probe failure.")
|
||||
}
|
||||
return &proxyProbeService{
|
||||
ipInfoURL: defaultIPInfoURL,
|
||||
|
||||
@@ -22,7 +22,7 @@ import (
|
||||
"github.com/lib/pq"
|
||||
)
|
||||
|
||||
const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, billing_type, stream, duration_ms, first_token_ms, image_count, image_size, created_at"
|
||||
const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, billing_type, stream, duration_ms, first_token_ms, user_agent, image_count, image_size, created_at"
|
||||
|
||||
type usageLogRepository struct {
|
||||
client *dbent.Client
|
||||
@@ -109,6 +109,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
|
||||
stream,
|
||||
duration_ms,
|
||||
first_token_ms,
|
||||
user_agent,
|
||||
image_count,
|
||||
image_size,
|
||||
created_at
|
||||
@@ -118,8 +119,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
|
||||
$8, $9, $10, $11,
|
||||
$12, $13,
|
||||
$14, $15, $16, $17, $18, $19,
|
||||
$20, $21, $22, $23, $24,
|
||||
$25, $26, $27
|
||||
$20, $21, $22, $23, $24, $25, $26, $27, $28
|
||||
)
|
||||
ON CONFLICT (request_id, api_key_id) DO NOTHING
|
||||
RETURNING id, created_at
|
||||
@@ -129,6 +129,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
|
||||
subscriptionID := nullInt64(log.SubscriptionID)
|
||||
duration := nullInt(log.DurationMs)
|
||||
firstToken := nullInt(log.FirstTokenMs)
|
||||
userAgent := nullString(log.UserAgent)
|
||||
imageSize := nullString(log.ImageSize)
|
||||
|
||||
var requestIDArg any
|
||||
@@ -161,6 +162,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
|
||||
log.Stream,
|
||||
duration,
|
||||
firstToken,
|
||||
userAgent,
|
||||
log.ImageCount,
|
||||
imageSize,
|
||||
createdAt,
|
||||
@@ -1388,6 +1390,81 @@ func (r *usageLogRepository) GetGlobalStats(ctx context.Context, startTime, endT
|
||||
return stats, nil
|
||||
}
|
||||
|
||||
// GetStatsWithFilters gets usage statistics with optional filters
|
||||
func (r *usageLogRepository) GetStatsWithFilters(ctx context.Context, filters UsageLogFilters) (*UsageStats, error) {
|
||||
conditions := make([]string, 0, 9)
|
||||
args := make([]any, 0, 9)
|
||||
|
||||
if filters.UserID > 0 {
|
||||
conditions = append(conditions, fmt.Sprintf("user_id = $%d", len(args)+1))
|
||||
args = append(args, filters.UserID)
|
||||
}
|
||||
if filters.APIKeyID > 0 {
|
||||
conditions = append(conditions, fmt.Sprintf("api_key_id = $%d", len(args)+1))
|
||||
args = append(args, filters.APIKeyID)
|
||||
}
|
||||
if filters.AccountID > 0 {
|
||||
conditions = append(conditions, fmt.Sprintf("account_id = $%d", len(args)+1))
|
||||
args = append(args, filters.AccountID)
|
||||
}
|
||||
if filters.GroupID > 0 {
|
||||
conditions = append(conditions, fmt.Sprintf("group_id = $%d", len(args)+1))
|
||||
args = append(args, filters.GroupID)
|
||||
}
|
||||
if filters.Model != "" {
|
||||
conditions = append(conditions, fmt.Sprintf("model = $%d", len(args)+1))
|
||||
args = append(args, filters.Model)
|
||||
}
|
||||
if filters.Stream != nil {
|
||||
conditions = append(conditions, fmt.Sprintf("stream = $%d", len(args)+1))
|
||||
args = append(args, *filters.Stream)
|
||||
}
|
||||
if filters.BillingType != nil {
|
||||
conditions = append(conditions, fmt.Sprintf("billing_type = $%d", len(args)+1))
|
||||
args = append(args, int16(*filters.BillingType))
|
||||
}
|
||||
if filters.StartTime != nil {
|
||||
conditions = append(conditions, fmt.Sprintf("created_at >= $%d", len(args)+1))
|
||||
args = append(args, *filters.StartTime)
|
||||
}
|
||||
if filters.EndTime != nil {
|
||||
conditions = append(conditions, fmt.Sprintf("created_at <= $%d", len(args)+1))
|
||||
args = append(args, *filters.EndTime)
|
||||
}
|
||||
|
||||
query := fmt.Sprintf(`
|
||||
SELECT
|
||||
COUNT(*) as total_requests,
|
||||
COALESCE(SUM(input_tokens), 0) as total_input_tokens,
|
||||
COALESCE(SUM(output_tokens), 0) as total_output_tokens,
|
||||
COALESCE(SUM(cache_creation_tokens + cache_read_tokens), 0) as total_cache_tokens,
|
||||
COALESCE(SUM(total_cost), 0) as total_cost,
|
||||
COALESCE(SUM(actual_cost), 0) as total_actual_cost,
|
||||
COALESCE(AVG(duration_ms), 0) as avg_duration_ms
|
||||
FROM usage_logs
|
||||
%s
|
||||
`, buildWhere(conditions))
|
||||
|
||||
stats := &UsageStats{}
|
||||
if err := scanSingleRow(
|
||||
ctx,
|
||||
r.sql,
|
||||
query,
|
||||
args,
|
||||
&stats.TotalRequests,
|
||||
&stats.TotalInputTokens,
|
||||
&stats.TotalOutputTokens,
|
||||
&stats.TotalCacheTokens,
|
||||
&stats.TotalCost,
|
||||
&stats.TotalActualCost,
|
||||
&stats.AverageDurationMs,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
stats.TotalTokens = stats.TotalInputTokens + stats.TotalOutputTokens + stats.TotalCacheTokens
|
||||
return stats, nil
|
||||
}
|
||||
|
||||
// AccountUsageHistory represents daily usage history for an account
|
||||
type AccountUsageHistory = usagestats.AccountUsageHistory
|
||||
|
||||
@@ -1795,6 +1872,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
|
||||
stream bool
|
||||
durationMs sql.NullInt64
|
||||
firstTokenMs sql.NullInt64
|
||||
userAgent sql.NullString
|
||||
imageCount int
|
||||
imageSize sql.NullString
|
||||
createdAt time.Time
|
||||
@@ -1826,6 +1904,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
|
||||
&stream,
|
||||
&durationMs,
|
||||
&firstTokenMs,
|
||||
&userAgent,
|
||||
&imageCount,
|
||||
&imageSize,
|
||||
&createdAt,
|
||||
@@ -1877,6 +1956,9 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
|
||||
value := int(firstTokenMs.Int64)
|
||||
log.FirstTokenMs = &value
|
||||
}
|
||||
if userAgent.Valid {
|
||||
log.UserAgent = &userAgent.String
|
||||
}
|
||||
if imageSize.Valid {
|
||||
log.ImageSize = &imageSize.String
|
||||
}
|
||||
|
||||
@@ -25,6 +25,18 @@ func ProvideConcurrencyCache(rdb *redis.Client, cfg *config.Config) service.Conc
|
||||
return NewConcurrencyCache(rdb, cfg.Gateway.ConcurrencySlotTTLMinutes, waitTTLSeconds)
|
||||
}
|
||||
|
||||
// ProvideGitHubReleaseClient 创建 GitHub Release 客户端
|
||||
// 从配置中读取代理设置,支持国内服务器通过代理访问 GitHub
|
||||
func ProvideGitHubReleaseClient(cfg *config.Config) service.GitHubReleaseClient {
|
||||
return NewGitHubReleaseClient(cfg.Update.ProxyURL)
|
||||
}
|
||||
|
||||
// ProvidePricingRemoteClient 创建定价数据远程客户端
|
||||
// 从配置中读取代理设置,支持国内服务器通过代理访问 GitHub 上的定价数据
|
||||
func ProvidePricingRemoteClient(cfg *config.Config) service.PricingRemoteClient {
|
||||
return NewPricingRemoteClient(cfg.Update.ProxyURL)
|
||||
}
|
||||
|
||||
// ProviderSet is the Wire provider set for all repositories
|
||||
var ProviderSet = wire.NewSet(
|
||||
NewUserRepository,
|
||||
@@ -53,8 +65,8 @@ var ProviderSet = wire.NewSet(
|
||||
|
||||
// HTTP service ports (DI Strategy A: return interface directly)
|
||||
NewTurnstileVerifier,
|
||||
NewPricingRemoteClient,
|
||||
NewGitHubReleaseClient,
|
||||
ProvidePricingRemoteClient,
|
||||
ProvideGitHubReleaseClient,
|
||||
NewProxyExitInfoProber,
|
||||
NewClaudeUsageFetcher,
|
||||
NewClaudeOAuthClient,
|
||||
|
||||
@@ -1065,6 +1065,10 @@ func (r *stubUsageLogRepo) GetAccountUsageStats(ctx context.Context, accountID i
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubUsageLogRepo) GetStatsWithFilters(ctx context.Context, filters usagestats.UsageLogFilters) (*usagestats.UsageStats, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
type stubSettingRepo struct {
|
||||
all map[string]string
|
||||
}
|
||||
|
||||
@@ -9,21 +9,23 @@ import (
|
||||
)
|
||||
|
||||
type Account struct {
|
||||
ID int64
|
||||
Name string
|
||||
Notes *string
|
||||
Platform string
|
||||
Type string
|
||||
Credentials map[string]any
|
||||
Extra map[string]any
|
||||
ProxyID *int64
|
||||
Concurrency int
|
||||
Priority int
|
||||
Status string
|
||||
ErrorMessage string
|
||||
LastUsedAt *time.Time
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
ID int64
|
||||
Name string
|
||||
Notes *string
|
||||
Platform string
|
||||
Type string
|
||||
Credentials map[string]any
|
||||
Extra map[string]any
|
||||
ProxyID *int64
|
||||
Concurrency int
|
||||
Priority int
|
||||
Status string
|
||||
ErrorMessage string
|
||||
LastUsedAt *time.Time
|
||||
ExpiresAt *time.Time
|
||||
AutoPauseOnExpired bool
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
|
||||
Schedulable bool
|
||||
|
||||
@@ -60,6 +62,9 @@ func (a *Account) IsSchedulable() bool {
|
||||
return false
|
||||
}
|
||||
now := time.Now()
|
||||
if a.AutoPauseOnExpired && a.ExpiresAt != nil && !now.Before(*a.ExpiresAt) {
|
||||
return false
|
||||
}
|
||||
if a.OverloadUntil != nil && now.Before(*a.OverloadUntil) {
|
||||
return false
|
||||
}
|
||||
|
||||
71
backend/internal/service/account_expiry_service.go
Normal file
71
backend/internal/service/account_expiry_service.go
Normal file
@@ -0,0 +1,71 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// AccountExpiryService periodically pauses expired accounts when auto-pause is enabled.
|
||||
type AccountExpiryService struct {
|
||||
accountRepo AccountRepository
|
||||
interval time.Duration
|
||||
stopCh chan struct{}
|
||||
stopOnce sync.Once
|
||||
wg sync.WaitGroup
|
||||
}
|
||||
|
||||
func NewAccountExpiryService(accountRepo AccountRepository, interval time.Duration) *AccountExpiryService {
|
||||
return &AccountExpiryService{
|
||||
accountRepo: accountRepo,
|
||||
interval: interval,
|
||||
stopCh: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *AccountExpiryService) Start() {
|
||||
if s == nil || s.accountRepo == nil || s.interval <= 0 {
|
||||
return
|
||||
}
|
||||
s.wg.Add(1)
|
||||
go func() {
|
||||
defer s.wg.Done()
|
||||
ticker := time.NewTicker(s.interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
s.runOnce()
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
s.runOnce()
|
||||
case <-s.stopCh:
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (s *AccountExpiryService) Stop() {
|
||||
if s == nil {
|
||||
return
|
||||
}
|
||||
s.stopOnce.Do(func() {
|
||||
close(s.stopCh)
|
||||
})
|
||||
s.wg.Wait()
|
||||
}
|
||||
|
||||
func (s *AccountExpiryService) runOnce() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
updated, err := s.accountRepo.AutoPauseExpiredAccounts(ctx, time.Now())
|
||||
if err != nil {
|
||||
log.Printf("[AccountExpiry] Auto pause expired accounts failed: %v", err)
|
||||
return
|
||||
}
|
||||
if updated > 0 {
|
||||
log.Printf("[AccountExpiry] Auto paused %d expired accounts", updated)
|
||||
}
|
||||
}
|
||||
@@ -38,6 +38,7 @@ type AccountRepository interface {
|
||||
BatchUpdateLastUsed(ctx context.Context, updates map[int64]time.Time) error
|
||||
SetError(ctx context.Context, id int64, errorMsg string) error
|
||||
SetSchedulable(ctx context.Context, id int64, schedulable bool) error
|
||||
AutoPauseExpiredAccounts(ctx context.Context, now time.Time) (int64, error)
|
||||
BindGroups(ctx context.Context, accountID int64, groupIDs []int64) error
|
||||
|
||||
ListSchedulable(ctx context.Context) ([]Account, error)
|
||||
@@ -71,29 +72,33 @@ type AccountBulkUpdate struct {
|
||||
|
||||
// CreateAccountRequest 创建账号请求
|
||||
type CreateAccountRequest struct {
|
||||
Name string `json:"name"`
|
||||
Notes *string `json:"notes"`
|
||||
Platform string `json:"platform"`
|
||||
Type string `json:"type"`
|
||||
Credentials map[string]any `json:"credentials"`
|
||||
Extra map[string]any `json:"extra"`
|
||||
ProxyID *int64 `json:"proxy_id"`
|
||||
Concurrency int `json:"concurrency"`
|
||||
Priority int `json:"priority"`
|
||||
GroupIDs []int64 `json:"group_ids"`
|
||||
Name string `json:"name"`
|
||||
Notes *string `json:"notes"`
|
||||
Platform string `json:"platform"`
|
||||
Type string `json:"type"`
|
||||
Credentials map[string]any `json:"credentials"`
|
||||
Extra map[string]any `json:"extra"`
|
||||
ProxyID *int64 `json:"proxy_id"`
|
||||
Concurrency int `json:"concurrency"`
|
||||
Priority int `json:"priority"`
|
||||
GroupIDs []int64 `json:"group_ids"`
|
||||
ExpiresAt *time.Time `json:"expires_at"`
|
||||
AutoPauseOnExpired *bool `json:"auto_pause_on_expired"`
|
||||
}
|
||||
|
||||
// UpdateAccountRequest 更新账号请求
|
||||
type UpdateAccountRequest struct {
|
||||
Name *string `json:"name"`
|
||||
Notes *string `json:"notes"`
|
||||
Credentials *map[string]any `json:"credentials"`
|
||||
Extra *map[string]any `json:"extra"`
|
||||
ProxyID *int64 `json:"proxy_id"`
|
||||
Concurrency *int `json:"concurrency"`
|
||||
Priority *int `json:"priority"`
|
||||
Status *string `json:"status"`
|
||||
GroupIDs *[]int64 `json:"group_ids"`
|
||||
Name *string `json:"name"`
|
||||
Notes *string `json:"notes"`
|
||||
Credentials *map[string]any `json:"credentials"`
|
||||
Extra *map[string]any `json:"extra"`
|
||||
ProxyID *int64 `json:"proxy_id"`
|
||||
Concurrency *int `json:"concurrency"`
|
||||
Priority *int `json:"priority"`
|
||||
Status *string `json:"status"`
|
||||
GroupIDs *[]int64 `json:"group_ids"`
|
||||
ExpiresAt *time.Time `json:"expires_at"`
|
||||
AutoPauseOnExpired *bool `json:"auto_pause_on_expired"`
|
||||
}
|
||||
|
||||
// AccountService 账号管理服务
|
||||
@@ -134,6 +139,12 @@ func (s *AccountService) Create(ctx context.Context, req CreateAccountRequest) (
|
||||
Concurrency: req.Concurrency,
|
||||
Priority: req.Priority,
|
||||
Status: StatusActive,
|
||||
ExpiresAt: req.ExpiresAt,
|
||||
}
|
||||
if req.AutoPauseOnExpired != nil {
|
||||
account.AutoPauseOnExpired = *req.AutoPauseOnExpired
|
||||
} else {
|
||||
account.AutoPauseOnExpired = true
|
||||
}
|
||||
|
||||
if err := s.accountRepo.Create(ctx, account); err != nil {
|
||||
@@ -224,6 +235,12 @@ func (s *AccountService) Update(ctx context.Context, id int64, req UpdateAccount
|
||||
if req.Status != nil {
|
||||
account.Status = *req.Status
|
||||
}
|
||||
if req.ExpiresAt != nil {
|
||||
account.ExpiresAt = req.ExpiresAt
|
||||
}
|
||||
if req.AutoPauseOnExpired != nil {
|
||||
account.AutoPauseOnExpired = *req.AutoPauseOnExpired
|
||||
}
|
||||
|
||||
// 先验证分组是否存在(在任何写操作之前)
|
||||
if req.GroupIDs != nil {
|
||||
|
||||
@@ -103,6 +103,10 @@ func (s *accountRepoStub) SetSchedulable(ctx context.Context, id int64, schedula
|
||||
panic("unexpected SetSchedulable call")
|
||||
}
|
||||
|
||||
func (s *accountRepoStub) AutoPauseExpiredAccounts(ctx context.Context, now time.Time) (int64, error) {
|
||||
panic("unexpected AutoPauseExpiredAccounts call")
|
||||
}
|
||||
|
||||
func (s *accountRepoStub) BindGroups(ctx context.Context, accountID int64, groupIDs []int64) error {
|
||||
panic("unexpected BindGroups call")
|
||||
}
|
||||
|
||||
@@ -47,6 +47,7 @@ type UsageLogRepository interface {
|
||||
// Admin usage listing/stats
|
||||
ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters usagestats.UsageLogFilters) ([]UsageLog, *pagination.PaginationResult, error)
|
||||
GetGlobalStats(ctx context.Context, startTime, endTime time.Time) (*usagestats.UsageStats, error)
|
||||
GetStatsWithFilters(ctx context.Context, filters usagestats.UsageLogFilters) (*usagestats.UsageStats, error)
|
||||
|
||||
// Account stats
|
||||
GetAccountUsageStats(ctx context.Context, accountID int64, startTime, endTime time.Time) (*usagestats.AccountUsageStatsResponse, error)
|
||||
|
||||
@@ -122,16 +122,18 @@ type UpdateGroupInput struct {
|
||||
}
|
||||
|
||||
type CreateAccountInput struct {
|
||||
Name string
|
||||
Notes *string
|
||||
Platform string
|
||||
Type string
|
||||
Credentials map[string]any
|
||||
Extra map[string]any
|
||||
ProxyID *int64
|
||||
Concurrency int
|
||||
Priority int
|
||||
GroupIDs []int64
|
||||
Name string
|
||||
Notes *string
|
||||
Platform string
|
||||
Type string
|
||||
Credentials map[string]any
|
||||
Extra map[string]any
|
||||
ProxyID *int64
|
||||
Concurrency int
|
||||
Priority int
|
||||
GroupIDs []int64
|
||||
ExpiresAt *int64
|
||||
AutoPauseOnExpired *bool
|
||||
// SkipMixedChannelCheck skips the mixed channel risk check when binding groups.
|
||||
// This should only be set when the caller has explicitly confirmed the risk.
|
||||
SkipMixedChannelCheck bool
|
||||
@@ -148,6 +150,8 @@ type UpdateAccountInput struct {
|
||||
Priority *int // 使用指针区分"未提供"和"设置为0"
|
||||
Status string
|
||||
GroupIDs *[]int64
|
||||
ExpiresAt *int64
|
||||
AutoPauseOnExpired *bool
|
||||
SkipMixedChannelCheck bool // 跳过混合渠道检查(用户已确认风险)
|
||||
}
|
||||
|
||||
@@ -700,6 +704,15 @@ func (s *adminServiceImpl) CreateAccount(ctx context.Context, input *CreateAccou
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
}
|
||||
if input.ExpiresAt != nil && *input.ExpiresAt > 0 {
|
||||
expiresAt := time.Unix(*input.ExpiresAt, 0)
|
||||
account.ExpiresAt = &expiresAt
|
||||
}
|
||||
if input.AutoPauseOnExpired != nil {
|
||||
account.AutoPauseOnExpired = *input.AutoPauseOnExpired
|
||||
} else {
|
||||
account.AutoPauseOnExpired = true
|
||||
}
|
||||
if err := s.accountRepo.Create(ctx, account); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -755,6 +768,17 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U
|
||||
if input.Status != "" {
|
||||
account.Status = input.Status
|
||||
}
|
||||
if input.ExpiresAt != nil {
|
||||
if *input.ExpiresAt <= 0 {
|
||||
account.ExpiresAt = nil
|
||||
} else {
|
||||
expiresAt := time.Unix(*input.ExpiresAt, 0)
|
||||
account.ExpiresAt = &expiresAt
|
||||
}
|
||||
}
|
||||
if input.AutoPauseOnExpired != nil {
|
||||
account.AutoPauseOnExpired = *input.AutoPauseOnExpired
|
||||
}
|
||||
|
||||
// 先验证分组是否存在(在任何写操作之前)
|
||||
if input.GroupIDs != nil {
|
||||
|
||||
@@ -20,12 +20,16 @@ var (
|
||||
ErrEmailExists = infraerrors.Conflict("EMAIL_EXISTS", "email already exists")
|
||||
ErrInvalidToken = infraerrors.Unauthorized("INVALID_TOKEN", "invalid token")
|
||||
ErrTokenExpired = infraerrors.Unauthorized("TOKEN_EXPIRED", "token has expired")
|
||||
ErrTokenTooLarge = infraerrors.BadRequest("TOKEN_TOO_LARGE", "token too large")
|
||||
ErrTokenRevoked = infraerrors.Unauthorized("TOKEN_REVOKED", "token has been revoked")
|
||||
ErrEmailVerifyRequired = infraerrors.BadRequest("EMAIL_VERIFY_REQUIRED", "email verification is required")
|
||||
ErrRegDisabled = infraerrors.Forbidden("REGISTRATION_DISABLED", "registration is currently disabled")
|
||||
ErrServiceUnavailable = infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "service temporarily unavailable")
|
||||
)
|
||||
|
||||
// maxTokenLength 限制 token 大小,避免超长 header 触发解析时的异常内存分配。
|
||||
const maxTokenLength = 8192
|
||||
|
||||
// JWTClaims JWT载荷数据
|
||||
type JWTClaims struct {
|
||||
UserID int64 `json:"user_id"`
|
||||
@@ -309,7 +313,20 @@ func (s *AuthService) Login(ctx context.Context, email, password string) (string
|
||||
|
||||
// ValidateToken 验证JWT token并返回用户声明
|
||||
func (s *AuthService) ValidateToken(tokenString string) (*JWTClaims, error) {
|
||||
token, err := jwt.ParseWithClaims(tokenString, &JWTClaims{}, func(token *jwt.Token) (any, error) {
|
||||
// 先做长度校验,尽早拒绝异常超长 token,降低 DoS 风险。
|
||||
if len(tokenString) > maxTokenLength {
|
||||
return nil, ErrTokenTooLarge
|
||||
}
|
||||
|
||||
// 使用解析器并限制可接受的签名算法,防止算法混淆。
|
||||
parser := jwt.NewParser(jwt.WithValidMethods([]string{
|
||||
jwt.SigningMethodHS256.Name,
|
||||
jwt.SigningMethodHS384.Name,
|
||||
jwt.SigningMethodHS512.Name,
|
||||
}))
|
||||
|
||||
// 保留默认 claims 校验(exp/nbf),避免放行过期或未生效的 token。
|
||||
token, err := parser.ParseWithClaims(tokenString, &JWTClaims{}, func(token *jwt.Token) (any, error) {
|
||||
// 验证签名方法
|
||||
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
||||
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
|
||||
|
||||
@@ -140,6 +140,8 @@ func (s *EmailService) SendEmailWithConfig(config *SMTPConfig, to, subject, body
|
||||
func (s *EmailService) sendMailTLS(addr string, auth smtp.Auth, from, to string, msg []byte, host string) error {
|
||||
tlsConfig := &tls.Config{
|
||||
ServerName: host,
|
||||
// 强制 TLS 1.2+,避免协议降级导致的弱加密风险。
|
||||
MinVersion: tls.VersionTLS12,
|
||||
}
|
||||
|
||||
conn, err := tls.Dial("tcp", addr, tlsConfig)
|
||||
@@ -311,7 +313,11 @@ func (s *EmailService) TestSMTPConnectionWithConfig(config *SMTPConfig) error {
|
||||
addr := fmt.Sprintf("%s:%d", config.Host, config.Port)
|
||||
|
||||
if config.UseTLS {
|
||||
tlsConfig := &tls.Config{ServerName: config.Host}
|
||||
tlsConfig := &tls.Config{
|
||||
ServerName: config.Host,
|
||||
// 与发送逻辑一致,显式要求 TLS 1.2+。
|
||||
MinVersion: tls.VersionTLS12,
|
||||
}
|
||||
conn, err := tls.Dial("tcp", addr, tlsConfig)
|
||||
if err != nil {
|
||||
return fmt.Errorf("tls connection failed: %w", err)
|
||||
|
||||
@@ -105,6 +105,9 @@ func (m *mockAccountRepoForPlatform) SetError(ctx context.Context, id int64, err
|
||||
func (m *mockAccountRepoForPlatform) SetSchedulable(ctx context.Context, id int64, schedulable bool) error {
|
||||
return nil
|
||||
}
|
||||
func (m *mockAccountRepoForPlatform) AutoPauseExpiredAccounts(ctx context.Context, now time.Time) (int64, error) {
|
||||
return 0, nil
|
||||
}
|
||||
func (m *mockAccountRepoForPlatform) BindGroups(ctx context.Context, accountID int64, groupIDs []int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -35,6 +35,7 @@ const (
|
||||
stickySessionTTL = time.Hour // 粘性会话TTL
|
||||
defaultMaxLineSize = 10 * 1024 * 1024
|
||||
claudeCodeSystemPrompt = "You are Claude Code, Anthropic's official CLI for Claude."
|
||||
maxCacheControlBlocks = 4 // Anthropic API 允许的最大 cache_control 块数量
|
||||
)
|
||||
|
||||
// sseDataRe matches SSE data lines with optional whitespace after colon.
|
||||
@@ -43,6 +44,16 @@ var (
|
||||
sseDataRe = regexp.MustCompile(`^data:\s*`)
|
||||
sessionIDRegex = regexp.MustCompile(`session_([a-f0-9-]{36})`)
|
||||
claudeCliUserAgentRe = regexp.MustCompile(`^claude-cli/\d+\.\d+\.\d+`)
|
||||
|
||||
// claudeCodePromptPrefixes 用于检测 Claude Code 系统提示词的前缀列表
|
||||
// 支持多种变体:标准版、Agent SDK 版、Explore Agent 版、Compact 版等
|
||||
// 注意:前缀之间不应存在包含关系,否则会导致冗余匹配
|
||||
claudeCodePromptPrefixes = []string{
|
||||
"You are Claude Code, Anthropic's official CLI for Claude", // 标准版 & Agent SDK 版(含 running within...)
|
||||
"You are a Claude agent, built on Anthropic's Claude Agent SDK", // Agent SDK 变体
|
||||
"You are a file search specialist for Claude Code", // Explore Agent 版
|
||||
"You are a helpful AI assistant tasked with summarizing conversations", // Compact 版
|
||||
}
|
||||
)
|
||||
|
||||
// allowedHeaders 白名单headers(参考CRS项目)
|
||||
@@ -98,12 +109,13 @@ type ClaudeUsage struct {
|
||||
|
||||
// ForwardResult 转发结果
|
||||
type ForwardResult struct {
|
||||
RequestID string
|
||||
Usage ClaudeUsage
|
||||
Model string
|
||||
Stream bool
|
||||
Duration time.Duration
|
||||
FirstTokenMs *int // 首字时间(流式请求)
|
||||
RequestID string
|
||||
Usage ClaudeUsage
|
||||
Model string
|
||||
Stream bool
|
||||
Duration time.Duration
|
||||
FirstTokenMs *int // 首字时间(流式请求)
|
||||
ClientDisconnect bool // 客户端是否在流式传输过程中断开
|
||||
|
||||
// 图片生成计费字段(仅 gemini-3-pro-image 使用)
|
||||
ImageCount int // 生成的图片数量
|
||||
@@ -355,17 +367,8 @@ func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context
|
||||
return s.selectAccountWithMixedScheduling(ctx, groupID, sessionHash, requestedModel, excludedIDs, platform)
|
||||
}
|
||||
|
||||
// 强制平台模式:优先按分组查找,找不到再查全部该平台账户
|
||||
if hasForcePlatform && groupID != nil {
|
||||
account, err := s.selectAccountForModelWithPlatform(ctx, groupID, sessionHash, requestedModel, excludedIDs, platform)
|
||||
if err == nil {
|
||||
return account, nil
|
||||
}
|
||||
// 分组中找不到,回退查询全部该平台账户
|
||||
groupID = nil
|
||||
}
|
||||
|
||||
// antigravity 分组、强制平台模式或无分组使用单平台选择
|
||||
// 注意:强制平台模式也必须遵守分组限制,不再回退到全平台查询
|
||||
return s.selectAccountForModelWithPlatform(ctx, groupID, sessionHash, requestedModel, excludedIDs, platform)
|
||||
}
|
||||
|
||||
@@ -443,7 +446,8 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
accountID, err := s.cache.GetSessionAccountID(ctx, sessionHash)
|
||||
if err == nil && accountID > 0 && !isExcluded(accountID) {
|
||||
account, err := s.accountRepo.GetByID(ctx, accountID)
|
||||
if err == nil && s.isAccountAllowedForPlatform(account, platform, useMixed) &&
|
||||
if err == nil && s.isAccountInGroup(account, groupID) &&
|
||||
s.isAccountAllowedForPlatform(account, platform, useMixed) &&
|
||||
account.IsSchedulable() &&
|
||||
(requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
|
||||
result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency)
|
||||
@@ -660,9 +664,7 @@ func (s *GatewayService) listSchedulableAccounts(ctx context.Context, groupID *i
|
||||
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, platform)
|
||||
} else if groupID != nil {
|
||||
accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, platform)
|
||||
if err == nil && len(accounts) == 0 && hasForcePlatform {
|
||||
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, platform)
|
||||
}
|
||||
// 分组内无账号则返回空列表,由上层处理错误,不再回退到全平台查询
|
||||
} else {
|
||||
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, platform)
|
||||
}
|
||||
@@ -685,6 +687,23 @@ func (s *GatewayService) isAccountAllowedForPlatform(account *Account, platform
|
||||
return account.Platform == platform
|
||||
}
|
||||
|
||||
// isAccountInGroup checks if the account belongs to the specified group.
|
||||
// Returns true if groupID is nil (no group restriction) or account belongs to the group.
|
||||
func (s *GatewayService) isAccountInGroup(account *Account, groupID *int64) bool {
|
||||
if groupID == nil {
|
||||
return true // 无分组限制
|
||||
}
|
||||
if account == nil {
|
||||
return false
|
||||
}
|
||||
for _, ag := range account.AccountGroups {
|
||||
if ag.GroupID == *groupID {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (s *GatewayService) tryAcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int) (*AcquireResult, error) {
|
||||
if s.concurrencyService == nil {
|
||||
return &AcquireResult{Acquired: true, ReleaseFunc: func() {}}, nil
|
||||
@@ -723,8 +742,8 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
|
||||
if err == nil && accountID > 0 {
|
||||
if _, excluded := excludedIDs[accountID]; !excluded {
|
||||
account, err := s.accountRepo.GetByID(ctx, accountID)
|
||||
// 检查账号平台是否匹配(确保粘性会话不会跨平台)
|
||||
if err == nil && account.Platform == platform && account.IsSchedulable() && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
|
||||
// 检查账号分组归属和平台匹配(确保粘性会话不会跨分组或跨平台)
|
||||
if err == nil && s.isAccountInGroup(account, groupID) && account.Platform == platform && account.IsSchedulable() && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
|
||||
if err := s.cache.RefreshSessionTTL(ctx, sessionHash, stickySessionTTL); err != nil {
|
||||
log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err)
|
||||
}
|
||||
@@ -812,8 +831,8 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
|
||||
if err == nil && accountID > 0 {
|
||||
if _, excluded := excludedIDs[accountID]; !excluded {
|
||||
account, err := s.accountRepo.GetByID(ctx, accountID)
|
||||
// 检查账号是否有效:原生平台直接匹配,antigravity 需要启用混合调度
|
||||
if err == nil && account.IsSchedulable() && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
|
||||
// 检查账号分组归属和有效性:原生平台直接匹配,antigravity 需要启用混合调度
|
||||
if err == nil && s.isAccountInGroup(account, groupID) && account.IsSchedulable() && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
|
||||
if account.Platform == nativePlatform || (account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled()) {
|
||||
if err := s.cache.RefreshSessionTTL(ctx, sessionHash, stickySessionTTL); err != nil {
|
||||
log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err)
|
||||
@@ -1013,15 +1032,15 @@ func isClaudeCodeClient(userAgent string, metadataUserID string) bool {
|
||||
}
|
||||
|
||||
// systemIncludesClaudeCodePrompt 检查 system 中是否已包含 Claude Code 提示词
|
||||
// 支持 string 和 []any 两种格式
|
||||
// 使用前缀匹配支持多种变体(标准版、Agent SDK 版等)
|
||||
func systemIncludesClaudeCodePrompt(system any) bool {
|
||||
switch v := system.(type) {
|
||||
case string:
|
||||
return v == claudeCodeSystemPrompt
|
||||
return hasClaudeCodePrefix(v)
|
||||
case []any:
|
||||
for _, item := range v {
|
||||
if m, ok := item.(map[string]any); ok {
|
||||
if text, ok := m["text"].(string); ok && text == claudeCodeSystemPrompt {
|
||||
if text, ok := m["text"].(string); ok && hasClaudeCodePrefix(text) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
@@ -1030,6 +1049,16 @@ func systemIncludesClaudeCodePrompt(system any) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// hasClaudeCodePrefix 检查文本是否以 Claude Code 提示词的特征前缀开头
|
||||
func hasClaudeCodePrefix(text string) bool {
|
||||
for _, prefix := range claudeCodePromptPrefixes {
|
||||
if strings.HasPrefix(text, prefix) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// injectClaudeCodePrompt 在 system 开头注入 Claude Code 提示词
|
||||
// 处理 null、字符串、数组三种格式
|
||||
func injectClaudeCodePrompt(body []byte, system any) []byte {
|
||||
@@ -1073,6 +1102,124 @@ func injectClaudeCodePrompt(body []byte, system any) []byte {
|
||||
return result
|
||||
}
|
||||
|
||||
// enforceCacheControlLimit 强制执行 cache_control 块数量限制(最多 4 个)
|
||||
// 超限时优先从 messages 中移除 cache_control,保护 system 中的缓存控制
|
||||
func enforceCacheControlLimit(body []byte) []byte {
|
||||
var data map[string]any
|
||||
if err := json.Unmarshal(body, &data); err != nil {
|
||||
return body
|
||||
}
|
||||
|
||||
// 计算当前 cache_control 块数量
|
||||
count := countCacheControlBlocks(data)
|
||||
if count <= maxCacheControlBlocks {
|
||||
return body
|
||||
}
|
||||
|
||||
// 超限:优先从 messages 中移除,再从 system 中移除
|
||||
for count > maxCacheControlBlocks {
|
||||
if removeCacheControlFromMessages(data) {
|
||||
count--
|
||||
continue
|
||||
}
|
||||
if removeCacheControlFromSystem(data) {
|
||||
count--
|
||||
continue
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
result, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
return body
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// countCacheControlBlocks 统计 system 和 messages 中的 cache_control 块数量
|
||||
func countCacheControlBlocks(data map[string]any) int {
|
||||
count := 0
|
||||
|
||||
// 统计 system 中的块
|
||||
if system, ok := data["system"].([]any); ok {
|
||||
for _, item := range system {
|
||||
if m, ok := item.(map[string]any); ok {
|
||||
if _, has := m["cache_control"]; has {
|
||||
count++
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 统计 messages 中的块
|
||||
if messages, ok := data["messages"].([]any); ok {
|
||||
for _, msg := range messages {
|
||||
if msgMap, ok := msg.(map[string]any); ok {
|
||||
if content, ok := msgMap["content"].([]any); ok {
|
||||
for _, item := range content {
|
||||
if m, ok := item.(map[string]any); ok {
|
||||
if _, has := m["cache_control"]; has {
|
||||
count++
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return count
|
||||
}
|
||||
|
||||
// removeCacheControlFromMessages 从 messages 中移除一个 cache_control(从头开始)
|
||||
// 返回 true 表示成功移除,false 表示没有可移除的
|
||||
func removeCacheControlFromMessages(data map[string]any) bool {
|
||||
messages, ok := data["messages"].([]any)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
for _, msg := range messages {
|
||||
msgMap, ok := msg.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
content, ok := msgMap["content"].([]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
for _, item := range content {
|
||||
if m, ok := item.(map[string]any); ok {
|
||||
if _, has := m["cache_control"]; has {
|
||||
delete(m, "cache_control")
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// removeCacheControlFromSystem 从 system 中移除一个 cache_control(从尾部开始,保护注入的 prompt)
|
||||
// 返回 true 表示成功移除,false 表示没有可移除的
|
||||
func removeCacheControlFromSystem(data map[string]any) bool {
|
||||
system, ok := data["system"].([]any)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
// 从尾部开始移除,保护开头注入的 Claude Code prompt
|
||||
for i := len(system) - 1; i >= 0; i-- {
|
||||
if m, ok := system[i].(map[string]any); ok {
|
||||
if _, has := m["cache_control"]; has {
|
||||
delete(m, "cache_control")
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Forward 转发请求到Claude API
|
||||
func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *Account, parsed *ParsedRequest) (*ForwardResult, error) {
|
||||
startTime := time.Now()
|
||||
@@ -1093,6 +1240,9 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
body = injectClaudeCodePrompt(body, parsed.System)
|
||||
}
|
||||
|
||||
// 强制执行 cache_control 块数量限制(最多 4 个)
|
||||
body = enforceCacheControlLimit(body)
|
||||
|
||||
// 应用模型映射(仅对apikey类型账号)
|
||||
originalModel := reqModel
|
||||
if account.Type == AccountTypeAPIKey {
|
||||
@@ -1316,6 +1466,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
// 处理正常响应
|
||||
var usage *ClaudeUsage
|
||||
var firstTokenMs *int
|
||||
var clientDisconnect bool
|
||||
if reqStream {
|
||||
streamResult, err := s.handleStreamingResponse(ctx, resp, c, account, startTime, originalModel, reqModel)
|
||||
if err != nil {
|
||||
@@ -1328,6 +1479,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
}
|
||||
usage = streamResult.usage
|
||||
firstTokenMs = streamResult.firstTokenMs
|
||||
clientDisconnect = streamResult.clientDisconnect
|
||||
} else {
|
||||
usage, err = s.handleNonStreamingResponse(ctx, resp, c, account, originalModel, reqModel)
|
||||
if err != nil {
|
||||
@@ -1336,12 +1488,13 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
}
|
||||
|
||||
return &ForwardResult{
|
||||
RequestID: resp.Header.Get("x-request-id"),
|
||||
Usage: *usage,
|
||||
Model: originalModel, // 使用原始模型用于计费和日志
|
||||
Stream: reqStream,
|
||||
Duration: time.Since(startTime),
|
||||
FirstTokenMs: firstTokenMs,
|
||||
RequestID: resp.Header.Get("x-request-id"),
|
||||
Usage: *usage,
|
||||
Model: originalModel, // 使用原始模型用于计费和日志
|
||||
Stream: reqStream,
|
||||
Duration: time.Since(startTime),
|
||||
FirstTokenMs: firstTokenMs,
|
||||
ClientDisconnect: clientDisconnect,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -1696,8 +1849,9 @@ func (s *GatewayService) handleRetryExhaustedError(ctx context.Context, resp *ht
|
||||
|
||||
// streamingResult 流式响应结果
|
||||
type streamingResult struct {
|
||||
usage *ClaudeUsage
|
||||
firstTokenMs *int
|
||||
usage *ClaudeUsage
|
||||
firstTokenMs *int
|
||||
clientDisconnect bool // 客户端是否在流式传输过程中断开
|
||||
}
|
||||
|
||||
func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, startTime time.Time, originalModel, mappedModel string) (*streamingResult, error) {
|
||||
@@ -1793,14 +1947,27 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
|
||||
}
|
||||
|
||||
needModelReplace := originalModel != mappedModel
|
||||
clientDisconnected := false // 客户端断开标志,断开后继续读取上游以获取完整usage
|
||||
|
||||
for {
|
||||
select {
|
||||
case ev, ok := <-events:
|
||||
if !ok {
|
||||
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, nil
|
||||
// 上游完成,返回结果
|
||||
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: clientDisconnected}, nil
|
||||
}
|
||||
if ev.err != nil {
|
||||
// 检测 context 取消(客户端断开会导致 context 取消,进而影响上游读取)
|
||||
if errors.Is(ev.err, context.Canceled) || errors.Is(ev.err, context.DeadlineExceeded) {
|
||||
log.Printf("Context canceled during streaming, returning collected usage")
|
||||
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil
|
||||
}
|
||||
// 客户端已通过写入失败检测到断开,上游也出错了,返回已收集的 usage
|
||||
if clientDisconnected {
|
||||
log.Printf("Upstream read error after client disconnect: %v, returning collected usage", ev.err)
|
||||
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil
|
||||
}
|
||||
// 客户端未断开,正常的错误处理
|
||||
if errors.Is(ev.err, bufio.ErrTooLong) {
|
||||
log.Printf("SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, ev.err)
|
||||
sendErrorEvent("response_too_large")
|
||||
@@ -1811,38 +1978,40 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
|
||||
}
|
||||
line := ev.line
|
||||
if line == "event: error" {
|
||||
// 上游返回错误事件,如果客户端已断开仍返回已收集的 usage
|
||||
if clientDisconnected {
|
||||
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil
|
||||
}
|
||||
return nil, errors.New("have error in stream")
|
||||
}
|
||||
|
||||
// Extract data from SSE line (supports both "data: " and "data:" formats)
|
||||
var data string
|
||||
if sseDataRe.MatchString(line) {
|
||||
data := sseDataRe.ReplaceAllString(line, "")
|
||||
|
||||
data = sseDataRe.ReplaceAllString(line, "")
|
||||
// 如果有模型映射,替换响应中的model字段
|
||||
if needModelReplace {
|
||||
line = s.replaceModelInSSELine(line, mappedModel, originalModel)
|
||||
}
|
||||
}
|
||||
|
||||
// 转发行
|
||||
// 写入客户端(统一处理 data 行和非 data 行)
|
||||
if !clientDisconnected {
|
||||
if _, err := fmt.Fprintf(w, "%s\n", line); err != nil {
|
||||
sendErrorEvent("write_failed")
|
||||
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, err
|
||||
clientDisconnected = true
|
||||
log.Printf("Client disconnected during streaming, continuing to drain upstream for billing")
|
||||
} else {
|
||||
flusher.Flush()
|
||||
}
|
||||
flusher.Flush()
|
||||
}
|
||||
|
||||
// 记录首字时间:第一个有效的 content_block_delta 或 message_start
|
||||
if firstTokenMs == nil && data != "" && data != "[DONE]" {
|
||||
// 无论客户端是否断开,都解析 usage(仅对 data 行)
|
||||
if data != "" {
|
||||
if firstTokenMs == nil && data != "[DONE]" {
|
||||
ms := int(time.Since(startTime).Milliseconds())
|
||||
firstTokenMs = &ms
|
||||
}
|
||||
s.parseSSEUsage(data, usage)
|
||||
} else {
|
||||
// 非 data 行直接转发
|
||||
if _, err := fmt.Fprintf(w, "%s\n", line); err != nil {
|
||||
sendErrorEvent("write_failed")
|
||||
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, err
|
||||
}
|
||||
flusher.Flush()
|
||||
}
|
||||
|
||||
case <-intervalCh:
|
||||
@@ -1850,6 +2019,11 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
|
||||
if time.Since(lastRead) < streamInterval {
|
||||
continue
|
||||
}
|
||||
if clientDisconnected {
|
||||
// 客户端已断开,上游也超时了,返回已收集的 usage
|
||||
log.Printf("Upstream timeout after client disconnect, returning collected usage")
|
||||
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil
|
||||
}
|
||||
log.Printf("Stream data interval timeout: account=%d model=%s interval=%s", account.ID, originalModel, streamInterval)
|
||||
sendErrorEvent("stream_timeout")
|
||||
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout")
|
||||
@@ -2003,6 +2177,7 @@ type RecordUsageInput struct {
|
||||
User *User
|
||||
Account *Account
|
||||
Subscription *UserSubscription // 可选:订阅信息
|
||||
UserAgent string // 请求的 User-Agent
|
||||
}
|
||||
|
||||
// RecordUsage 记录使用量并扣费(或更新订阅用量)
|
||||
@@ -2088,6 +2263,11 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
|
||||
// 添加 UserAgent
|
||||
if input.UserAgent != "" {
|
||||
usageLog.UserAgent = &input.UserAgent
|
||||
}
|
||||
|
||||
// 添加分组和订阅关联
|
||||
if apiKey.GroupID != nil {
|
||||
usageLog.GroupID = apiKey.GroupID
|
||||
|
||||
@@ -90,6 +90,9 @@ func (m *mockAccountRepoForGemini) SetError(ctx context.Context, id int64, error
|
||||
func (m *mockAccountRepoForGemini) SetSchedulable(ctx context.Context, id int64, schedulable bool) error {
|
||||
return nil
|
||||
}
|
||||
func (m *mockAccountRepoForGemini) AutoPauseExpiredAccounts(ctx context.Context, now time.Time) (int64, error) {
|
||||
return 0, nil
|
||||
}
|
||||
func (m *mockAccountRepoForGemini) BindGroups(ctx context.Context, accountID int64, groupIDs []int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1092,6 +1092,7 @@ type OpenAIRecordUsageInput struct {
|
||||
User *User
|
||||
Account *Account
|
||||
Subscription *UserSubscription
|
||||
UserAgent string // 请求的 User-Agent
|
||||
}
|
||||
|
||||
// RecordUsage records usage and deducts balance
|
||||
@@ -1161,6 +1162,11 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
|
||||
// 添加 UserAgent
|
||||
if input.UserAgent != "" {
|
||||
usageLog.UserAgent = &input.UserAgent
|
||||
}
|
||||
|
||||
if apiKey.GroupID != nil {
|
||||
usageLog.GroupID = apiKey.GroupID
|
||||
}
|
||||
|
||||
@@ -38,6 +38,7 @@ type UsageLog struct {
|
||||
Stream bool
|
||||
DurationMs *int
|
||||
FirstTokenMs *int
|
||||
UserAgent *string
|
||||
|
||||
// 图片生成字段
|
||||
ImageCount int
|
||||
|
||||
@@ -319,3 +319,12 @@ func (s *UsageService) GetGlobalStats(ctx context.Context, startTime, endTime ti
|
||||
}
|
||||
return stats, nil
|
||||
}
|
||||
|
||||
// GetStatsWithFilters returns usage stats with optional filters.
|
||||
func (s *UsageService) GetStatsWithFilters(ctx context.Context, filters usagestats.UsageLogFilters) (*usagestats.UsageStats, error) {
|
||||
stats, err := s.usageRepo.GetStatsWithFilters(ctx, filters)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get usage stats with filters: %w", err)
|
||||
}
|
||||
return stats, nil
|
||||
}
|
||||
|
||||
@@ -47,6 +47,13 @@ func ProvideTokenRefreshService(
|
||||
return svc
|
||||
}
|
||||
|
||||
// ProvideAccountExpiryService creates and starts AccountExpiryService.
|
||||
func ProvideAccountExpiryService(accountRepo AccountRepository) *AccountExpiryService {
|
||||
svc := NewAccountExpiryService(accountRepo, time.Minute)
|
||||
svc.Start()
|
||||
return svc
|
||||
}
|
||||
|
||||
// ProvideTimingWheelService creates and starts TimingWheelService
|
||||
func ProvideTimingWheelService() *TimingWheelService {
|
||||
svc := NewTimingWheelService()
|
||||
@@ -110,6 +117,7 @@ var ProviderSet = wire.NewSet(
|
||||
NewCRSSyncService,
|
||||
ProvideUpdateService,
|
||||
ProvideTokenRefreshService,
|
||||
ProvideAccountExpiryService,
|
||||
ProvideTimingWheelService,
|
||||
ProvideDeferredService,
|
||||
NewAntigravityQuotaFetcher,
|
||||
|
||||
Reference in New Issue
Block a user