merge: 合并 upstream/main
This commit is contained in:
@@ -3,6 +3,7 @@ package admin
|
||||
import (
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
|
||||
@@ -13,6 +14,7 @@ import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"golang.org/x/sync/errgroup"
|
||||
)
|
||||
|
||||
// OAuthHandler handles OAuth-related operations for accounts
|
||||
@@ -989,3 +991,164 @@ func (h *AccountHandler) GetAvailableModels(c *gin.Context) {
|
||||
|
||||
response.Success(c, models)
|
||||
}
|
||||
|
||||
// RefreshTier handles refreshing Google One tier for a single account
|
||||
// POST /api/v1/admin/accounts/:id/refresh-tier
|
||||
func (h *AccountHandler) RefreshTier(c *gin.Context) {
|
||||
accountID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid account ID")
|
||||
return
|
||||
}
|
||||
|
||||
ctx := c.Request.Context()
|
||||
account, err := h.adminService.GetAccount(ctx, accountID)
|
||||
if err != nil {
|
||||
response.NotFound(c, "Account not found")
|
||||
return
|
||||
}
|
||||
|
||||
if account.Platform != service.PlatformGemini || account.Type != service.AccountTypeOAuth {
|
||||
response.BadRequest(c, "Only Gemini OAuth accounts support tier refresh")
|
||||
return
|
||||
}
|
||||
|
||||
oauthType, _ := account.Credentials["oauth_type"].(string)
|
||||
if oauthType != "google_one" {
|
||||
response.BadRequest(c, "Only google_one OAuth accounts support tier refresh")
|
||||
return
|
||||
}
|
||||
|
||||
tierID, extra, creds, err := h.geminiOAuthService.RefreshAccountGoogleOneTier(ctx, account)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
_, updateErr := h.adminService.UpdateAccount(ctx, accountID, &service.UpdateAccountInput{
|
||||
Credentials: creds,
|
||||
Extra: extra,
|
||||
})
|
||||
if updateErr != nil {
|
||||
response.ErrorFrom(c, updateErr)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{
|
||||
"tier_id": tierID,
|
||||
"storage_info": extra,
|
||||
"drive_storage_limit": extra["drive_storage_limit"],
|
||||
"drive_storage_usage": extra["drive_storage_usage"],
|
||||
"updated_at": extra["drive_tier_updated_at"],
|
||||
})
|
||||
}
|
||||
|
||||
// BatchRefreshTierRequest represents batch tier refresh request
|
||||
type BatchRefreshTierRequest struct {
|
||||
AccountIDs []int64 `json:"account_ids"`
|
||||
}
|
||||
|
||||
// BatchRefreshTier handles batch refreshing Google One tier
|
||||
// POST /api/v1/admin/accounts/batch-refresh-tier
|
||||
func (h *AccountHandler) BatchRefreshTier(c *gin.Context) {
|
||||
var req BatchRefreshTierRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
req = BatchRefreshTierRequest{}
|
||||
}
|
||||
|
||||
ctx := c.Request.Context()
|
||||
accounts := make([]*service.Account, 0)
|
||||
|
||||
if len(req.AccountIDs) == 0 {
|
||||
allAccounts, _, err := h.adminService.ListAccounts(ctx, 1, 10000, "gemini", "oauth", "", "")
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
for i := range allAccounts {
|
||||
acc := &allAccounts[i]
|
||||
oauthType, _ := acc.Credentials["oauth_type"].(string)
|
||||
if oauthType == "google_one" {
|
||||
accounts = append(accounts, acc)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
fetched, err := h.adminService.GetAccountsByIDs(ctx, req.AccountIDs)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
for _, acc := range fetched {
|
||||
if acc == nil {
|
||||
continue
|
||||
}
|
||||
if acc.Platform != service.PlatformGemini || acc.Type != service.AccountTypeOAuth {
|
||||
continue
|
||||
}
|
||||
oauthType, _ := acc.Credentials["oauth_type"].(string)
|
||||
if oauthType != "google_one" {
|
||||
continue
|
||||
}
|
||||
accounts = append(accounts, acc)
|
||||
}
|
||||
}
|
||||
|
||||
const maxConcurrency = 10
|
||||
g, gctx := errgroup.WithContext(ctx)
|
||||
g.SetLimit(maxConcurrency)
|
||||
|
||||
var mu sync.Mutex
|
||||
var successCount, failedCount int
|
||||
var errors []gin.H
|
||||
|
||||
for _, account := range accounts {
|
||||
acc := account // 闭包捕获
|
||||
g.Go(func() error {
|
||||
_, extra, creds, err := h.geminiOAuthService.RefreshAccountGoogleOneTier(gctx, acc)
|
||||
if err != nil {
|
||||
mu.Lock()
|
||||
failedCount++
|
||||
errors = append(errors, gin.H{
|
||||
"account_id": acc.ID,
|
||||
"error": err.Error(),
|
||||
})
|
||||
mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
_, updateErr := h.adminService.UpdateAccount(gctx, acc.ID, &service.UpdateAccountInput{
|
||||
Credentials: creds,
|
||||
Extra: extra,
|
||||
})
|
||||
|
||||
mu.Lock()
|
||||
if updateErr != nil {
|
||||
failedCount++
|
||||
errors = append(errors, gin.H{
|
||||
"account_id": acc.ID,
|
||||
"error": updateErr.Error(),
|
||||
})
|
||||
} else {
|
||||
successCount++
|
||||
}
|
||||
mu.Unlock()
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
if err := g.Wait(); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
results := gin.H{
|
||||
"total": len(accounts),
|
||||
"success": successCount,
|
||||
"failed": failedCount,
|
||||
"errors": errors,
|
||||
}
|
||||
|
||||
response.Success(c, results)
|
||||
}
|
||||
|
||||
@@ -46,8 +46,8 @@ func (h *GeminiOAuthHandler) GenerateAuthURL(c *gin.Context) {
|
||||
if oauthType == "" {
|
||||
oauthType = "code_assist"
|
||||
}
|
||||
if oauthType != "code_assist" && oauthType != "ai_studio" {
|
||||
response.BadRequest(c, "Invalid oauth_type: must be 'code_assist' or 'ai_studio'")
|
||||
if oauthType != "code_assist" && oauthType != "google_one" && oauthType != "ai_studio" {
|
||||
response.BadRequest(c, "Invalid oauth_type: must be 'code_assist', 'google_one', or 'ai_studio'")
|
||||
return
|
||||
}
|
||||
|
||||
@@ -92,8 +92,8 @@ func (h *GeminiOAuthHandler) ExchangeCode(c *gin.Context) {
|
||||
if oauthType == "" {
|
||||
oauthType = "code_assist"
|
||||
}
|
||||
if oauthType != "code_assist" && oauthType != "ai_studio" {
|
||||
response.BadRequest(c, "Invalid oauth_type: must be 'code_assist' or 'ai_studio'")
|
||||
if oauthType != "code_assist" && oauthType != "google_one" && oauthType != "ai_studio" {
|
||||
response.BadRequest(c, "Invalid oauth_type: must be 'code_assist', 'google_one', or 'ai_studio'")
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -246,7 +246,7 @@ func (h *UsageHandler) SearchUsers(c *gin.Context) {
|
||||
}
|
||||
|
||||
// Limit to 30 results
|
||||
users, _, err := h.adminService.ListUsers(c.Request.Context(), 1, 30, "", "", keyword)
|
||||
users, _, err := h.adminService.ListUsers(c.Request.Context(), 1, 30, service.UserListFilters{Search: keyword})
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
|
||||
342
backend/internal/handler/admin/user_attribute_handler.go
Normal file
342
backend/internal/handler/admin/user_attribute_handler.go
Normal file
@@ -0,0 +1,342 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// UserAttributeHandler handles user attribute management
|
||||
type UserAttributeHandler struct {
|
||||
attrService *service.UserAttributeService
|
||||
}
|
||||
|
||||
// NewUserAttributeHandler creates a new handler
|
||||
func NewUserAttributeHandler(attrService *service.UserAttributeService) *UserAttributeHandler {
|
||||
return &UserAttributeHandler{attrService: attrService}
|
||||
}
|
||||
|
||||
// --- Request/Response DTOs ---
|
||||
|
||||
// CreateAttributeDefinitionRequest represents create attribute definition request
|
||||
type CreateAttributeDefinitionRequest struct {
|
||||
Key string `json:"key" binding:"required,min=1,max=100"`
|
||||
Name string `json:"name" binding:"required,min=1,max=255"`
|
||||
Description string `json:"description"`
|
||||
Type string `json:"type" binding:"required"`
|
||||
Options []service.UserAttributeOption `json:"options"`
|
||||
Required bool `json:"required"`
|
||||
Validation service.UserAttributeValidation `json:"validation"`
|
||||
Placeholder string `json:"placeholder"`
|
||||
Enabled bool `json:"enabled"`
|
||||
}
|
||||
|
||||
// UpdateAttributeDefinitionRequest represents update attribute definition request
|
||||
type UpdateAttributeDefinitionRequest struct {
|
||||
Name *string `json:"name"`
|
||||
Description *string `json:"description"`
|
||||
Type *string `json:"type"`
|
||||
Options *[]service.UserAttributeOption `json:"options"`
|
||||
Required *bool `json:"required"`
|
||||
Validation *service.UserAttributeValidation `json:"validation"`
|
||||
Placeholder *string `json:"placeholder"`
|
||||
Enabled *bool `json:"enabled"`
|
||||
}
|
||||
|
||||
// ReorderRequest represents reorder attribute definitions request
|
||||
type ReorderRequest struct {
|
||||
IDs []int64 `json:"ids" binding:"required"`
|
||||
}
|
||||
|
||||
// UpdateUserAttributesRequest represents update user attributes request
|
||||
type UpdateUserAttributesRequest struct {
|
||||
Values map[int64]string `json:"values" binding:"required"`
|
||||
}
|
||||
|
||||
// BatchGetUserAttributesRequest represents batch get user attributes request
|
||||
type BatchGetUserAttributesRequest struct {
|
||||
UserIDs []int64 `json:"user_ids" binding:"required"`
|
||||
}
|
||||
|
||||
// BatchUserAttributesResponse represents batch user attributes response
|
||||
type BatchUserAttributesResponse struct {
|
||||
// Map of userID -> map of attributeID -> value
|
||||
Attributes map[int64]map[int64]string `json:"attributes"`
|
||||
}
|
||||
|
||||
// AttributeDefinitionResponse represents attribute definition response
|
||||
type AttributeDefinitionResponse struct {
|
||||
ID int64 `json:"id"`
|
||||
Key string `json:"key"`
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
Type string `json:"type"`
|
||||
Options []service.UserAttributeOption `json:"options"`
|
||||
Required bool `json:"required"`
|
||||
Validation service.UserAttributeValidation `json:"validation"`
|
||||
Placeholder string `json:"placeholder"`
|
||||
DisplayOrder int `json:"display_order"`
|
||||
Enabled bool `json:"enabled"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
UpdatedAt string `json:"updated_at"`
|
||||
}
|
||||
|
||||
// AttributeValueResponse represents attribute value response
|
||||
type AttributeValueResponse struct {
|
||||
ID int64 `json:"id"`
|
||||
UserID int64 `json:"user_id"`
|
||||
AttributeID int64 `json:"attribute_id"`
|
||||
Value string `json:"value"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
UpdatedAt string `json:"updated_at"`
|
||||
}
|
||||
|
||||
// --- Helpers ---
|
||||
|
||||
func defToResponse(def *service.UserAttributeDefinition) *AttributeDefinitionResponse {
|
||||
return &AttributeDefinitionResponse{
|
||||
ID: def.ID,
|
||||
Key: def.Key,
|
||||
Name: def.Name,
|
||||
Description: def.Description,
|
||||
Type: string(def.Type),
|
||||
Options: def.Options,
|
||||
Required: def.Required,
|
||||
Validation: def.Validation,
|
||||
Placeholder: def.Placeholder,
|
||||
DisplayOrder: def.DisplayOrder,
|
||||
Enabled: def.Enabled,
|
||||
CreatedAt: def.CreatedAt.Format("2006-01-02T15:04:05Z07:00"),
|
||||
UpdatedAt: def.UpdatedAt.Format("2006-01-02T15:04:05Z07:00"),
|
||||
}
|
||||
}
|
||||
|
||||
func valueToResponse(val *service.UserAttributeValue) *AttributeValueResponse {
|
||||
return &AttributeValueResponse{
|
||||
ID: val.ID,
|
||||
UserID: val.UserID,
|
||||
AttributeID: val.AttributeID,
|
||||
Value: val.Value,
|
||||
CreatedAt: val.CreatedAt.Format("2006-01-02T15:04:05Z07:00"),
|
||||
UpdatedAt: val.UpdatedAt.Format("2006-01-02T15:04:05Z07:00"),
|
||||
}
|
||||
}
|
||||
|
||||
// --- Handlers ---
|
||||
|
||||
// ListDefinitions lists all attribute definitions
|
||||
// GET /admin/user-attributes
|
||||
func (h *UserAttributeHandler) ListDefinitions(c *gin.Context) {
|
||||
enabledOnly := c.Query("enabled") == "true"
|
||||
|
||||
defs, err := h.attrService.ListDefinitions(c.Request.Context(), enabledOnly)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
out := make([]*AttributeDefinitionResponse, 0, len(defs))
|
||||
for i := range defs {
|
||||
out = append(out, defToResponse(&defs[i]))
|
||||
}
|
||||
|
||||
response.Success(c, out)
|
||||
}
|
||||
|
||||
// CreateDefinition creates a new attribute definition
|
||||
// POST /admin/user-attributes
|
||||
func (h *UserAttributeHandler) CreateDefinition(c *gin.Context) {
|
||||
var req CreateAttributeDefinitionRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
def, err := h.attrService.CreateDefinition(c.Request.Context(), service.CreateAttributeDefinitionInput{
|
||||
Key: req.Key,
|
||||
Name: req.Name,
|
||||
Description: req.Description,
|
||||
Type: service.UserAttributeType(req.Type),
|
||||
Options: req.Options,
|
||||
Required: req.Required,
|
||||
Validation: req.Validation,
|
||||
Placeholder: req.Placeholder,
|
||||
Enabled: req.Enabled,
|
||||
})
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, defToResponse(def))
|
||||
}
|
||||
|
||||
// UpdateDefinition updates an attribute definition
|
||||
// PUT /admin/user-attributes/:id
|
||||
func (h *UserAttributeHandler) UpdateDefinition(c *gin.Context) {
|
||||
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid attribute ID")
|
||||
return
|
||||
}
|
||||
|
||||
var req UpdateAttributeDefinitionRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
input := service.UpdateAttributeDefinitionInput{
|
||||
Name: req.Name,
|
||||
Description: req.Description,
|
||||
Options: req.Options,
|
||||
Required: req.Required,
|
||||
Validation: req.Validation,
|
||||
Placeholder: req.Placeholder,
|
||||
Enabled: req.Enabled,
|
||||
}
|
||||
if req.Type != nil {
|
||||
t := service.UserAttributeType(*req.Type)
|
||||
input.Type = &t
|
||||
}
|
||||
|
||||
def, err := h.attrService.UpdateDefinition(c.Request.Context(), id, input)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, defToResponse(def))
|
||||
}
|
||||
|
||||
// DeleteDefinition deletes an attribute definition
|
||||
// DELETE /admin/user-attributes/:id
|
||||
func (h *UserAttributeHandler) DeleteDefinition(c *gin.Context) {
|
||||
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid attribute ID")
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.attrService.DeleteDefinition(c.Request.Context(), id); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{"message": "Attribute definition deleted successfully"})
|
||||
}
|
||||
|
||||
// ReorderDefinitions reorders attribute definitions
|
||||
// PUT /admin/user-attributes/reorder
|
||||
func (h *UserAttributeHandler) ReorderDefinitions(c *gin.Context) {
|
||||
var req ReorderRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// Convert IDs array to orders map (position in array = display_order)
|
||||
orders := make(map[int64]int, len(req.IDs))
|
||||
for i, id := range req.IDs {
|
||||
orders[id] = i
|
||||
}
|
||||
|
||||
if err := h.attrService.ReorderDefinitions(c.Request.Context(), orders); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{"message": "Reorder successful"})
|
||||
}
|
||||
|
||||
// GetUserAttributes gets a user's attribute values
|
||||
// GET /admin/users/:id/attributes
|
||||
func (h *UserAttributeHandler) GetUserAttributes(c *gin.Context) {
|
||||
userID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid user ID")
|
||||
return
|
||||
}
|
||||
|
||||
values, err := h.attrService.GetUserAttributes(c.Request.Context(), userID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
out := make([]*AttributeValueResponse, 0, len(values))
|
||||
for i := range values {
|
||||
out = append(out, valueToResponse(&values[i]))
|
||||
}
|
||||
|
||||
response.Success(c, out)
|
||||
}
|
||||
|
||||
// UpdateUserAttributes updates a user's attribute values
|
||||
// PUT /admin/users/:id/attributes
|
||||
func (h *UserAttributeHandler) UpdateUserAttributes(c *gin.Context) {
|
||||
userID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid user ID")
|
||||
return
|
||||
}
|
||||
|
||||
var req UpdateUserAttributesRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
inputs := make([]service.UpdateUserAttributeInput, 0, len(req.Values))
|
||||
for attrID, value := range req.Values {
|
||||
inputs = append(inputs, service.UpdateUserAttributeInput{
|
||||
AttributeID: attrID,
|
||||
Value: value,
|
||||
})
|
||||
}
|
||||
|
||||
if err := h.attrService.UpdateUserAttributes(c.Request.Context(), userID, inputs); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Return updated values
|
||||
values, err := h.attrService.GetUserAttributes(c.Request.Context(), userID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
out := make([]*AttributeValueResponse, 0, len(values))
|
||||
for i := range values {
|
||||
out = append(out, valueToResponse(&values[i]))
|
||||
}
|
||||
|
||||
response.Success(c, out)
|
||||
}
|
||||
|
||||
// GetBatchUserAttributes gets attribute values for multiple users
|
||||
// POST /admin/user-attributes/batch
|
||||
func (h *UserAttributeHandler) GetBatchUserAttributes(c *gin.Context) {
|
||||
var req BatchGetUserAttributesRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if len(req.UserIDs) == 0 {
|
||||
response.Success(c, BatchUserAttributesResponse{Attributes: map[int64]map[int64]string{}})
|
||||
return
|
||||
}
|
||||
|
||||
attrs, err := h.attrService.GetBatchUserAttributes(c.Request.Context(), req.UserIDs)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, BatchUserAttributesResponse{Attributes: attrs})
|
||||
}
|
||||
@@ -27,7 +27,6 @@ type CreateUserRequest struct {
|
||||
Email string `json:"email" binding:"required,email"`
|
||||
Password string `json:"password" binding:"required,min=6"`
|
||||
Username string `json:"username"`
|
||||
Wechat string `json:"wechat"`
|
||||
Notes string `json:"notes"`
|
||||
Balance float64 `json:"balance"`
|
||||
Concurrency int `json:"concurrency"`
|
||||
@@ -40,7 +39,6 @@ type UpdateUserRequest struct {
|
||||
Email string `json:"email" binding:"omitempty,email"`
|
||||
Password string `json:"password" binding:"omitempty,min=6"`
|
||||
Username *string `json:"username"`
|
||||
Wechat *string `json:"wechat"`
|
||||
Notes *string `json:"notes"`
|
||||
Balance *float64 `json:"balance"`
|
||||
Concurrency *int `json:"concurrency"`
|
||||
@@ -57,13 +55,22 @@ type UpdateBalanceRequest struct {
|
||||
|
||||
// List handles listing all users with pagination
|
||||
// GET /api/v1/admin/users
|
||||
// Query params:
|
||||
// - status: filter by user status
|
||||
// - role: filter by user role
|
||||
// - search: search in email, username
|
||||
// - attr[{id}]: filter by custom attribute value, e.g. attr[1]=company
|
||||
func (h *UserHandler) List(c *gin.Context) {
|
||||
page, pageSize := response.ParsePagination(c)
|
||||
status := c.Query("status")
|
||||
role := c.Query("role")
|
||||
search := c.Query("search")
|
||||
|
||||
users, total, err := h.adminService.ListUsers(c.Request.Context(), page, pageSize, status, role, search)
|
||||
filters := service.UserListFilters{
|
||||
Status: c.Query("status"),
|
||||
Role: c.Query("role"),
|
||||
Search: c.Query("search"),
|
||||
Attributes: parseAttributeFilters(c),
|
||||
}
|
||||
|
||||
users, total, err := h.adminService.ListUsers(c.Request.Context(), page, pageSize, filters)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
@@ -76,6 +83,29 @@ func (h *UserHandler) List(c *gin.Context) {
|
||||
response.Paginated(c, out, total, page, pageSize)
|
||||
}
|
||||
|
||||
// parseAttributeFilters extracts attribute filters from query params
|
||||
// Format: attr[{attributeID}]=value, e.g. attr[1]=company&attr[2]=developer
|
||||
func parseAttributeFilters(c *gin.Context) map[int64]string {
|
||||
result := make(map[int64]string)
|
||||
|
||||
// Get all query params and look for attr[*] pattern
|
||||
for key, values := range c.Request.URL.Query() {
|
||||
if len(values) == 0 || values[0] == "" {
|
||||
continue
|
||||
}
|
||||
// Check if key matches pattern attr[{id}]
|
||||
if len(key) > 5 && key[:5] == "attr[" && key[len(key)-1] == ']' {
|
||||
idStr := key[5 : len(key)-1]
|
||||
id, err := strconv.ParseInt(idStr, 10, 64)
|
||||
if err == nil && id > 0 {
|
||||
result[id] = values[0]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// GetByID handles getting a user by ID
|
||||
// GET /api/v1/admin/users/:id
|
||||
func (h *UserHandler) GetByID(c *gin.Context) {
|
||||
@@ -107,7 +137,6 @@ func (h *UserHandler) Create(c *gin.Context) {
|
||||
Email: req.Email,
|
||||
Password: req.Password,
|
||||
Username: req.Username,
|
||||
Wechat: req.Wechat,
|
||||
Notes: req.Notes,
|
||||
Balance: req.Balance,
|
||||
Concurrency: req.Concurrency,
|
||||
@@ -141,7 +170,6 @@ func (h *UserHandler) Update(c *gin.Context) {
|
||||
Email: req.Email,
|
||||
Password: req.Password,
|
||||
Username: req.Username,
|
||||
Wechat: req.Wechat,
|
||||
Notes: req.Notes,
|
||||
Balance: req.Balance,
|
||||
Concurrency: req.Concurrency,
|
||||
|
||||
@@ -10,7 +10,6 @@ func UserFromServiceShallow(u *service.User) *User {
|
||||
ID: u.ID,
|
||||
Email: u.Email,
|
||||
Username: u.Username,
|
||||
Wechat: u.Wechat,
|
||||
Notes: u.Notes,
|
||||
Role: u.Role,
|
||||
Balance: u.Balance,
|
||||
|
||||
@@ -6,7 +6,6 @@ type User struct {
|
||||
ID int64 `json:"id"`
|
||||
Email string `json:"email"`
|
||||
Username string `json:"username"`
|
||||
Wechat string `json:"wechat"`
|
||||
Notes string `json:"notes"`
|
||||
Role string `json:"role"`
|
||||
Balance float64 `json:"balance"`
|
||||
|
||||
@@ -142,6 +142,10 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
} else if apiKey.Group != nil {
|
||||
platform = apiKey.Group.Platform
|
||||
}
|
||||
sessionKey := sessionHash
|
||||
if platform == service.PlatformGemini && sessionHash != "" {
|
||||
sessionKey = "gemini:" + sessionHash
|
||||
}
|
||||
|
||||
if platform == service.PlatformGemini {
|
||||
const maxAccountSwitches = 3
|
||||
@@ -150,7 +154,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
lastFailoverStatus := 0
|
||||
|
||||
for {
|
||||
account, err := h.geminiCompatService.SelectAccountForModelWithExclusions(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, failedAccountIDs)
|
||||
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, reqModel, failedAccountIDs)
|
||||
if err != nil {
|
||||
if len(failedAccountIDs) == 0 {
|
||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
|
||||
@@ -159,9 +163,13 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted)
|
||||
return
|
||||
}
|
||||
account := selection.Account
|
||||
|
||||
// 检查预热请求拦截(在账号选择后、转发前检查)
|
||||
if account.IsInterceptWarmupEnabled() && isWarmupRequest(body) {
|
||||
if selection.Acquired && selection.ReleaseFunc != nil {
|
||||
selection.ReleaseFunc()
|
||||
}
|
||||
if reqStream {
|
||||
sendMockWarmupStream(c, reqModel)
|
||||
} else {
|
||||
@@ -171,11 +179,46 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
}
|
||||
|
||||
// 3. 获取账号并发槽位
|
||||
accountReleaseFunc, err := h.concurrencyHelper.AcquireAccountSlotWithWait(c, account.ID, account.Concurrency, reqStream, &streamStarted)
|
||||
if err != nil {
|
||||
log.Printf("Account concurrency acquire failed: %v", err)
|
||||
h.handleConcurrencyError(c, err, "account", streamStarted)
|
||||
return
|
||||
accountReleaseFunc := selection.ReleaseFunc
|
||||
var accountWaitRelease func()
|
||||
if !selection.Acquired {
|
||||
if selection.WaitPlan == nil {
|
||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted)
|
||||
return
|
||||
}
|
||||
canWait, err := h.concurrencyHelper.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting)
|
||||
if err != nil {
|
||||
log.Printf("Increment account wait count failed: %v", err)
|
||||
} else if !canWait {
|
||||
log.Printf("Account wait queue full: account=%d", account.ID)
|
||||
h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later", streamStarted)
|
||||
return
|
||||
} else {
|
||||
// Only set release function if increment succeeded
|
||||
accountWaitRelease = func() {
|
||||
h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID)
|
||||
}
|
||||
}
|
||||
|
||||
accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout(
|
||||
c,
|
||||
account.ID,
|
||||
selection.WaitPlan.MaxConcurrency,
|
||||
selection.WaitPlan.Timeout,
|
||||
reqStream,
|
||||
&streamStarted,
|
||||
)
|
||||
if err != nil {
|
||||
if accountWaitRelease != nil {
|
||||
accountWaitRelease()
|
||||
}
|
||||
log.Printf("Account concurrency acquire failed: %v", err)
|
||||
h.handleConcurrencyError(c, err, "account", streamStarted)
|
||||
return
|
||||
}
|
||||
if err := h.gatewayService.BindStickySession(c.Request.Context(), sessionKey, account.ID); err != nil {
|
||||
log.Printf("Bind sticky session failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// 转发请求 - 根据账号平台分流
|
||||
@@ -188,6 +231,9 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
if accountReleaseFunc != nil {
|
||||
accountReleaseFunc()
|
||||
}
|
||||
if accountWaitRelease != nil {
|
||||
accountWaitRelease()
|
||||
}
|
||||
if err != nil {
|
||||
var failoverErr *service.UpstreamFailoverError
|
||||
if errors.As(err, &failoverErr) {
|
||||
@@ -232,7 +278,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
|
||||
for {
|
||||
// 选择支持该模型的账号
|
||||
account, err := h.gatewayService.SelectAccountForModelWithExclusions(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, failedAccountIDs)
|
||||
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, reqModel, failedAccountIDs)
|
||||
if err != nil {
|
||||
if len(failedAccountIDs) == 0 {
|
||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
|
||||
@@ -241,9 +287,13 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted)
|
||||
return
|
||||
}
|
||||
account := selection.Account
|
||||
|
||||
// 检查预热请求拦截(在账号选择后、转发前检查)
|
||||
if account.IsInterceptWarmupEnabled() && isWarmupRequest(body) {
|
||||
if selection.Acquired && selection.ReleaseFunc != nil {
|
||||
selection.ReleaseFunc()
|
||||
}
|
||||
if reqStream {
|
||||
sendMockWarmupStream(c, reqModel)
|
||||
} else {
|
||||
@@ -253,11 +303,46 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
}
|
||||
|
||||
// 3. 获取账号并发槽位
|
||||
accountReleaseFunc, err := h.concurrencyHelper.AcquireAccountSlotWithWait(c, account.ID, account.Concurrency, reqStream, &streamStarted)
|
||||
if err != nil {
|
||||
log.Printf("Account concurrency acquire failed: %v", err)
|
||||
h.handleConcurrencyError(c, err, "account", streamStarted)
|
||||
return
|
||||
accountReleaseFunc := selection.ReleaseFunc
|
||||
var accountWaitRelease func()
|
||||
if !selection.Acquired {
|
||||
if selection.WaitPlan == nil {
|
||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted)
|
||||
return
|
||||
}
|
||||
canWait, err := h.concurrencyHelper.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting)
|
||||
if err != nil {
|
||||
log.Printf("Increment account wait count failed: %v", err)
|
||||
} else if !canWait {
|
||||
log.Printf("Account wait queue full: account=%d", account.ID)
|
||||
h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later", streamStarted)
|
||||
return
|
||||
} else {
|
||||
// Only set release function if increment succeeded
|
||||
accountWaitRelease = func() {
|
||||
h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID)
|
||||
}
|
||||
}
|
||||
|
||||
accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout(
|
||||
c,
|
||||
account.ID,
|
||||
selection.WaitPlan.MaxConcurrency,
|
||||
selection.WaitPlan.Timeout,
|
||||
reqStream,
|
||||
&streamStarted,
|
||||
)
|
||||
if err != nil {
|
||||
if accountWaitRelease != nil {
|
||||
accountWaitRelease()
|
||||
}
|
||||
log.Printf("Account concurrency acquire failed: %v", err)
|
||||
h.handleConcurrencyError(c, err, "account", streamStarted)
|
||||
return
|
||||
}
|
||||
if err := h.gatewayService.BindStickySession(c.Request.Context(), sessionKey, account.ID); err != nil {
|
||||
log.Printf("Bind sticky session failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// 转发请求 - 根据账号平台分流
|
||||
@@ -270,6 +355,9 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
if accountReleaseFunc != nil {
|
||||
accountReleaseFunc()
|
||||
}
|
||||
if accountWaitRelease != nil {
|
||||
accountWaitRelease()
|
||||
}
|
||||
if err != nil {
|
||||
var failoverErr *service.UpstreamFailoverError
|
||||
if errors.As(err, &failoverErr) {
|
||||
@@ -309,7 +397,8 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
|
||||
// Models handles listing available models
|
||||
// GET /v1/models
|
||||
// Returns different model lists based on the API key's group platform or forced platform
|
||||
// Returns models based on account configurations (model_mapping whitelist)
|
||||
// Falls back to default models if no whitelist is configured
|
||||
func (h *GatewayHandler) Models(c *gin.Context) {
|
||||
apiKey, _ := middleware2.GetApiKeyFromContext(c)
|
||||
|
||||
@@ -324,8 +413,37 @@ func (h *GatewayHandler) Models(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
// Return OpenAI models for OpenAI platform groups
|
||||
if apiKey != nil && apiKey.Group != nil && apiKey.Group.Platform == "openai" {
|
||||
var groupID *int64
|
||||
var platform string
|
||||
|
||||
if apiKey != nil && apiKey.Group != nil {
|
||||
groupID = &apiKey.Group.ID
|
||||
platform = apiKey.Group.Platform
|
||||
}
|
||||
|
||||
// Get available models from account configurations (without platform filter)
|
||||
availableModels := h.gatewayService.GetAvailableModels(c.Request.Context(), groupID, "")
|
||||
|
||||
if len(availableModels) > 0 {
|
||||
// Build model list from whitelist
|
||||
models := make([]claude.Model, 0, len(availableModels))
|
||||
for _, modelID := range availableModels {
|
||||
models = append(models, claude.Model{
|
||||
ID: modelID,
|
||||
Type: "model",
|
||||
DisplayName: modelID,
|
||||
CreatedAt: "2024-01-01T00:00:00Z",
|
||||
})
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"object": "list",
|
||||
"data": models,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Fallback to default models
|
||||
if platform == "openai" {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"object": "list",
|
||||
"data": openai.DefaultModels,
|
||||
@@ -333,7 +451,6 @@ func (h *GatewayHandler) Models(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// Default: Claude models
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"object": "list",
|
||||
"data": claude.DefaultModels,
|
||||
|
||||
@@ -83,6 +83,16 @@ func (h *ConcurrencyHelper) DecrementWaitCount(ctx context.Context, userID int64
|
||||
h.concurrencyService.DecrementWaitCount(ctx, userID)
|
||||
}
|
||||
|
||||
// IncrementAccountWaitCount increments the wait count for an account
|
||||
func (h *ConcurrencyHelper) IncrementAccountWaitCount(ctx context.Context, accountID int64, maxWait int) (bool, error) {
|
||||
return h.concurrencyService.IncrementAccountWaitCount(ctx, accountID, maxWait)
|
||||
}
|
||||
|
||||
// DecrementAccountWaitCount decrements the wait count for an account
|
||||
func (h *ConcurrencyHelper) DecrementAccountWaitCount(ctx context.Context, accountID int64) {
|
||||
h.concurrencyService.DecrementAccountWaitCount(ctx, accountID)
|
||||
}
|
||||
|
||||
// AcquireUserSlotWithWait acquires a user concurrency slot, waiting if necessary.
|
||||
// For streaming requests, sends ping events during the wait.
|
||||
// streamStarted is updated if streaming response has begun.
|
||||
@@ -126,7 +136,12 @@ func (h *ConcurrencyHelper) AcquireAccountSlotWithWait(c *gin.Context, accountID
|
||||
// waitForSlotWithPing waits for a concurrency slot, sending ping events for streaming requests.
|
||||
// streamStarted pointer is updated when streaming begins (for proper error handling by caller).
|
||||
func (h *ConcurrencyHelper) waitForSlotWithPing(c *gin.Context, slotType string, id int64, maxConcurrency int, isStream bool, streamStarted *bool) (func(), error) {
|
||||
ctx, cancel := context.WithTimeout(c.Request.Context(), maxConcurrencyWait)
|
||||
return h.waitForSlotWithPingTimeout(c, slotType, id, maxConcurrency, maxConcurrencyWait, isStream, streamStarted)
|
||||
}
|
||||
|
||||
// waitForSlotWithPingTimeout waits for a concurrency slot with a custom timeout.
|
||||
func (h *ConcurrencyHelper) waitForSlotWithPingTimeout(c *gin.Context, slotType string, id int64, maxConcurrency int, timeout time.Duration, isStream bool, streamStarted *bool) (func(), error) {
|
||||
ctx, cancel := context.WithTimeout(c.Request.Context(), timeout)
|
||||
defer cancel()
|
||||
|
||||
// Determine if ping is needed (streaming + ping format defined)
|
||||
@@ -200,6 +215,11 @@ func (h *ConcurrencyHelper) waitForSlotWithPing(c *gin.Context, slotType string,
|
||||
}
|
||||
}
|
||||
|
||||
// AcquireAccountSlotWithWaitTimeout acquires an account slot with a custom timeout (keeps SSE ping).
|
||||
func (h *ConcurrencyHelper) AcquireAccountSlotWithWaitTimeout(c *gin.Context, accountID int64, maxConcurrency int, timeout time.Duration, isStream bool, streamStarted *bool) (func(), error) {
|
||||
return h.waitForSlotWithPingTimeout(c, "account", accountID, maxConcurrency, timeout, isStream, streamStarted)
|
||||
}
|
||||
|
||||
// nextBackoff 计算下一次退避时间
|
||||
// 性能优化:使用指数退避 + 随机抖动,避免惊群效应
|
||||
// current: 当前退避时间
|
||||
|
||||
@@ -198,13 +198,17 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
// 3) select account (sticky session based on request body)
|
||||
parsedReq, _ := service.ParseGatewayRequest(body)
|
||||
sessionHash := h.gatewayService.GenerateSessionHash(parsedReq)
|
||||
sessionKey := sessionHash
|
||||
if sessionHash != "" {
|
||||
sessionKey = "gemini:" + sessionHash
|
||||
}
|
||||
const maxAccountSwitches = 3
|
||||
switchCount := 0
|
||||
failedAccountIDs := make(map[int64]struct{})
|
||||
lastFailoverStatus := 0
|
||||
|
||||
for {
|
||||
account, err := h.geminiCompatService.SelectAccountForModelWithExclusions(c.Request.Context(), apiKey.GroupID, sessionHash, modelName, failedAccountIDs)
|
||||
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, modelName, failedAccountIDs)
|
||||
if err != nil {
|
||||
if len(failedAccountIDs) == 0 {
|
||||
googleError(c, http.StatusServiceUnavailable, "No available Gemini accounts: "+err.Error())
|
||||
@@ -213,12 +217,48 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
handleGeminiFailoverExhausted(c, lastFailoverStatus)
|
||||
return
|
||||
}
|
||||
account := selection.Account
|
||||
|
||||
// 4) account concurrency slot
|
||||
accountReleaseFunc, err := geminiConcurrency.AcquireAccountSlotWithWait(c, account.ID, account.Concurrency, stream, &streamStarted)
|
||||
if err != nil {
|
||||
googleError(c, http.StatusTooManyRequests, err.Error())
|
||||
return
|
||||
accountReleaseFunc := selection.ReleaseFunc
|
||||
var accountWaitRelease func()
|
||||
if !selection.Acquired {
|
||||
if selection.WaitPlan == nil {
|
||||
googleError(c, http.StatusServiceUnavailable, "No available Gemini accounts")
|
||||
return
|
||||
}
|
||||
canWait, err := geminiConcurrency.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting)
|
||||
if err != nil {
|
||||
log.Printf("Increment account wait count failed: %v", err)
|
||||
} else if !canWait {
|
||||
log.Printf("Account wait queue full: account=%d", account.ID)
|
||||
googleError(c, http.StatusTooManyRequests, "Too many pending requests, please retry later")
|
||||
return
|
||||
} else {
|
||||
// Only set release function if increment succeeded
|
||||
accountWaitRelease = func() {
|
||||
geminiConcurrency.DecrementAccountWaitCount(c.Request.Context(), account.ID)
|
||||
}
|
||||
}
|
||||
|
||||
accountReleaseFunc, err = geminiConcurrency.AcquireAccountSlotWithWaitTimeout(
|
||||
c,
|
||||
account.ID,
|
||||
selection.WaitPlan.MaxConcurrency,
|
||||
selection.WaitPlan.Timeout,
|
||||
stream,
|
||||
&streamStarted,
|
||||
)
|
||||
if err != nil {
|
||||
if accountWaitRelease != nil {
|
||||
accountWaitRelease()
|
||||
}
|
||||
googleError(c, http.StatusTooManyRequests, err.Error())
|
||||
return
|
||||
}
|
||||
if err := h.gatewayService.BindStickySession(c.Request.Context(), sessionKey, account.ID); err != nil {
|
||||
log.Printf("Bind sticky session failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// 5) forward (根据平台分流)
|
||||
@@ -231,6 +271,9 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
if accountReleaseFunc != nil {
|
||||
accountReleaseFunc()
|
||||
}
|
||||
if accountWaitRelease != nil {
|
||||
accountWaitRelease()
|
||||
}
|
||||
if err != nil {
|
||||
var failoverErr *service.UpstreamFailoverError
|
||||
if errors.As(err, &failoverErr) {
|
||||
|
||||
@@ -20,6 +20,7 @@ type AdminHandlers struct {
|
||||
System *admin.SystemHandler
|
||||
Subscription *admin.SubscriptionHandler
|
||||
Usage *admin.UsageHandler
|
||||
UserAttribute *admin.UserAttributeHandler
|
||||
}
|
||||
|
||||
// Handlers contains all HTTP handlers
|
||||
|
||||
@@ -146,7 +146,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
for {
|
||||
// Select account supporting the requested model
|
||||
log.Printf("[OpenAI Handler] Selecting account: groupID=%v model=%s", apiKey.GroupID, reqModel)
|
||||
account, err := h.gatewayService.SelectAccountForModelWithExclusions(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, failedAccountIDs)
|
||||
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, failedAccountIDs)
|
||||
if err != nil {
|
||||
log.Printf("[OpenAI Handler] SelectAccount failed: %v", err)
|
||||
if len(failedAccountIDs) == 0 {
|
||||
@@ -156,14 +156,50 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted)
|
||||
return
|
||||
}
|
||||
account := selection.Account
|
||||
log.Printf("[OpenAI Handler] Selected account: id=%d name=%s", account.ID, account.Name)
|
||||
|
||||
// 3. Acquire account concurrency slot
|
||||
accountReleaseFunc, err := h.concurrencyHelper.AcquireAccountSlotWithWait(c, account.ID, account.Concurrency, reqStream, &streamStarted)
|
||||
if err != nil {
|
||||
log.Printf("Account concurrency acquire failed: %v", err)
|
||||
h.handleConcurrencyError(c, err, "account", streamStarted)
|
||||
return
|
||||
accountReleaseFunc := selection.ReleaseFunc
|
||||
var accountWaitRelease func()
|
||||
if !selection.Acquired {
|
||||
if selection.WaitPlan == nil {
|
||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted)
|
||||
return
|
||||
}
|
||||
canWait, err := h.concurrencyHelper.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting)
|
||||
if err != nil {
|
||||
log.Printf("Increment account wait count failed: %v", err)
|
||||
} else if !canWait {
|
||||
log.Printf("Account wait queue full: account=%d", account.ID)
|
||||
h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later", streamStarted)
|
||||
return
|
||||
} else {
|
||||
// Only set release function if increment succeeded
|
||||
accountWaitRelease = func() {
|
||||
h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID)
|
||||
}
|
||||
}
|
||||
|
||||
accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout(
|
||||
c,
|
||||
account.ID,
|
||||
selection.WaitPlan.MaxConcurrency,
|
||||
selection.WaitPlan.Timeout,
|
||||
reqStream,
|
||||
&streamStarted,
|
||||
)
|
||||
if err != nil {
|
||||
if accountWaitRelease != nil {
|
||||
accountWaitRelease()
|
||||
}
|
||||
log.Printf("Account concurrency acquire failed: %v", err)
|
||||
h.handleConcurrencyError(c, err, "account", streamStarted)
|
||||
return
|
||||
}
|
||||
if err := h.gatewayService.BindStickySession(c.Request.Context(), sessionHash, account.ID); err != nil {
|
||||
log.Printf("Bind sticky session failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Forward request
|
||||
@@ -171,6 +207,9 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
if accountReleaseFunc != nil {
|
||||
accountReleaseFunc()
|
||||
}
|
||||
if accountWaitRelease != nil {
|
||||
accountWaitRelease()
|
||||
}
|
||||
if err != nil {
|
||||
var failoverErr *service.UpstreamFailoverError
|
||||
if errors.As(err, &failoverErr) {
|
||||
|
||||
@@ -30,7 +30,6 @@ type ChangePasswordRequest struct {
|
||||
// UpdateProfileRequest represents the update profile request payload
|
||||
type UpdateProfileRequest struct {
|
||||
Username *string `json:"username"`
|
||||
Wechat *string `json:"wechat"`
|
||||
}
|
||||
|
||||
// GetProfile handles getting user profile
|
||||
@@ -99,7 +98,6 @@ func (h *UserHandler) UpdateProfile(c *gin.Context) {
|
||||
|
||||
svcReq := service.UpdateProfileRequest{
|
||||
Username: req.Username,
|
||||
Wechat: req.Wechat,
|
||||
}
|
||||
updatedUser, err := h.userService.UpdateProfile(c.Request.Context(), subject.UserID, svcReq)
|
||||
if err != nil {
|
||||
|
||||
@@ -23,6 +23,7 @@ func ProvideAdminHandlers(
|
||||
systemHandler *admin.SystemHandler,
|
||||
subscriptionHandler *admin.SubscriptionHandler,
|
||||
usageHandler *admin.UsageHandler,
|
||||
userAttributeHandler *admin.UserAttributeHandler,
|
||||
) *AdminHandlers {
|
||||
return &AdminHandlers{
|
||||
Dashboard: dashboardHandler,
|
||||
@@ -39,6 +40,7 @@ func ProvideAdminHandlers(
|
||||
System: systemHandler,
|
||||
Subscription: subscriptionHandler,
|
||||
Usage: usageHandler,
|
||||
UserAttribute: userAttributeHandler,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -107,6 +109,7 @@ var ProviderSet = wire.NewSet(
|
||||
ProvideSystemHandler,
|
||||
admin.NewSubscriptionHandler,
|
||||
admin.NewUsageHandler,
|
||||
admin.NewUserAttributeHandler,
|
||||
|
||||
// AdminHandlers and Handlers constructors
|
||||
ProvideAdminHandlers,
|
||||
|
||||
Reference in New Issue
Block a user