Merge remote-tracking branch 'upstream/main'

This commit is contained in:
song
2026-01-08 12:17:56 +08:00
99 changed files with 3678 additions and 512 deletions

View File

@@ -52,6 +52,15 @@ type Config struct {
RunMode string `mapstructure:"run_mode" yaml:"run_mode"`
Timezone string `mapstructure:"timezone"` // e.g. "Asia/Shanghai", "UTC"
Gemini GeminiConfig `mapstructure:"gemini"`
Update UpdateConfig `mapstructure:"update"`
}
// UpdateConfig 在线更新相关配置
type UpdateConfig struct {
// ProxyURL 用于访问 GitHub 的代理地址
// 支持 http/https/socks5/socks5h 协议
// 例如: "http://127.0.0.1:7890", "socks5://127.0.0.1:1080"
ProxyURL string `mapstructure:"proxy_url"`
}
type GeminiConfig struct {
@@ -148,7 +157,7 @@ type CSPConfig struct {
}
type ProxyProbeConfig struct {
InsecureSkipVerify bool `mapstructure:"insecure_skip_verify"`
InsecureSkipVerify bool `mapstructure:"insecure_skip_verify"` // 已禁用:禁止跳过 TLS 证书验证
}
type BillingConfig struct {
@@ -448,8 +457,8 @@ func setDefaults() {
"raw.githubusercontent.com",
})
viper.SetDefault("security.url_allowlist.crs_hosts", []string{})
viper.SetDefault("security.url_allowlist.allow_private_hosts", false)
viper.SetDefault("security.url_allowlist.allow_insecure_http", false)
viper.SetDefault("security.url_allowlist.allow_private_hosts", true)
viper.SetDefault("security.url_allowlist.allow_insecure_http", true)
viper.SetDefault("security.response_headers.enabled", false)
viper.SetDefault("security.response_headers.additional_allowed", []string{})
viper.SetDefault("security.response_headers.force_remove", []string{})
@@ -558,6 +567,10 @@ func setDefaults() {
viper.SetDefault("gemini.oauth.client_secret", "")
viper.SetDefault("gemini.oauth.scopes", "")
viper.SetDefault("gemini.quota.policy", "")
// Update - 在线更新配置
// 代理地址为空表示直连 GitHub适用于海外服务器
viper.SetDefault("update.proxy_url", "")
}
func (c *Config) Validate() error {

View File

@@ -80,8 +80,11 @@ func TestLoadDefaultSecurityToggles(t *testing.T) {
if cfg.Security.URLAllowlist.Enabled {
t.Fatalf("URLAllowlist.Enabled = true, want false")
}
if cfg.Security.URLAllowlist.AllowInsecureHTTP {
t.Fatalf("URLAllowlist.AllowInsecureHTTP = true, want false")
if !cfg.Security.URLAllowlist.AllowInsecureHTTP {
t.Fatalf("URLAllowlist.AllowInsecureHTTP = false, want true")
}
if !cfg.Security.URLAllowlist.AllowPrivateHosts {
t.Fatalf("URLAllowlist.AllowPrivateHosts = false, want true")
}
if cfg.Security.ResponseHeaders.Enabled {
t.Fatalf("ResponseHeaders.Enabled = true, want false")

View File

@@ -85,6 +85,8 @@ type CreateAccountRequest struct {
Concurrency int `json:"concurrency"`
Priority int `json:"priority"`
GroupIDs []int64 `json:"group_ids"`
ExpiresAt *int64 `json:"expires_at"`
AutoPauseOnExpired *bool `json:"auto_pause_on_expired"`
ConfirmMixedChannelRisk *bool `json:"confirm_mixed_channel_risk"` // 用户确认混合渠道风险
}
@@ -101,6 +103,8 @@ type UpdateAccountRequest struct {
Priority *int `json:"priority"`
Status string `json:"status" binding:"omitempty,oneof=active inactive"`
GroupIDs *[]int64 `json:"group_ids"`
ExpiresAt *int64 `json:"expires_at"`
AutoPauseOnExpired *bool `json:"auto_pause_on_expired"`
ConfirmMixedChannelRisk *bool `json:"confirm_mixed_channel_risk"` // 用户确认混合渠道风险
}
@@ -204,6 +208,8 @@ func (h *AccountHandler) Create(c *gin.Context) {
Concurrency: req.Concurrency,
Priority: req.Priority,
GroupIDs: req.GroupIDs,
ExpiresAt: req.ExpiresAt,
AutoPauseOnExpired: req.AutoPauseOnExpired,
SkipMixedChannelCheck: skipCheck,
})
if err != nil {
@@ -261,6 +267,8 @@ func (h *AccountHandler) Update(c *gin.Context) {
Priority: req.Priority, // 指针类型nil 表示未提供
Status: req.Status,
GroupIDs: req.GroupIDs,
ExpiresAt: req.ExpiresAt,
AutoPauseOnExpired: req.AutoPauseOnExpired,
SkipMixedChannelCheck: skipCheck,
})
if err != nil {

View File

@@ -26,31 +26,33 @@ func NewDashboardHandler(dashboardService *service.DashboardService) *DashboardH
}
// parseTimeRange parses start_date, end_date query parameters
// Uses user's timezone if provided, otherwise falls back to server timezone
func parseTimeRange(c *gin.Context) (time.Time, time.Time) {
now := timezone.Now()
userTZ := c.Query("timezone") // Get user's timezone from request
now := timezone.NowInUserLocation(userTZ)
startDate := c.Query("start_date")
endDate := c.Query("end_date")
var startTime, endTime time.Time
if startDate != "" {
if t, err := timezone.ParseInLocation("2006-01-02", startDate); err == nil {
if t, err := timezone.ParseInUserLocation("2006-01-02", startDate, userTZ); err == nil {
startTime = t
} else {
startTime = timezone.StartOfDay(now.AddDate(0, 0, -7))
startTime = timezone.StartOfDayInUserLocation(now.AddDate(0, 0, -7), userTZ)
}
} else {
startTime = timezone.StartOfDay(now.AddDate(0, 0, -7))
startTime = timezone.StartOfDayInUserLocation(now.AddDate(0, 0, -7), userTZ)
}
if endDate != "" {
if t, err := timezone.ParseInLocation("2006-01-02", endDate); err == nil {
if t, err := timezone.ParseInUserLocation("2006-01-02", endDate, userTZ); err == nil {
endTime = t.Add(24 * time.Hour) // Include the end date
} else {
endTime = timezone.StartOfDay(now.AddDate(0, 0, 1))
endTime = timezone.StartOfDayInUserLocation(now.AddDate(0, 0, 1), userTZ)
}
} else {
endTime = timezone.StartOfDay(now.AddDate(0, 0, 1))
endTime = timezone.StartOfDayInUserLocation(now.AddDate(0, 0, 1), userTZ)
}
return startTime, endTime

View File

@@ -102,8 +102,9 @@ func (h *UsageHandler) List(c *gin.Context) {
// Parse date range
var startTime, endTime *time.Time
userTZ := c.Query("timezone") // Get user's timezone from request
if startDateStr := c.Query("start_date"); startDateStr != "" {
t, err := timezone.ParseInLocation("2006-01-02", startDateStr)
t, err := timezone.ParseInUserLocation("2006-01-02", startDateStr, userTZ)
if err != nil {
response.BadRequest(c, "Invalid start_date format, use YYYY-MM-DD")
return
@@ -112,7 +113,7 @@ func (h *UsageHandler) List(c *gin.Context) {
}
if endDateStr := c.Query("end_date"); endDateStr != "" {
t, err := timezone.ParseInLocation("2006-01-02", endDateStr)
t, err := timezone.ParseInUserLocation("2006-01-02", endDateStr, userTZ)
if err != nil {
response.BadRequest(c, "Invalid end_date format, use YYYY-MM-DD")
return
@@ -151,8 +152,8 @@ func (h *UsageHandler) List(c *gin.Context) {
// Stats handles getting usage statistics with filters
// GET /api/v1/admin/usage/stats
func (h *UsageHandler) Stats(c *gin.Context) {
// Parse filters
var userID, apiKeyID int64
// Parse filters - same as List endpoint
var userID, apiKeyID, accountID, groupID int64
if userIDStr := c.Query("user_id"); userIDStr != "" {
id, err := strconv.ParseInt(userIDStr, 10, 64)
if err != nil {
@@ -171,8 +172,50 @@ func (h *UsageHandler) Stats(c *gin.Context) {
apiKeyID = id
}
if accountIDStr := c.Query("account_id"); accountIDStr != "" {
id, err := strconv.ParseInt(accountIDStr, 10, 64)
if err != nil {
response.BadRequest(c, "Invalid account_id")
return
}
accountID = id
}
if groupIDStr := c.Query("group_id"); groupIDStr != "" {
id, err := strconv.ParseInt(groupIDStr, 10, 64)
if err != nil {
response.BadRequest(c, "Invalid group_id")
return
}
groupID = id
}
model := c.Query("model")
var stream *bool
if streamStr := c.Query("stream"); streamStr != "" {
val, err := strconv.ParseBool(streamStr)
if err != nil {
response.BadRequest(c, "Invalid stream value, use true or false")
return
}
stream = &val
}
var billingType *int8
if billingTypeStr := c.Query("billing_type"); billingTypeStr != "" {
val, err := strconv.ParseInt(billingTypeStr, 10, 8)
if err != nil {
response.BadRequest(c, "Invalid billing_type")
return
}
bt := int8(val)
billingType = &bt
}
// Parse date range
now := timezone.Now()
userTZ := c.Query("timezone")
now := timezone.NowInUserLocation(userTZ)
var startTime, endTime time.Time
startDateStr := c.Query("start_date")
@@ -180,12 +223,12 @@ func (h *UsageHandler) Stats(c *gin.Context) {
if startDateStr != "" && endDateStr != "" {
var err error
startTime, err = timezone.ParseInLocation("2006-01-02", startDateStr)
startTime, err = timezone.ParseInUserLocation("2006-01-02", startDateStr, userTZ)
if err != nil {
response.BadRequest(c, "Invalid start_date format, use YYYY-MM-DD")
return
}
endTime, err = timezone.ParseInLocation("2006-01-02", endDateStr)
endTime, err = timezone.ParseInUserLocation("2006-01-02", endDateStr, userTZ)
if err != nil {
response.BadRequest(c, "Invalid end_date format, use YYYY-MM-DD")
return
@@ -195,39 +238,31 @@ func (h *UsageHandler) Stats(c *gin.Context) {
period := c.DefaultQuery("period", "today")
switch period {
case "today":
startTime = timezone.StartOfDay(now)
startTime = timezone.StartOfDayInUserLocation(now, userTZ)
case "week":
startTime = now.AddDate(0, 0, -7)
case "month":
startTime = now.AddDate(0, -1, 0)
default:
startTime = timezone.StartOfDay(now)
startTime = timezone.StartOfDayInUserLocation(now, userTZ)
}
endTime = now
}
if apiKeyID > 0 {
stats, err := h.usageService.GetStatsByAPIKey(c.Request.Context(), apiKeyID, startTime, endTime)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, stats)
return
// Build filters and call GetStatsWithFilters
filters := usagestats.UsageLogFilters{
UserID: userID,
APIKeyID: apiKeyID,
AccountID: accountID,
GroupID: groupID,
Model: model,
Stream: stream,
BillingType: billingType,
StartTime: &startTime,
EndTime: &endTime,
}
if userID > 0 {
stats, err := h.usageService.GetStatsByUser(c.Request.Context(), userID, startTime, endTime)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, stats)
return
}
// Get global stats
stats, err := h.usageService.GetGlobalStats(c.Request.Context(), startTime, endTime)
stats, err := h.usageService.GetStatsWithFilters(c.Request.Context(), filters)
if err != nil {
response.ErrorFrom(c, err)
return

View File

@@ -1,7 +1,11 @@
// Package dto provides data transfer objects for HTTP handlers.
package dto
import "github.com/Wei-Shaw/sub2api/internal/service"
import (
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
)
func UserFromServiceShallow(u *service.User) *User {
if u == nil {
@@ -120,6 +124,8 @@ func AccountFromServiceShallow(a *service.Account) *Account {
Status: a.Status,
ErrorMessage: a.ErrorMessage,
LastUsedAt: a.LastUsedAt,
ExpiresAt: timeToUnixSeconds(a.ExpiresAt),
AutoPauseOnExpired: a.AutoPauseOnExpired,
CreatedAt: a.CreatedAt,
UpdatedAt: a.UpdatedAt,
Schedulable: a.Schedulable,
@@ -157,6 +163,14 @@ func AccountFromService(a *service.Account) *Account {
return out
}
func timeToUnixSeconds(value *time.Time) *int64 {
if value == nil {
return nil
}
ts := value.Unix()
return &ts
}
func AccountGroupFromService(ag *service.AccountGroup) *AccountGroup {
if ag == nil {
return nil

View File

@@ -60,21 +60,23 @@ type Group struct {
}
type Account struct {
ID int64 `json:"id"`
Name string `json:"name"`
Notes *string `json:"notes"`
Platform string `json:"platform"`
Type string `json:"type"`
Credentials map[string]any `json:"credentials"`
Extra map[string]any `json:"extra"`
ProxyID *int64 `json:"proxy_id"`
Concurrency int `json:"concurrency"`
Priority int `json:"priority"`
Status string `json:"status"`
ErrorMessage string `json:"error_message"`
LastUsedAt *time.Time `json:"last_used_at"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
ID int64 `json:"id"`
Name string `json:"name"`
Notes *string `json:"notes"`
Platform string `json:"platform"`
Type string `json:"type"`
Credentials map[string]any `json:"credentials"`
Extra map[string]any `json:"extra"`
ProxyID *int64 `json:"proxy_id"`
Concurrency int `json:"concurrency"`
Priority int `json:"priority"`
Status string `json:"status"`
ErrorMessage string `json:"error_message"`
LastUsedAt *time.Time `json:"last_used_at"`
ExpiresAt *int64 `json:"expires_at"`
AutoPauseOnExpired bool `json:"auto_pause_on_expired"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
Schedulable bool `json:"schedulable"`

View File

@@ -108,6 +108,9 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
// 获取订阅信息可能为nil- 提前获取用于后续检查
subscription, _ := middleware2.GetSubscriptionFromContext(c)
// 获取 User-Agent
userAgent := c.Request.UserAgent()
// 0. 检查wait队列是否已满
maxWait := service.CalculateMaxWait(subject.Concurrency)
canWait, err := h.concurrencyHelper.IncrementWaitCount(c.Request.Context(), subject.UserID, maxWait)
@@ -267,7 +270,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
}
// 异步记录使用量subscription已在函数开头获取
go func(result *service.ForwardResult, usedAccount *service.Account) {
go func(result *service.ForwardResult, usedAccount *service.Account, ua string) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
@@ -276,10 +279,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
User: apiKey.User,
Account: usedAccount,
Subscription: subscription,
UserAgent: ua,
}); err != nil {
log.Printf("Record usage failed: %v", err)
}
}(result, account)
}(result, account, userAgent)
return
}
}
@@ -394,7 +398,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
}
// 异步记录使用量subscription已在函数开头获取
go func(result *service.ForwardResult, usedAccount *service.Account) {
go func(result *service.ForwardResult, usedAccount *service.Account, ua string) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
@@ -403,10 +407,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
User: apiKey.User,
Account: usedAccount,
Subscription: subscription,
UserAgent: ua,
}); err != nil {
log.Printf("Record usage failed: %v", err)
}
}(result, account)
}(result, account, userAgent)
return
}
}

View File

@@ -83,19 +83,33 @@ func NewConcurrencyHelper(concurrencyService *service.ConcurrencyService, pingFo
// wrapReleaseOnDone ensures release runs at most once and still triggers on context cancellation.
// 用于避免客户端断开或上游超时导致的并发槽位泄漏。
// 修复:添加 quit channel 确保 goroutine 及时退出,避免泄露
func wrapReleaseOnDone(ctx context.Context, releaseFunc func()) func() {
if releaseFunc == nil {
return nil
}
var once sync.Once
wrapped := func() {
once.Do(releaseFunc)
quit := make(chan struct{})
release := func() {
once.Do(func() {
releaseFunc()
close(quit) // 通知监听 goroutine 退出
})
}
go func() {
<-ctx.Done()
wrapped()
select {
case <-ctx.Done():
// Context 取消时释放资源
release()
case <-quit:
// 正常释放已完成goroutine 退出
return
}
}()
return wrapped
return release
}
// IncrementWaitCount increments the wait count for a user

View File

@@ -0,0 +1,141 @@
package handler
import (
"context"
"runtime"
"sync/atomic"
"testing"
"time"
)
// TestWrapReleaseOnDone_NoGoroutineLeak 验证 wrapReleaseOnDone 修复后不会泄露 goroutine
func TestWrapReleaseOnDone_NoGoroutineLeak(t *testing.T) {
// 记录测试开始时的 goroutine 数量
runtime.GC()
time.Sleep(100 * time.Millisecond)
initialGoroutines := runtime.NumGoroutine()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
var releaseCount int32
release := wrapReleaseOnDone(ctx, func() {
atomic.AddInt32(&releaseCount, 1)
})
// 正常释放
release()
// 等待足够时间确保 goroutine 退出
time.Sleep(200 * time.Millisecond)
// 验证只释放一次
if count := atomic.LoadInt32(&releaseCount); count != 1 {
t.Errorf("expected release count to be 1, got %d", count)
}
// 强制 GC清理已退出的 goroutine
runtime.GC()
time.Sleep(100 * time.Millisecond)
// 验证 goroutine 数量没有增加允许±2的误差考虑到测试框架本身可能创建的 goroutine
finalGoroutines := runtime.NumGoroutine()
if finalGoroutines > initialGoroutines+2 {
t.Errorf("goroutine leak detected: initial=%d, final=%d, leaked=%d",
initialGoroutines, finalGoroutines, finalGoroutines-initialGoroutines)
}
}
// TestWrapReleaseOnDone_ContextCancellation 验证 context 取消时也能正确释放
func TestWrapReleaseOnDone_ContextCancellation(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
var releaseCount int32
_ = wrapReleaseOnDone(ctx, func() {
atomic.AddInt32(&releaseCount, 1)
})
// 取消 context应该触发释放
cancel()
// 等待释放完成
time.Sleep(100 * time.Millisecond)
// 验证释放被调用
if count := atomic.LoadInt32(&releaseCount); count != 1 {
t.Errorf("expected release count to be 1, got %d", count)
}
}
// TestWrapReleaseOnDone_MultipleCallsOnlyReleaseOnce 验证多次调用 release 只释放一次
func TestWrapReleaseOnDone_MultipleCallsOnlyReleaseOnce(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
var releaseCount int32
release := wrapReleaseOnDone(ctx, func() {
atomic.AddInt32(&releaseCount, 1)
})
// 调用多次
release()
release()
release()
// 等待执行完成
time.Sleep(100 * time.Millisecond)
// 验证只释放一次
if count := atomic.LoadInt32(&releaseCount); count != 1 {
t.Errorf("expected release count to be 1, got %d", count)
}
}
// TestWrapReleaseOnDone_NilReleaseFunc 验证 nil releaseFunc 不会 panic
func TestWrapReleaseOnDone_NilReleaseFunc(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
release := wrapReleaseOnDone(ctx, nil)
if release != nil {
t.Error("expected nil release function when releaseFunc is nil")
}
}
// TestWrapReleaseOnDone_ConcurrentCalls 验证并发调用的安全性
func TestWrapReleaseOnDone_ConcurrentCalls(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
var releaseCount int32
release := wrapReleaseOnDone(ctx, func() {
atomic.AddInt32(&releaseCount, 1)
})
// 并发调用 release
const numGoroutines = 10
for i := 0; i < numGoroutines; i++ {
go release()
}
// 等待所有 goroutine 完成
time.Sleep(200 * time.Millisecond)
// 验证只释放一次
if count := atomic.LoadInt32(&releaseCount); count != 1 {
t.Errorf("expected release count to be 1, got %d", count)
}
}
// BenchmarkWrapReleaseOnDone 性能基准测试
func BenchmarkWrapReleaseOnDone(b *testing.B) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
b.ResetTimer()
for i := 0; i < b.N; i++ {
release := wrapReleaseOnDone(ctx, func() {})
release()
}
}

View File

@@ -164,6 +164,9 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
// Get subscription (may be nil)
subscription, _ := middleware.GetSubscriptionFromContext(c)
// 获取 User-Agent
userAgent := c.Request.UserAgent()
// For Gemini native API, do not send Claude-style ping frames.
geminiConcurrency := NewConcurrencyHelper(h.concurrencyHelper.concurrencyService, SSEPingFormatNone, 0)
@@ -300,7 +303,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
}
// 6) record usage async
go func(result *service.ForwardResult, usedAccount *service.Account) {
go func(result *service.ForwardResult, usedAccount *service.Account, ua string) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
@@ -309,10 +312,11 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
User: apiKey.User,
Account: usedAccount,
Subscription: subscription,
UserAgent: ua,
}); err != nil {
log.Printf("Record usage failed: %v", err)
}
}(result, account)
}(result, account, userAgent)
return
}
}

View File

@@ -242,7 +242,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
}
// Async record usage
go func(result *service.OpenAIForwardResult, usedAccount *service.Account) {
go func(result *service.OpenAIForwardResult, usedAccount *service.Account, ua string) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
@@ -251,10 +251,11 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
User: apiKey.User,
Account: usedAccount,
Subscription: subscription,
UserAgent: ua,
}); err != nil {
log.Printf("Record usage failed: %v", err)
}
}(result, account)
}(result, account, userAgent)
return
}
}

View File

@@ -88,8 +88,9 @@ func (h *UsageHandler) List(c *gin.Context) {
// Parse date range
var startTime, endTime *time.Time
userTZ := c.Query("timezone") // Get user's timezone from request
if startDateStr := c.Query("start_date"); startDateStr != "" {
t, err := timezone.ParseInLocation("2006-01-02", startDateStr)
t, err := timezone.ParseInUserLocation("2006-01-02", startDateStr, userTZ)
if err != nil {
response.BadRequest(c, "Invalid start_date format, use YYYY-MM-DD")
return
@@ -98,7 +99,7 @@ func (h *UsageHandler) List(c *gin.Context) {
}
if endDateStr := c.Query("end_date"); endDateStr != "" {
t, err := timezone.ParseInLocation("2006-01-02", endDateStr)
t, err := timezone.ParseInUserLocation("2006-01-02", endDateStr, userTZ)
if err != nil {
response.BadRequest(c, "Invalid end_date format, use YYYY-MM-DD")
return
@@ -194,7 +195,8 @@ func (h *UsageHandler) Stats(c *gin.Context) {
}
// 获取时间范围参数
now := timezone.Now()
userTZ := c.Query("timezone") // Get user's timezone from request
now := timezone.NowInUserLocation(userTZ)
var startTime, endTime time.Time
// 优先使用 start_date 和 end_date 参数
@@ -204,12 +206,12 @@ func (h *UsageHandler) Stats(c *gin.Context) {
if startDateStr != "" && endDateStr != "" {
// 使用自定义日期范围
var err error
startTime, err = timezone.ParseInLocation("2006-01-02", startDateStr)
startTime, err = timezone.ParseInUserLocation("2006-01-02", startDateStr, userTZ)
if err != nil {
response.BadRequest(c, "Invalid start_date format, use YYYY-MM-DD")
return
}
endTime, err = timezone.ParseInLocation("2006-01-02", endDateStr)
endTime, err = timezone.ParseInUserLocation("2006-01-02", endDateStr, userTZ)
if err != nil {
response.BadRequest(c, "Invalid end_date format, use YYYY-MM-DD")
return
@@ -221,13 +223,13 @@ func (h *UsageHandler) Stats(c *gin.Context) {
period := c.DefaultQuery("period", "today")
switch period {
case "today":
startTime = timezone.StartOfDay(now)
startTime = timezone.StartOfDayInUserLocation(now, userTZ)
case "week":
startTime = now.AddDate(0, 0, -7)
case "month":
startTime = now.AddDate(0, -1, 0)
default:
startTime = timezone.StartOfDay(now)
startTime = timezone.StartOfDayInUserLocation(now, userTZ)
}
endTime = now
}
@@ -248,31 +250,33 @@ func (h *UsageHandler) Stats(c *gin.Context) {
}
// parseUserTimeRange parses start_date, end_date query parameters for user dashboard
// Uses user's timezone if provided, otherwise falls back to server timezone
func parseUserTimeRange(c *gin.Context) (time.Time, time.Time) {
now := timezone.Now()
userTZ := c.Query("timezone") // Get user's timezone from request
now := timezone.NowInUserLocation(userTZ)
startDate := c.Query("start_date")
endDate := c.Query("end_date")
var startTime, endTime time.Time
if startDate != "" {
if t, err := timezone.ParseInLocation("2006-01-02", startDate); err == nil {
if t, err := timezone.ParseInUserLocation("2006-01-02", startDate, userTZ); err == nil {
startTime = t
} else {
startTime = timezone.StartOfDay(now.AddDate(0, 0, -7))
startTime = timezone.StartOfDayInUserLocation(now.AddDate(0, 0, -7), userTZ)
}
} else {
startTime = timezone.StartOfDay(now.AddDate(0, 0, -7))
startTime = timezone.StartOfDayInUserLocation(now.AddDate(0, 0, -7), userTZ)
}
if endDate != "" {
if t, err := timezone.ParseInLocation("2006-01-02", endDate); err == nil {
if t, err := timezone.ParseInUserLocation("2006-01-02", endDate, userTZ); err == nil {
endTime = t.Add(24 * time.Hour) // Include the end date
} else {
endTime = timezone.StartOfDay(now.AddDate(0, 0, 1))
endTime = timezone.StartOfDayInUserLocation(now.AddDate(0, 0, 1), userTZ)
}
} else {
endTime = timezone.StartOfDay(now.AddDate(0, 0, 1))
endTime = timezone.StartOfDayInUserLocation(now.AddDate(0, 0, 1), userTZ)
}
return startTime, endTime

View File

@@ -16,7 +16,6 @@
package httpclient
import (
"crypto/tls"
"fmt"
"net/http"
"net/url"
@@ -40,7 +39,7 @@ type Options struct {
ProxyURL string // 代理 URL支持 http/https/socks5/socks5h
Timeout time.Duration // 请求总超时时间
ResponseHeaderTimeout time.Duration // 等待响应头超时时间
InsecureSkipVerify bool // 是否跳过 TLS 证书验证
InsecureSkipVerify bool // 是否跳过 TLS 证书验证(已禁用,不允许设置为 true
ProxyStrict bool // 严格代理模式:代理失败时返回错误而非回退
ValidateResolvedIP bool // 是否校验解析后的 IP防止 DNS Rebinding
AllowPrivateHosts bool // 允许私有地址解析(与 ValidateResolvedIP 一起使用)
@@ -113,7 +112,8 @@ func buildTransport(opts Options) (*http.Transport, error) {
}
if opts.InsecureSkipVerify {
transport.TLSClientConfig = &tls.Config{InsecureSkipVerify: true}
// 安全要求:禁止跳过证书验证,避免中间人攻击。
return nil, fmt.Errorf("insecure_skip_verify is not allowed; install a trusted certificate instead")
}
proxyURL := strings.TrimSpace(opts.ProxyURL)

View File

@@ -122,3 +122,40 @@ func StartOfMonth(t time.Time) time.Time {
func ParseInLocation(layout, value string) (time.Time, error) {
return time.ParseInLocation(layout, value, Location())
}
// ParseInUserLocation parses a time string in the user's timezone.
// If userTZ is empty or invalid, falls back to the configured server timezone.
func ParseInUserLocation(layout, value, userTZ string) (time.Time, error) {
loc := Location() // default to server timezone
if userTZ != "" {
if userLoc, err := time.LoadLocation(userTZ); err == nil {
loc = userLoc
}
}
return time.ParseInLocation(layout, value, loc)
}
// NowInUserLocation returns the current time in the user's timezone.
// If userTZ is empty or invalid, falls back to the configured server timezone.
func NowInUserLocation(userTZ string) time.Time {
if userTZ == "" {
return Now()
}
if userLoc, err := time.LoadLocation(userTZ); err == nil {
return time.Now().In(userLoc)
}
return Now()
}
// StartOfDayInUserLocation returns the start of the given day in the user's timezone.
// If userTZ is empty or invalid, falls back to the configured server timezone.
func StartOfDayInUserLocation(t time.Time, userTZ string) time.Time {
loc := Location()
if userTZ != "" {
if userLoc, err := time.LoadLocation(userTZ); err == nil {
loc = userLoc
}
}
t = t.In(loc)
return time.Date(t.Year(), t.Month(), t.Day(), 0, 0, 0, 0, loc)
}

View File

@@ -76,7 +76,8 @@ func (r *accountRepository) Create(ctx context.Context, account *service.Account
SetPriority(account.Priority).
SetStatus(account.Status).
SetErrorMessage(account.ErrorMessage).
SetSchedulable(account.Schedulable)
SetSchedulable(account.Schedulable).
SetAutoPauseOnExpired(account.AutoPauseOnExpired)
if account.ProxyID != nil {
builder.SetProxyID(*account.ProxyID)
@@ -84,6 +85,9 @@ func (r *accountRepository) Create(ctx context.Context, account *service.Account
if account.LastUsedAt != nil {
builder.SetLastUsedAt(*account.LastUsedAt)
}
if account.ExpiresAt != nil {
builder.SetExpiresAt(*account.ExpiresAt)
}
if account.RateLimitedAt != nil {
builder.SetRateLimitedAt(*account.RateLimitedAt)
}
@@ -280,7 +284,8 @@ func (r *accountRepository) Update(ctx context.Context, account *service.Account
SetPriority(account.Priority).
SetStatus(account.Status).
SetErrorMessage(account.ErrorMessage).
SetSchedulable(account.Schedulable)
SetSchedulable(account.Schedulable).
SetAutoPauseOnExpired(account.AutoPauseOnExpired)
if account.ProxyID != nil {
builder.SetProxyID(*account.ProxyID)
@@ -292,6 +297,11 @@ func (r *accountRepository) Update(ctx context.Context, account *service.Account
} else {
builder.ClearLastUsedAt()
}
if account.ExpiresAt != nil {
builder.SetExpiresAt(*account.ExpiresAt)
} else {
builder.ClearExpiresAt()
}
if account.RateLimitedAt != nil {
builder.SetRateLimitedAt(*account.RateLimitedAt)
} else {
@@ -570,6 +580,7 @@ func (r *accountRepository) ListSchedulable(ctx context.Context) ([]service.Acco
dbaccount.StatusEQ(service.StatusActive),
dbaccount.SchedulableEQ(true),
tempUnschedulablePredicate(),
notExpiredPredicate(now),
dbaccount.Or(dbaccount.OverloadUntilIsNil(), dbaccount.OverloadUntilLTE(now)),
dbaccount.Or(dbaccount.RateLimitResetAtIsNil(), dbaccount.RateLimitResetAtLTE(now)),
).
@@ -596,6 +607,7 @@ func (r *accountRepository) ListSchedulableByPlatform(ctx context.Context, platf
dbaccount.StatusEQ(service.StatusActive),
dbaccount.SchedulableEQ(true),
tempUnschedulablePredicate(),
notExpiredPredicate(now),
dbaccount.Or(dbaccount.OverloadUntilIsNil(), dbaccount.OverloadUntilLTE(now)),
dbaccount.Or(dbaccount.RateLimitResetAtIsNil(), dbaccount.RateLimitResetAtLTE(now)),
).
@@ -629,6 +641,7 @@ func (r *accountRepository) ListSchedulableByPlatforms(ctx context.Context, plat
dbaccount.StatusEQ(service.StatusActive),
dbaccount.SchedulableEQ(true),
tempUnschedulablePredicate(),
notExpiredPredicate(now),
dbaccount.Or(dbaccount.OverloadUntilIsNil(), dbaccount.OverloadUntilLTE(now)),
dbaccount.Or(dbaccount.RateLimitResetAtIsNil(), dbaccount.RateLimitResetAtLTE(now)),
).
@@ -727,6 +740,27 @@ func (r *accountRepository) SetSchedulable(ctx context.Context, id int64, schedu
return err
}
func (r *accountRepository) AutoPauseExpiredAccounts(ctx context.Context, now time.Time) (int64, error) {
result, err := r.sql.ExecContext(ctx, `
UPDATE accounts
SET schedulable = FALSE,
updated_at = NOW()
WHERE deleted_at IS NULL
AND schedulable = TRUE
AND auto_pause_on_expired = TRUE
AND expires_at IS NOT NULL
AND expires_at <= $1
`, now)
if err != nil {
return 0, err
}
rows, err := result.RowsAffected()
if err != nil {
return 0, err
}
return rows, nil
}
func (r *accountRepository) UpdateExtra(ctx context.Context, id int64, updates map[string]any) error {
if len(updates) == 0 {
return nil
@@ -861,6 +895,7 @@ func (r *accountRepository) queryAccountsByGroup(ctx context.Context, groupID in
preds = append(preds,
dbaccount.SchedulableEQ(true),
tempUnschedulablePredicate(),
notExpiredPredicate(now),
dbaccount.Or(dbaccount.OverloadUntilIsNil(), dbaccount.OverloadUntilLTE(now)),
dbaccount.Or(dbaccount.RateLimitResetAtIsNil(), dbaccount.RateLimitResetAtLTE(now)),
)
@@ -971,6 +1006,14 @@ func tempUnschedulablePredicate() dbpredicate.Account {
})
}
func notExpiredPredicate(now time.Time) dbpredicate.Account {
return dbaccount.Or(
dbaccount.ExpiresAtIsNil(),
dbaccount.ExpiresAtGT(now),
dbaccount.AutoPauseOnExpiredEQ(false),
)
}
func (r *accountRepository) loadTempUnschedStates(ctx context.Context, accountIDs []int64) (map[int64]tempUnschedSnapshot, error) {
out := make(map[int64]tempUnschedSnapshot)
if len(accountIDs) == 0 {
@@ -1086,6 +1129,8 @@ func accountEntityToService(m *dbent.Account) *service.Account {
Status: m.Status,
ErrorMessage: derefString(m.ErrorMessage),
LastUsedAt: m.LastUsedAt,
ExpiresAt: m.ExpiresAt,
AutoPauseOnExpired: m.AutoPauseOnExpired,
CreatedAt: m.CreatedAt,
UpdatedAt: m.UpdatedAt,
Schedulable: m.Schedulable,

View File

@@ -14,23 +14,33 @@ import (
)
type githubReleaseClient struct {
httpClient *http.Client
allowPrivateHosts bool
httpClient *http.Client
downloadHTTPClient *http.Client
}
func NewGitHubReleaseClient() service.GitHubReleaseClient {
allowPrivate := false
// NewGitHubReleaseClient 创建 GitHub Release 客户端
// proxyURL 为空时直连 GitHub支持 http/https/socks5/socks5h 协议
func NewGitHubReleaseClient(proxyURL string) service.GitHubReleaseClient {
sharedClient, err := httpclient.GetClient(httpclient.Options{
Timeout: 30 * time.Second,
ValidateResolvedIP: true,
AllowPrivateHosts: allowPrivate,
Timeout: 30 * time.Second,
ProxyURL: proxyURL,
})
if err != nil {
sharedClient = &http.Client{Timeout: 30 * time.Second}
}
// 下载客户端需要更长的超时时间
downloadClient, err := httpclient.GetClient(httpclient.Options{
Timeout: 10 * time.Minute,
ProxyURL: proxyURL,
})
if err != nil {
downloadClient = &http.Client{Timeout: 10 * time.Minute}
}
return &githubReleaseClient{
httpClient: sharedClient,
allowPrivateHosts: allowPrivate,
httpClient: sharedClient,
downloadHTTPClient: downloadClient,
}
}
@@ -68,15 +78,8 @@ func (c *githubReleaseClient) DownloadFile(ctx context.Context, url, dest string
return err
}
downloadClient, err := httpclient.GetClient(httpclient.Options{
Timeout: 10 * time.Minute,
ValidateResolvedIP: true,
AllowPrivateHosts: c.allowPrivateHosts,
})
if err != nil {
downloadClient = &http.Client{Timeout: 10 * time.Minute}
}
resp, err := downloadClient.Do(req)
// 使用预配置的下载客户端(已包含代理配置)
resp, err := c.downloadHTTPClient.Do(req)
if err != nil {
return err
}

View File

@@ -39,8 +39,8 @@ func (t *testTransport) RoundTrip(req *http.Request) (*http.Response, error) {
func newTestGitHubReleaseClient() *githubReleaseClient {
return &githubReleaseClient{
httpClient: &http.Client{},
allowPrivateHosts: true,
httpClient: &http.Client{},
downloadHTTPClient: &http.Client{},
}
}
@@ -234,7 +234,7 @@ func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_Success() {
httpClient: &http.Client{
Transport: &testTransport{testServerURL: s.srv.URL},
},
allowPrivateHosts: true,
downloadHTTPClient: &http.Client{},
}
release, err := s.client.FetchLatestRelease(context.Background(), "test/repo")
@@ -254,7 +254,7 @@ func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_Non200() {
httpClient: &http.Client{
Transport: &testTransport{testServerURL: s.srv.URL},
},
allowPrivateHosts: true,
downloadHTTPClient: &http.Client{},
}
_, err := s.client.FetchLatestRelease(context.Background(), "test/repo")
@@ -272,7 +272,7 @@ func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_InvalidJSON() {
httpClient: &http.Client{
Transport: &testTransport{testServerURL: s.srv.URL},
},
allowPrivateHosts: true,
downloadHTTPClient: &http.Client{},
}
_, err := s.client.FetchLatestRelease(context.Background(), "test/repo")
@@ -288,7 +288,7 @@ func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_ContextCancel() {
httpClient: &http.Client{
Transport: &testTransport{testServerURL: s.srv.URL},
},
allowPrivateHosts: true,
downloadHTTPClient: &http.Client{},
}
ctx, cancel := context.WithCancel(context.Background())

View File

@@ -8,7 +8,6 @@ import (
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
"github.com/Wei-Shaw/sub2api/internal/service"
)
@@ -17,17 +16,12 @@ type pricingRemoteClient struct {
httpClient *http.Client
}
func NewPricingRemoteClient(cfg *config.Config) service.PricingRemoteClient {
allowPrivate := false
validateResolvedIP := true
if cfg != nil {
allowPrivate = cfg.Security.URLAllowlist.AllowPrivateHosts
validateResolvedIP = cfg.Security.URLAllowlist.Enabled
}
// NewPricingRemoteClient 创建定价数据远程客户端
// proxyURL 为空时直连,支持 http/https/socks5/socks5h 协议
func NewPricingRemoteClient(proxyURL string) service.PricingRemoteClient {
sharedClient, err := httpclient.GetClient(httpclient.Options{
Timeout: 30 * time.Second,
ValidateResolvedIP: validateResolvedIP,
AllowPrivateHosts: allowPrivate,
Timeout: 30 * time.Second,
ProxyURL: proxyURL,
})
if err != nil {
sharedClient = &http.Client{Timeout: 30 * time.Second}

View File

@@ -6,7 +6,6 @@ import (
"net/http/httptest"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
)
@@ -20,13 +19,7 @@ type PricingServiceSuite struct {
func (s *PricingServiceSuite) SetupTest() {
s.ctx = context.Background()
client, ok := NewPricingRemoteClient(&config.Config{
Security: config.SecurityConfig{
URLAllowlist: config.URLAllowlistConfig{
AllowPrivateHosts: true,
},
},
}).(*pricingRemoteClient)
client, ok := NewPricingRemoteClient("").(*pricingRemoteClient)
require.True(s.T(), ok, "type assertion failed")
s.client = client
}

View File

@@ -24,7 +24,7 @@ func NewProxyExitInfoProber(cfg *config.Config) service.ProxyExitInfoProber {
validateResolvedIP = cfg.Security.URLAllowlist.Enabled
}
if insecure {
log.Printf("[ProxyProbe] Warning: TLS verification is disabled for proxy probing.")
log.Printf("[ProxyProbe] Warning: insecure_skip_verify is not allowed and will cause probe failure.")
}
return &proxyProbeService{
ipInfoURL: defaultIPInfoURL,

View File

@@ -22,7 +22,7 @@ import (
"github.com/lib/pq"
)
const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, billing_type, stream, duration_ms, first_token_ms, image_count, image_size, created_at"
const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, billing_type, stream, duration_ms, first_token_ms, user_agent, image_count, image_size, created_at"
type usageLogRepository struct {
client *dbent.Client
@@ -109,6 +109,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
stream,
duration_ms,
first_token_ms,
user_agent,
image_count,
image_size,
created_at
@@ -118,8 +119,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
$8, $9, $10, $11,
$12, $13,
$14, $15, $16, $17, $18, $19,
$20, $21, $22, $23, $24,
$25, $26, $27
$20, $21, $22, $23, $24, $25, $26, $27, $28
)
ON CONFLICT (request_id, api_key_id) DO NOTHING
RETURNING id, created_at
@@ -129,6 +129,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
subscriptionID := nullInt64(log.SubscriptionID)
duration := nullInt(log.DurationMs)
firstToken := nullInt(log.FirstTokenMs)
userAgent := nullString(log.UserAgent)
imageSize := nullString(log.ImageSize)
var requestIDArg any
@@ -161,6 +162,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
log.Stream,
duration,
firstToken,
userAgent,
log.ImageCount,
imageSize,
createdAt,
@@ -1388,6 +1390,81 @@ func (r *usageLogRepository) GetGlobalStats(ctx context.Context, startTime, endT
return stats, nil
}
// GetStatsWithFilters gets usage statistics with optional filters
func (r *usageLogRepository) GetStatsWithFilters(ctx context.Context, filters UsageLogFilters) (*UsageStats, error) {
conditions := make([]string, 0, 9)
args := make([]any, 0, 9)
if filters.UserID > 0 {
conditions = append(conditions, fmt.Sprintf("user_id = $%d", len(args)+1))
args = append(args, filters.UserID)
}
if filters.APIKeyID > 0 {
conditions = append(conditions, fmt.Sprintf("api_key_id = $%d", len(args)+1))
args = append(args, filters.APIKeyID)
}
if filters.AccountID > 0 {
conditions = append(conditions, fmt.Sprintf("account_id = $%d", len(args)+1))
args = append(args, filters.AccountID)
}
if filters.GroupID > 0 {
conditions = append(conditions, fmt.Sprintf("group_id = $%d", len(args)+1))
args = append(args, filters.GroupID)
}
if filters.Model != "" {
conditions = append(conditions, fmt.Sprintf("model = $%d", len(args)+1))
args = append(args, filters.Model)
}
if filters.Stream != nil {
conditions = append(conditions, fmt.Sprintf("stream = $%d", len(args)+1))
args = append(args, *filters.Stream)
}
if filters.BillingType != nil {
conditions = append(conditions, fmt.Sprintf("billing_type = $%d", len(args)+1))
args = append(args, int16(*filters.BillingType))
}
if filters.StartTime != nil {
conditions = append(conditions, fmt.Sprintf("created_at >= $%d", len(args)+1))
args = append(args, *filters.StartTime)
}
if filters.EndTime != nil {
conditions = append(conditions, fmt.Sprintf("created_at <= $%d", len(args)+1))
args = append(args, *filters.EndTime)
}
query := fmt.Sprintf(`
SELECT
COUNT(*) as total_requests,
COALESCE(SUM(input_tokens), 0) as total_input_tokens,
COALESCE(SUM(output_tokens), 0) as total_output_tokens,
COALESCE(SUM(cache_creation_tokens + cache_read_tokens), 0) as total_cache_tokens,
COALESCE(SUM(total_cost), 0) as total_cost,
COALESCE(SUM(actual_cost), 0) as total_actual_cost,
COALESCE(AVG(duration_ms), 0) as avg_duration_ms
FROM usage_logs
%s
`, buildWhere(conditions))
stats := &UsageStats{}
if err := scanSingleRow(
ctx,
r.sql,
query,
args,
&stats.TotalRequests,
&stats.TotalInputTokens,
&stats.TotalOutputTokens,
&stats.TotalCacheTokens,
&stats.TotalCost,
&stats.TotalActualCost,
&stats.AverageDurationMs,
); err != nil {
return nil, err
}
stats.TotalTokens = stats.TotalInputTokens + stats.TotalOutputTokens + stats.TotalCacheTokens
return stats, nil
}
// AccountUsageHistory represents daily usage history for an account
type AccountUsageHistory = usagestats.AccountUsageHistory
@@ -1795,6 +1872,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
stream bool
durationMs sql.NullInt64
firstTokenMs sql.NullInt64
userAgent sql.NullString
imageCount int
imageSize sql.NullString
createdAt time.Time
@@ -1826,6 +1904,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
&stream,
&durationMs,
&firstTokenMs,
&userAgent,
&imageCount,
&imageSize,
&createdAt,
@@ -1877,6 +1956,9 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
value := int(firstTokenMs.Int64)
log.FirstTokenMs = &value
}
if userAgent.Valid {
log.UserAgent = &userAgent.String
}
if imageSize.Valid {
log.ImageSize = &imageSize.String
}

View File

@@ -25,6 +25,18 @@ func ProvideConcurrencyCache(rdb *redis.Client, cfg *config.Config) service.Conc
return NewConcurrencyCache(rdb, cfg.Gateway.ConcurrencySlotTTLMinutes, waitTTLSeconds)
}
// ProvideGitHubReleaseClient 创建 GitHub Release 客户端
// 从配置中读取代理设置,支持国内服务器通过代理访问 GitHub
func ProvideGitHubReleaseClient(cfg *config.Config) service.GitHubReleaseClient {
return NewGitHubReleaseClient(cfg.Update.ProxyURL)
}
// ProvidePricingRemoteClient 创建定价数据远程客户端
// 从配置中读取代理设置,支持国内服务器通过代理访问 GitHub 上的定价数据
func ProvidePricingRemoteClient(cfg *config.Config) service.PricingRemoteClient {
return NewPricingRemoteClient(cfg.Update.ProxyURL)
}
// ProviderSet is the Wire provider set for all repositories
var ProviderSet = wire.NewSet(
NewUserRepository,
@@ -53,8 +65,8 @@ var ProviderSet = wire.NewSet(
// HTTP service ports (DI Strategy A: return interface directly)
NewTurnstileVerifier,
NewPricingRemoteClient,
NewGitHubReleaseClient,
ProvidePricingRemoteClient,
ProvideGitHubReleaseClient,
NewProxyExitInfoProber,
NewClaudeUsageFetcher,
NewClaudeOAuthClient,

View File

@@ -1065,6 +1065,10 @@ func (r *stubUsageLogRepo) GetAccountUsageStats(ctx context.Context, accountID i
return nil, errors.New("not implemented")
}
func (r *stubUsageLogRepo) GetStatsWithFilters(ctx context.Context, filters usagestats.UsageLogFilters) (*usagestats.UsageStats, error) {
return nil, errors.New("not implemented")
}
type stubSettingRepo struct {
all map[string]string
}

View File

@@ -9,21 +9,23 @@ import (
)
type Account struct {
ID int64
Name string
Notes *string
Platform string
Type string
Credentials map[string]any
Extra map[string]any
ProxyID *int64
Concurrency int
Priority int
Status string
ErrorMessage string
LastUsedAt *time.Time
CreatedAt time.Time
UpdatedAt time.Time
ID int64
Name string
Notes *string
Platform string
Type string
Credentials map[string]any
Extra map[string]any
ProxyID *int64
Concurrency int
Priority int
Status string
ErrorMessage string
LastUsedAt *time.Time
ExpiresAt *time.Time
AutoPauseOnExpired bool
CreatedAt time.Time
UpdatedAt time.Time
Schedulable bool
@@ -60,6 +62,9 @@ func (a *Account) IsSchedulable() bool {
return false
}
now := time.Now()
if a.AutoPauseOnExpired && a.ExpiresAt != nil && !now.Before(*a.ExpiresAt) {
return false
}
if a.OverloadUntil != nil && now.Before(*a.OverloadUntil) {
return false
}

View File

@@ -0,0 +1,71 @@
package service
import (
"context"
"log"
"sync"
"time"
)
// AccountExpiryService periodically pauses expired accounts when auto-pause is enabled.
type AccountExpiryService struct {
accountRepo AccountRepository
interval time.Duration
stopCh chan struct{}
stopOnce sync.Once
wg sync.WaitGroup
}
func NewAccountExpiryService(accountRepo AccountRepository, interval time.Duration) *AccountExpiryService {
return &AccountExpiryService{
accountRepo: accountRepo,
interval: interval,
stopCh: make(chan struct{}),
}
}
func (s *AccountExpiryService) Start() {
if s == nil || s.accountRepo == nil || s.interval <= 0 {
return
}
s.wg.Add(1)
go func() {
defer s.wg.Done()
ticker := time.NewTicker(s.interval)
defer ticker.Stop()
s.runOnce()
for {
select {
case <-ticker.C:
s.runOnce()
case <-s.stopCh:
return
}
}
}()
}
func (s *AccountExpiryService) Stop() {
if s == nil {
return
}
s.stopOnce.Do(func() {
close(s.stopCh)
})
s.wg.Wait()
}
func (s *AccountExpiryService) runOnce() {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
updated, err := s.accountRepo.AutoPauseExpiredAccounts(ctx, time.Now())
if err != nil {
log.Printf("[AccountExpiry] Auto pause expired accounts failed: %v", err)
return
}
if updated > 0 {
log.Printf("[AccountExpiry] Auto paused %d expired accounts", updated)
}
}

View File

@@ -38,6 +38,7 @@ type AccountRepository interface {
BatchUpdateLastUsed(ctx context.Context, updates map[int64]time.Time) error
SetError(ctx context.Context, id int64, errorMsg string) error
SetSchedulable(ctx context.Context, id int64, schedulable bool) error
AutoPauseExpiredAccounts(ctx context.Context, now time.Time) (int64, error)
BindGroups(ctx context.Context, accountID int64, groupIDs []int64) error
ListSchedulable(ctx context.Context) ([]Account, error)
@@ -71,29 +72,33 @@ type AccountBulkUpdate struct {
// CreateAccountRequest 创建账号请求
type CreateAccountRequest struct {
Name string `json:"name"`
Notes *string `json:"notes"`
Platform string `json:"platform"`
Type string `json:"type"`
Credentials map[string]any `json:"credentials"`
Extra map[string]any `json:"extra"`
ProxyID *int64 `json:"proxy_id"`
Concurrency int `json:"concurrency"`
Priority int `json:"priority"`
GroupIDs []int64 `json:"group_ids"`
Name string `json:"name"`
Notes *string `json:"notes"`
Platform string `json:"platform"`
Type string `json:"type"`
Credentials map[string]any `json:"credentials"`
Extra map[string]any `json:"extra"`
ProxyID *int64 `json:"proxy_id"`
Concurrency int `json:"concurrency"`
Priority int `json:"priority"`
GroupIDs []int64 `json:"group_ids"`
ExpiresAt *time.Time `json:"expires_at"`
AutoPauseOnExpired *bool `json:"auto_pause_on_expired"`
}
// UpdateAccountRequest 更新账号请求
type UpdateAccountRequest struct {
Name *string `json:"name"`
Notes *string `json:"notes"`
Credentials *map[string]any `json:"credentials"`
Extra *map[string]any `json:"extra"`
ProxyID *int64 `json:"proxy_id"`
Concurrency *int `json:"concurrency"`
Priority *int `json:"priority"`
Status *string `json:"status"`
GroupIDs *[]int64 `json:"group_ids"`
Name *string `json:"name"`
Notes *string `json:"notes"`
Credentials *map[string]any `json:"credentials"`
Extra *map[string]any `json:"extra"`
ProxyID *int64 `json:"proxy_id"`
Concurrency *int `json:"concurrency"`
Priority *int `json:"priority"`
Status *string `json:"status"`
GroupIDs *[]int64 `json:"group_ids"`
ExpiresAt *time.Time `json:"expires_at"`
AutoPauseOnExpired *bool `json:"auto_pause_on_expired"`
}
// AccountService 账号管理服务
@@ -134,6 +139,12 @@ func (s *AccountService) Create(ctx context.Context, req CreateAccountRequest) (
Concurrency: req.Concurrency,
Priority: req.Priority,
Status: StatusActive,
ExpiresAt: req.ExpiresAt,
}
if req.AutoPauseOnExpired != nil {
account.AutoPauseOnExpired = *req.AutoPauseOnExpired
} else {
account.AutoPauseOnExpired = true
}
if err := s.accountRepo.Create(ctx, account); err != nil {
@@ -224,6 +235,12 @@ func (s *AccountService) Update(ctx context.Context, id int64, req UpdateAccount
if req.Status != nil {
account.Status = *req.Status
}
if req.ExpiresAt != nil {
account.ExpiresAt = req.ExpiresAt
}
if req.AutoPauseOnExpired != nil {
account.AutoPauseOnExpired = *req.AutoPauseOnExpired
}
// 先验证分组是否存在(在任何写操作之前)
if req.GroupIDs != nil {

View File

@@ -103,6 +103,10 @@ func (s *accountRepoStub) SetSchedulable(ctx context.Context, id int64, schedula
panic("unexpected SetSchedulable call")
}
func (s *accountRepoStub) AutoPauseExpiredAccounts(ctx context.Context, now time.Time) (int64, error) {
panic("unexpected AutoPauseExpiredAccounts call")
}
func (s *accountRepoStub) BindGroups(ctx context.Context, accountID int64, groupIDs []int64) error {
panic("unexpected BindGroups call")
}

View File

@@ -47,6 +47,7 @@ type UsageLogRepository interface {
// Admin usage listing/stats
ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters usagestats.UsageLogFilters) ([]UsageLog, *pagination.PaginationResult, error)
GetGlobalStats(ctx context.Context, startTime, endTime time.Time) (*usagestats.UsageStats, error)
GetStatsWithFilters(ctx context.Context, filters usagestats.UsageLogFilters) (*usagestats.UsageStats, error)
// Account stats
GetAccountUsageStats(ctx context.Context, accountID int64, startTime, endTime time.Time) (*usagestats.AccountUsageStatsResponse, error)

View File

@@ -122,16 +122,18 @@ type UpdateGroupInput struct {
}
type CreateAccountInput struct {
Name string
Notes *string
Platform string
Type string
Credentials map[string]any
Extra map[string]any
ProxyID *int64
Concurrency int
Priority int
GroupIDs []int64
Name string
Notes *string
Platform string
Type string
Credentials map[string]any
Extra map[string]any
ProxyID *int64
Concurrency int
Priority int
GroupIDs []int64
ExpiresAt *int64
AutoPauseOnExpired *bool
// SkipMixedChannelCheck skips the mixed channel risk check when binding groups.
// This should only be set when the caller has explicitly confirmed the risk.
SkipMixedChannelCheck bool
@@ -148,6 +150,8 @@ type UpdateAccountInput struct {
Priority *int // 使用指针区分"未提供"和"设置为0"
Status string
GroupIDs *[]int64
ExpiresAt *int64
AutoPauseOnExpired *bool
SkipMixedChannelCheck bool // 跳过混合渠道检查(用户已确认风险)
}
@@ -700,6 +704,15 @@ func (s *adminServiceImpl) CreateAccount(ctx context.Context, input *CreateAccou
Status: StatusActive,
Schedulable: true,
}
if input.ExpiresAt != nil && *input.ExpiresAt > 0 {
expiresAt := time.Unix(*input.ExpiresAt, 0)
account.ExpiresAt = &expiresAt
}
if input.AutoPauseOnExpired != nil {
account.AutoPauseOnExpired = *input.AutoPauseOnExpired
} else {
account.AutoPauseOnExpired = true
}
if err := s.accountRepo.Create(ctx, account); err != nil {
return nil, err
}
@@ -755,6 +768,17 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U
if input.Status != "" {
account.Status = input.Status
}
if input.ExpiresAt != nil {
if *input.ExpiresAt <= 0 {
account.ExpiresAt = nil
} else {
expiresAt := time.Unix(*input.ExpiresAt, 0)
account.ExpiresAt = &expiresAt
}
}
if input.AutoPauseOnExpired != nil {
account.AutoPauseOnExpired = *input.AutoPauseOnExpired
}
// 先验证分组是否存在(在任何写操作之前)
if input.GroupIDs != nil {

View File

@@ -20,12 +20,16 @@ var (
ErrEmailExists = infraerrors.Conflict("EMAIL_EXISTS", "email already exists")
ErrInvalidToken = infraerrors.Unauthorized("INVALID_TOKEN", "invalid token")
ErrTokenExpired = infraerrors.Unauthorized("TOKEN_EXPIRED", "token has expired")
ErrTokenTooLarge = infraerrors.BadRequest("TOKEN_TOO_LARGE", "token too large")
ErrTokenRevoked = infraerrors.Unauthorized("TOKEN_REVOKED", "token has been revoked")
ErrEmailVerifyRequired = infraerrors.BadRequest("EMAIL_VERIFY_REQUIRED", "email verification is required")
ErrRegDisabled = infraerrors.Forbidden("REGISTRATION_DISABLED", "registration is currently disabled")
ErrServiceUnavailable = infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "service temporarily unavailable")
)
// maxTokenLength 限制 token 大小,避免超长 header 触发解析时的异常内存分配。
const maxTokenLength = 8192
// JWTClaims JWT载荷数据
type JWTClaims struct {
UserID int64 `json:"user_id"`
@@ -309,7 +313,20 @@ func (s *AuthService) Login(ctx context.Context, email, password string) (string
// ValidateToken 验证JWT token并返回用户声明
func (s *AuthService) ValidateToken(tokenString string) (*JWTClaims, error) {
token, err := jwt.ParseWithClaims(tokenString, &JWTClaims{}, func(token *jwt.Token) (any, error) {
// 先做长度校验,尽早拒绝异常超长 token降低 DoS 风险。
if len(tokenString) > maxTokenLength {
return nil, ErrTokenTooLarge
}
// 使用解析器并限制可接受的签名算法,防止算法混淆。
parser := jwt.NewParser(jwt.WithValidMethods([]string{
jwt.SigningMethodHS256.Name,
jwt.SigningMethodHS384.Name,
jwt.SigningMethodHS512.Name,
}))
// 保留默认 claims 校验exp/nbf避免放行过期或未生效的 token。
token, err := parser.ParseWithClaims(tokenString, &JWTClaims{}, func(token *jwt.Token) (any, error) {
// 验证签名方法
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])

View File

@@ -140,6 +140,8 @@ func (s *EmailService) SendEmailWithConfig(config *SMTPConfig, to, subject, body
func (s *EmailService) sendMailTLS(addr string, auth smtp.Auth, from, to string, msg []byte, host string) error {
tlsConfig := &tls.Config{
ServerName: host,
// 强制 TLS 1.2+,避免协议降级导致的弱加密风险。
MinVersion: tls.VersionTLS12,
}
conn, err := tls.Dial("tcp", addr, tlsConfig)
@@ -311,7 +313,11 @@ func (s *EmailService) TestSMTPConnectionWithConfig(config *SMTPConfig) error {
addr := fmt.Sprintf("%s:%d", config.Host, config.Port)
if config.UseTLS {
tlsConfig := &tls.Config{ServerName: config.Host}
tlsConfig := &tls.Config{
ServerName: config.Host,
// 与发送逻辑一致,显式要求 TLS 1.2+。
MinVersion: tls.VersionTLS12,
}
conn, err := tls.Dial("tcp", addr, tlsConfig)
if err != nil {
return fmt.Errorf("tls connection failed: %w", err)

View File

@@ -105,6 +105,9 @@ func (m *mockAccountRepoForPlatform) SetError(ctx context.Context, id int64, err
func (m *mockAccountRepoForPlatform) SetSchedulable(ctx context.Context, id int64, schedulable bool) error {
return nil
}
func (m *mockAccountRepoForPlatform) AutoPauseExpiredAccounts(ctx context.Context, now time.Time) (int64, error) {
return 0, nil
}
func (m *mockAccountRepoForPlatform) BindGroups(ctx context.Context, accountID int64, groupIDs []int64) error {
return nil
}

View File

@@ -35,6 +35,7 @@ const (
stickySessionTTL = time.Hour // 粘性会话TTL
defaultMaxLineSize = 10 * 1024 * 1024
claudeCodeSystemPrompt = "You are Claude Code, Anthropic's official CLI for Claude."
maxCacheControlBlocks = 4 // Anthropic API 允许的最大 cache_control 块数量
)
// sseDataRe matches SSE data lines with optional whitespace after colon.
@@ -43,6 +44,16 @@ var (
sseDataRe = regexp.MustCompile(`^data:\s*`)
sessionIDRegex = regexp.MustCompile(`session_([a-f0-9-]{36})`)
claudeCliUserAgentRe = regexp.MustCompile(`^claude-cli/\d+\.\d+\.\d+`)
// claudeCodePromptPrefixes 用于检测 Claude Code 系统提示词的前缀列表
// 支持多种变体标准版、Agent SDK 版、Explore Agent 版、Compact 版等
// 注意:前缀之间不应存在包含关系,否则会导致冗余匹配
claudeCodePromptPrefixes = []string{
"You are Claude Code, Anthropic's official CLI for Claude", // 标准版 & Agent SDK 版(含 running within...
"You are a Claude agent, built on Anthropic's Claude Agent SDK", // Agent SDK 变体
"You are a file search specialist for Claude Code", // Explore Agent 版
"You are a helpful AI assistant tasked with summarizing conversations", // Compact 版
}
)
// allowedHeaders 白名单headers参考CRS项目
@@ -98,12 +109,13 @@ type ClaudeUsage struct {
// ForwardResult 转发结果
type ForwardResult struct {
RequestID string
Usage ClaudeUsage
Model string
Stream bool
Duration time.Duration
FirstTokenMs *int // 首字时间(流式请求)
RequestID string
Usage ClaudeUsage
Model string
Stream bool
Duration time.Duration
FirstTokenMs *int // 首字时间(流式请求)
ClientDisconnect bool // 客户端是否在流式传输过程中断开
// 图片生成计费字段(仅 gemini-3-pro-image 使用)
ImageCount int // 生成的图片数量
@@ -355,17 +367,8 @@ func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context
return s.selectAccountWithMixedScheduling(ctx, groupID, sessionHash, requestedModel, excludedIDs, platform)
}
// 强制平台模式:优先按分组查找,找不到再查全部该平台账户
if hasForcePlatform && groupID != nil {
account, err := s.selectAccountForModelWithPlatform(ctx, groupID, sessionHash, requestedModel, excludedIDs, platform)
if err == nil {
return account, nil
}
// 分组中找不到,回退查询全部该平台账户
groupID = nil
}
// antigravity 分组、强制平台模式或无分组使用单平台选择
// 注意:强制平台模式也必须遵守分组限制,不再回退到全平台查询
return s.selectAccountForModelWithPlatform(ctx, groupID, sessionHash, requestedModel, excludedIDs, platform)
}
@@ -443,7 +446,8 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
accountID, err := s.cache.GetSessionAccountID(ctx, sessionHash)
if err == nil && accountID > 0 && !isExcluded(accountID) {
account, err := s.accountRepo.GetByID(ctx, accountID)
if err == nil && s.isAccountAllowedForPlatform(account, platform, useMixed) &&
if err == nil && s.isAccountInGroup(account, groupID) &&
s.isAccountAllowedForPlatform(account, platform, useMixed) &&
account.IsSchedulable() &&
(requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency)
@@ -660,9 +664,7 @@ func (s *GatewayService) listSchedulableAccounts(ctx context.Context, groupID *i
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, platform)
} else if groupID != nil {
accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, platform)
if err == nil && len(accounts) == 0 && hasForcePlatform {
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, platform)
}
// 分组内无账号则返回空列表,由上层处理错误,不再回退到全平台查询
} else {
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, platform)
}
@@ -685,6 +687,23 @@ func (s *GatewayService) isAccountAllowedForPlatform(account *Account, platform
return account.Platform == platform
}
// isAccountInGroup checks if the account belongs to the specified group.
// Returns true if groupID is nil (no group restriction) or account belongs to the group.
func (s *GatewayService) isAccountInGroup(account *Account, groupID *int64) bool {
if groupID == nil {
return true // 无分组限制
}
if account == nil {
return false
}
for _, ag := range account.AccountGroups {
if ag.GroupID == *groupID {
return true
}
}
return false
}
func (s *GatewayService) tryAcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int) (*AcquireResult, error) {
if s.concurrencyService == nil {
return &AcquireResult{Acquired: true, ReleaseFunc: func() {}}, nil
@@ -723,8 +742,8 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
if err == nil && accountID > 0 {
if _, excluded := excludedIDs[accountID]; !excluded {
account, err := s.accountRepo.GetByID(ctx, accountID)
// 检查账号平台是否匹配(确保粘性会话不会跨平台)
if err == nil && account.Platform == platform && account.IsSchedulable() && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
// 检查账号分组归属和平台匹配(确保粘性会话不会跨分组或跨平台)
if err == nil && s.isAccountInGroup(account, groupID) && account.Platform == platform && account.IsSchedulable() && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
if err := s.cache.RefreshSessionTTL(ctx, sessionHash, stickySessionTTL); err != nil {
log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err)
}
@@ -812,8 +831,8 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
if err == nil && accountID > 0 {
if _, excluded := excludedIDs[accountID]; !excluded {
account, err := s.accountRepo.GetByID(ctx, accountID)
// 检查账号是否有效原生平台直接匹配antigravity 需要启用混合调度
if err == nil && account.IsSchedulable() && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
// 检查账号分组归属和有效原生平台直接匹配antigravity 需要启用混合调度
if err == nil && s.isAccountInGroup(account, groupID) && account.IsSchedulable() && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
if account.Platform == nativePlatform || (account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled()) {
if err := s.cache.RefreshSessionTTL(ctx, sessionHash, stickySessionTTL); err != nil {
log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err)
@@ -1013,15 +1032,15 @@ func isClaudeCodeClient(userAgent string, metadataUserID string) bool {
}
// systemIncludesClaudeCodePrompt 检查 system 中是否已包含 Claude Code 提示词
// 支持 string 和 []any 两种格式
// 使用前缀匹配支持多种变体标准版、Agent SDK 版等)
func systemIncludesClaudeCodePrompt(system any) bool {
switch v := system.(type) {
case string:
return v == claudeCodeSystemPrompt
return hasClaudeCodePrefix(v)
case []any:
for _, item := range v {
if m, ok := item.(map[string]any); ok {
if text, ok := m["text"].(string); ok && text == claudeCodeSystemPrompt {
if text, ok := m["text"].(string); ok && hasClaudeCodePrefix(text) {
return true
}
}
@@ -1030,6 +1049,16 @@ func systemIncludesClaudeCodePrompt(system any) bool {
return false
}
// hasClaudeCodePrefix 检查文本是否以 Claude Code 提示词的特征前缀开头
func hasClaudeCodePrefix(text string) bool {
for _, prefix := range claudeCodePromptPrefixes {
if strings.HasPrefix(text, prefix) {
return true
}
}
return false
}
// injectClaudeCodePrompt 在 system 开头注入 Claude Code 提示词
// 处理 null、字符串、数组三种格式
func injectClaudeCodePrompt(body []byte, system any) []byte {
@@ -1073,6 +1102,124 @@ func injectClaudeCodePrompt(body []byte, system any) []byte {
return result
}
// enforceCacheControlLimit 强制执行 cache_control 块数量限制(最多 4 个)
// 超限时优先从 messages 中移除 cache_control保护 system 中的缓存控制
func enforceCacheControlLimit(body []byte) []byte {
var data map[string]any
if err := json.Unmarshal(body, &data); err != nil {
return body
}
// 计算当前 cache_control 块数量
count := countCacheControlBlocks(data)
if count <= maxCacheControlBlocks {
return body
}
// 超限:优先从 messages 中移除,再从 system 中移除
for count > maxCacheControlBlocks {
if removeCacheControlFromMessages(data) {
count--
continue
}
if removeCacheControlFromSystem(data) {
count--
continue
}
break
}
result, err := json.Marshal(data)
if err != nil {
return body
}
return result
}
// countCacheControlBlocks 统计 system 和 messages 中的 cache_control 块数量
func countCacheControlBlocks(data map[string]any) int {
count := 0
// 统计 system 中的块
if system, ok := data["system"].([]any); ok {
for _, item := range system {
if m, ok := item.(map[string]any); ok {
if _, has := m["cache_control"]; has {
count++
}
}
}
}
// 统计 messages 中的块
if messages, ok := data["messages"].([]any); ok {
for _, msg := range messages {
if msgMap, ok := msg.(map[string]any); ok {
if content, ok := msgMap["content"].([]any); ok {
for _, item := range content {
if m, ok := item.(map[string]any); ok {
if _, has := m["cache_control"]; has {
count++
}
}
}
}
}
}
}
return count
}
// removeCacheControlFromMessages 从 messages 中移除一个 cache_control从头开始
// 返回 true 表示成功移除false 表示没有可移除的
func removeCacheControlFromMessages(data map[string]any) bool {
messages, ok := data["messages"].([]any)
if !ok {
return false
}
for _, msg := range messages {
msgMap, ok := msg.(map[string]any)
if !ok {
continue
}
content, ok := msgMap["content"].([]any)
if !ok {
continue
}
for _, item := range content {
if m, ok := item.(map[string]any); ok {
if _, has := m["cache_control"]; has {
delete(m, "cache_control")
return true
}
}
}
}
return false
}
// removeCacheControlFromSystem 从 system 中移除一个 cache_control从尾部开始保护注入的 prompt
// 返回 true 表示成功移除false 表示没有可移除的
func removeCacheControlFromSystem(data map[string]any) bool {
system, ok := data["system"].([]any)
if !ok {
return false
}
// 从尾部开始移除,保护开头注入的 Claude Code prompt
for i := len(system) - 1; i >= 0; i-- {
if m, ok := system[i].(map[string]any); ok {
if _, has := m["cache_control"]; has {
delete(m, "cache_control")
return true
}
}
}
return false
}
// Forward 转发请求到Claude API
func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *Account, parsed *ParsedRequest) (*ForwardResult, error) {
startTime := time.Now()
@@ -1093,6 +1240,9 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
body = injectClaudeCodePrompt(body, parsed.System)
}
// 强制执行 cache_control 块数量限制(最多 4 个)
body = enforceCacheControlLimit(body)
// 应用模型映射仅对apikey类型账号
originalModel := reqModel
if account.Type == AccountTypeAPIKey {
@@ -1316,6 +1466,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
// 处理正常响应
var usage *ClaudeUsage
var firstTokenMs *int
var clientDisconnect bool
if reqStream {
streamResult, err := s.handleStreamingResponse(ctx, resp, c, account, startTime, originalModel, reqModel)
if err != nil {
@@ -1328,6 +1479,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
}
usage = streamResult.usage
firstTokenMs = streamResult.firstTokenMs
clientDisconnect = streamResult.clientDisconnect
} else {
usage, err = s.handleNonStreamingResponse(ctx, resp, c, account, originalModel, reqModel)
if err != nil {
@@ -1336,12 +1488,13 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
}
return &ForwardResult{
RequestID: resp.Header.Get("x-request-id"),
Usage: *usage,
Model: originalModel, // 使用原始模型用于计费和日志
Stream: reqStream,
Duration: time.Since(startTime),
FirstTokenMs: firstTokenMs,
RequestID: resp.Header.Get("x-request-id"),
Usage: *usage,
Model: originalModel, // 使用原始模型用于计费和日志
Stream: reqStream,
Duration: time.Since(startTime),
FirstTokenMs: firstTokenMs,
ClientDisconnect: clientDisconnect,
}, nil
}
@@ -1696,8 +1849,9 @@ func (s *GatewayService) handleRetryExhaustedError(ctx context.Context, resp *ht
// streamingResult 流式响应结果
type streamingResult struct {
usage *ClaudeUsage
firstTokenMs *int
usage *ClaudeUsage
firstTokenMs *int
clientDisconnect bool // 客户端是否在流式传输过程中断开
}
func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, startTime time.Time, originalModel, mappedModel string) (*streamingResult, error) {
@@ -1793,14 +1947,27 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
}
needModelReplace := originalModel != mappedModel
clientDisconnected := false // 客户端断开标志断开后继续读取上游以获取完整usage
for {
select {
case ev, ok := <-events:
if !ok {
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, nil
// 上游完成,返回结果
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: clientDisconnected}, nil
}
if ev.err != nil {
// 检测 context 取消(客户端断开会导致 context 取消,进而影响上游读取)
if errors.Is(ev.err, context.Canceled) || errors.Is(ev.err, context.DeadlineExceeded) {
log.Printf("Context canceled during streaming, returning collected usage")
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil
}
// 客户端已通过写入失败检测到断开,上游也出错了,返回已收集的 usage
if clientDisconnected {
log.Printf("Upstream read error after client disconnect: %v, returning collected usage", ev.err)
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil
}
// 客户端未断开,正常的错误处理
if errors.Is(ev.err, bufio.ErrTooLong) {
log.Printf("SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, ev.err)
sendErrorEvent("response_too_large")
@@ -1811,38 +1978,40 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
}
line := ev.line
if line == "event: error" {
// 上游返回错误事件,如果客户端已断开仍返回已收集的 usage
if clientDisconnected {
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil
}
return nil, errors.New("have error in stream")
}
// Extract data from SSE line (supports both "data: " and "data:" formats)
var data string
if sseDataRe.MatchString(line) {
data := sseDataRe.ReplaceAllString(line, "")
data = sseDataRe.ReplaceAllString(line, "")
// 如果有模型映射替换响应中的model字段
if needModelReplace {
line = s.replaceModelInSSELine(line, mappedModel, originalModel)
}
}
// 转发行
// 写入客户端(统一处理 data 行和非 data 行)
if !clientDisconnected {
if _, err := fmt.Fprintf(w, "%s\n", line); err != nil {
sendErrorEvent("write_failed")
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, err
clientDisconnected = true
log.Printf("Client disconnected during streaming, continuing to drain upstream for billing")
} else {
flusher.Flush()
}
flusher.Flush()
}
// 记录首字时间:第一个有效的 content_block_delta 或 message_start
if firstTokenMs == nil && data != "" && data != "[DONE]" {
// 无论客户端是否断开,都解析 usage仅对 data 行)
if data != "" {
if firstTokenMs == nil && data != "[DONE]" {
ms := int(time.Since(startTime).Milliseconds())
firstTokenMs = &ms
}
s.parseSSEUsage(data, usage)
} else {
// 非 data 行直接转发
if _, err := fmt.Fprintf(w, "%s\n", line); err != nil {
sendErrorEvent("write_failed")
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, err
}
flusher.Flush()
}
case <-intervalCh:
@@ -1850,6 +2019,11 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
if time.Since(lastRead) < streamInterval {
continue
}
if clientDisconnected {
// 客户端已断开,上游也超时了,返回已收集的 usage
log.Printf("Upstream timeout after client disconnect, returning collected usage")
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil
}
log.Printf("Stream data interval timeout: account=%d model=%s interval=%s", account.ID, originalModel, streamInterval)
sendErrorEvent("stream_timeout")
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout")
@@ -2003,6 +2177,7 @@ type RecordUsageInput struct {
User *User
Account *Account
Subscription *UserSubscription // 可选:订阅信息
UserAgent string // 请求的 User-Agent
}
// RecordUsage 记录使用量并扣费(或更新订阅用量)
@@ -2088,6 +2263,11 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
CreatedAt: time.Now(),
}
// 添加 UserAgent
if input.UserAgent != "" {
usageLog.UserAgent = &input.UserAgent
}
// 添加分组和订阅关联
if apiKey.GroupID != nil {
usageLog.GroupID = apiKey.GroupID

View File

@@ -90,6 +90,9 @@ func (m *mockAccountRepoForGemini) SetError(ctx context.Context, id int64, error
func (m *mockAccountRepoForGemini) SetSchedulable(ctx context.Context, id int64, schedulable bool) error {
return nil
}
func (m *mockAccountRepoForGemini) AutoPauseExpiredAccounts(ctx context.Context, now time.Time) (int64, error) {
return 0, nil
}
func (m *mockAccountRepoForGemini) BindGroups(ctx context.Context, accountID int64, groupIDs []int64) error {
return nil
}

View File

@@ -1092,6 +1092,7 @@ type OpenAIRecordUsageInput struct {
User *User
Account *Account
Subscription *UserSubscription
UserAgent string // 请求的 User-Agent
}
// RecordUsage records usage and deducts balance
@@ -1161,6 +1162,11 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
CreatedAt: time.Now(),
}
// 添加 UserAgent
if input.UserAgent != "" {
usageLog.UserAgent = &input.UserAgent
}
if apiKey.GroupID != nil {
usageLog.GroupID = apiKey.GroupID
}

View File

@@ -38,6 +38,7 @@ type UsageLog struct {
Stream bool
DurationMs *int
FirstTokenMs *int
UserAgent *string
// 图片生成字段
ImageCount int

View File

@@ -319,3 +319,12 @@ func (s *UsageService) GetGlobalStats(ctx context.Context, startTime, endTime ti
}
return stats, nil
}
// GetStatsWithFilters returns usage stats with optional filters.
func (s *UsageService) GetStatsWithFilters(ctx context.Context, filters usagestats.UsageLogFilters) (*usagestats.UsageStats, error) {
stats, err := s.usageRepo.GetStatsWithFilters(ctx, filters)
if err != nil {
return nil, fmt.Errorf("get usage stats with filters: %w", err)
}
return stats, nil
}

View File

@@ -47,6 +47,13 @@ func ProvideTokenRefreshService(
return svc
}
// ProvideAccountExpiryService creates and starts AccountExpiryService.
func ProvideAccountExpiryService(accountRepo AccountRepository) *AccountExpiryService {
svc := NewAccountExpiryService(accountRepo, time.Minute)
svc.Start()
return svc
}
// ProvideTimingWheelService creates and starts TimingWheelService
func ProvideTimingWheelService() *TimingWheelService {
svc := NewTimingWheelService()
@@ -110,6 +117,7 @@ var ProviderSet = wire.NewSet(
NewCRSSyncService,
ProvideUpdateService,
ProvideTokenRefreshService,
ProvideAccountExpiryService,
ProvideTimingWheelService,
ProvideDeferredService,
NewAntigravityQuotaFetcher,