feat(api-key): 添加 IP 白名单/黑名单限制功能 (#221)
* feat(api-key): add IP whitelist/blacklist restriction and usage log IP tracking - Add IP restriction feature for API keys (whitelist/blacklist with CIDR support) - Add IP address logging to usage logs (admin-only visibility) - Remove billing_type column from usage logs UI (redundant) - Use generic "Access denied" error message for security Backend: - New ip package with IP/CIDR validation and matching utilities - Database migrations for ip_whitelist, ip_blacklist (api_keys) and ip_address (usage_logs) - Middleware IP restriction check after API key validation - Input validation for IP/CIDR patterns on create/update Frontend: - API key form with enable toggle for IP restriction - Shield icon indicator in table for keys with IP restriction - Removed billing_type filter and column from usage views * fix: update API contract tests for ip_whitelist/ip_blacklist fields Add ip_whitelist and ip_blacklist fields to expected JSON responses in API contract tests to match the new API key schema.
This commit is contained in:
@@ -3,16 +3,18 @@ package service
|
||||
import "time"
|
||||
|
||||
type APIKey struct {
|
||||
ID int64
|
||||
UserID int64
|
||||
Key string
|
||||
Name string
|
||||
GroupID *int64
|
||||
Status string
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
User *User
|
||||
Group *Group
|
||||
ID int64
|
||||
UserID int64
|
||||
Key string
|
||||
Name string
|
||||
GroupID *int64
|
||||
Status string
|
||||
IPWhitelist []string
|
||||
IPBlacklist []string
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
User *User
|
||||
Group *Group
|
||||
}
|
||||
|
||||
func (k *APIKey) IsActive() bool {
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
|
||||
)
|
||||
@@ -20,6 +21,7 @@ var (
|
||||
ErrAPIKeyTooShort = infraerrors.BadRequest("API_KEY_TOO_SHORT", "api key must be at least 16 characters")
|
||||
ErrAPIKeyInvalidChars = infraerrors.BadRequest("API_KEY_INVALID_CHARS", "api key can only contain letters, numbers, underscores, and hyphens")
|
||||
ErrAPIKeyRateLimited = infraerrors.TooManyRequests("API_KEY_RATE_LIMITED", "too many failed attempts, please try again later")
|
||||
ErrInvalidIPPattern = infraerrors.BadRequest("INVALID_IP_PATTERN", "invalid IP or CIDR pattern")
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -57,16 +59,20 @@ type APIKeyCache interface {
|
||||
|
||||
// CreateAPIKeyRequest 创建API Key请求
|
||||
type CreateAPIKeyRequest struct {
|
||||
Name string `json:"name"`
|
||||
GroupID *int64 `json:"group_id"`
|
||||
CustomKey *string `json:"custom_key"` // 可选的自定义key
|
||||
Name string `json:"name"`
|
||||
GroupID *int64 `json:"group_id"`
|
||||
CustomKey *string `json:"custom_key"` // 可选的自定义key
|
||||
IPWhitelist []string `json:"ip_whitelist"` // IP 白名单
|
||||
IPBlacklist []string `json:"ip_blacklist"` // IP 黑名单
|
||||
}
|
||||
|
||||
// UpdateAPIKeyRequest 更新API Key请求
|
||||
type UpdateAPIKeyRequest struct {
|
||||
Name *string `json:"name"`
|
||||
GroupID *int64 `json:"group_id"`
|
||||
Status *string `json:"status"`
|
||||
Name *string `json:"name"`
|
||||
GroupID *int64 `json:"group_id"`
|
||||
Status *string `json:"status"`
|
||||
IPWhitelist []string `json:"ip_whitelist"` // IP 白名单(空数组清空)
|
||||
IPBlacklist []string `json:"ip_blacklist"` // IP 黑名单(空数组清空)
|
||||
}
|
||||
|
||||
// APIKeyService API Key服务
|
||||
@@ -186,6 +192,20 @@ func (s *APIKeyService) Create(ctx context.Context, userID int64, req CreateAPIK
|
||||
return nil, fmt.Errorf("get user: %w", err)
|
||||
}
|
||||
|
||||
// 验证 IP 白名单格式
|
||||
if len(req.IPWhitelist) > 0 {
|
||||
if invalid := ip.ValidateIPPatterns(req.IPWhitelist); len(invalid) > 0 {
|
||||
return nil, fmt.Errorf("%w: %v", ErrInvalidIPPattern, invalid)
|
||||
}
|
||||
}
|
||||
|
||||
// 验证 IP 黑名单格式
|
||||
if len(req.IPBlacklist) > 0 {
|
||||
if invalid := ip.ValidateIPPatterns(req.IPBlacklist); len(invalid) > 0 {
|
||||
return nil, fmt.Errorf("%w: %v", ErrInvalidIPPattern, invalid)
|
||||
}
|
||||
}
|
||||
|
||||
// 验证分组权限(如果指定了分组)
|
||||
if req.GroupID != nil {
|
||||
group, err := s.groupRepo.GetByID(ctx, *req.GroupID)
|
||||
@@ -236,11 +256,13 @@ func (s *APIKeyService) Create(ctx context.Context, userID int64, req CreateAPIK
|
||||
|
||||
// 创建API Key记录
|
||||
apiKey := &APIKey{
|
||||
UserID: userID,
|
||||
Key: key,
|
||||
Name: req.Name,
|
||||
GroupID: req.GroupID,
|
||||
Status: StatusActive,
|
||||
UserID: userID,
|
||||
Key: key,
|
||||
Name: req.Name,
|
||||
GroupID: req.GroupID,
|
||||
Status: StatusActive,
|
||||
IPWhitelist: req.IPWhitelist,
|
||||
IPBlacklist: req.IPBlacklist,
|
||||
}
|
||||
|
||||
if err := s.apiKeyRepo.Create(ctx, apiKey); err != nil {
|
||||
@@ -312,6 +334,20 @@ func (s *APIKeyService) Update(ctx context.Context, id int64, userID int64, req
|
||||
return nil, ErrInsufficientPerms
|
||||
}
|
||||
|
||||
// 验证 IP 白名单格式
|
||||
if len(req.IPWhitelist) > 0 {
|
||||
if invalid := ip.ValidateIPPatterns(req.IPWhitelist); len(invalid) > 0 {
|
||||
return nil, fmt.Errorf("%w: %v", ErrInvalidIPPattern, invalid)
|
||||
}
|
||||
}
|
||||
|
||||
// 验证 IP 黑名单格式
|
||||
if len(req.IPBlacklist) > 0 {
|
||||
if invalid := ip.ValidateIPPatterns(req.IPBlacklist); len(invalid) > 0 {
|
||||
return nil, fmt.Errorf("%w: %v", ErrInvalidIPPattern, invalid)
|
||||
}
|
||||
}
|
||||
|
||||
// 更新字段
|
||||
if req.Name != nil {
|
||||
apiKey.Name = *req.Name
|
||||
@@ -344,6 +380,10 @@ func (s *APIKeyService) Update(ctx context.Context, id int64, userID int64, req
|
||||
}
|
||||
}
|
||||
|
||||
// 更新 IP 限制(空数组会清空设置)
|
||||
apiKey.IPWhitelist = req.IPWhitelist
|
||||
apiKey.IPBlacklist = req.IPBlacklist
|
||||
|
||||
if err := s.apiKeyRepo.Update(ctx, apiKey); err != nil {
|
||||
return nil, fmt.Errorf("update api key: %w", err)
|
||||
}
|
||||
|
||||
@@ -2247,6 +2247,7 @@ type RecordUsageInput struct {
|
||||
Account *Account
|
||||
Subscription *UserSubscription // 可选:订阅信息
|
||||
UserAgent string // 请求的 User-Agent
|
||||
IPAddress string // 请求的客户端 IP 地址
|
||||
}
|
||||
|
||||
// RecordUsage 记录使用量并扣费(或更新订阅用量)
|
||||
@@ -2337,6 +2338,11 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
|
||||
usageLog.UserAgent = &input.UserAgent
|
||||
}
|
||||
|
||||
// 添加 IPAddress
|
||||
if input.IPAddress != "" {
|
||||
usageLog.IPAddress = &input.IPAddress
|
||||
}
|
||||
|
||||
// 添加分组和订阅关联
|
||||
if apiKey.GroupID != nil {
|
||||
usageLog.GroupID = apiKey.GroupID
|
||||
|
||||
@@ -1197,6 +1197,7 @@ type OpenAIRecordUsageInput struct {
|
||||
Account *Account
|
||||
Subscription *UserSubscription
|
||||
UserAgent string // 请求的 User-Agent
|
||||
IPAddress string // 请求的客户端 IP 地址
|
||||
}
|
||||
|
||||
// RecordUsage records usage and deducts balance
|
||||
@@ -1271,6 +1272,11 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
|
||||
usageLog.UserAgent = &input.UserAgent
|
||||
}
|
||||
|
||||
// 添加 IPAddress
|
||||
if input.IPAddress != "" {
|
||||
usageLog.IPAddress = &input.IPAddress
|
||||
}
|
||||
|
||||
if apiKey.GroupID != nil {
|
||||
usageLog.GroupID = apiKey.GroupID
|
||||
}
|
||||
|
||||
@@ -39,6 +39,7 @@ type UsageLog struct {
|
||||
DurationMs *int
|
||||
FirstTokenMs *int
|
||||
UserAgent *string
|
||||
IPAddress *string
|
||||
|
||||
// 图片生成字段
|
||||
ImageCount int
|
||||
|
||||
Reference in New Issue
Block a user