First commit
This commit is contained in:
74
.dockerignore
Normal file
74
.dockerignore
Normal file
@@ -0,0 +1,74 @@
|
||||
# =============================================================================
|
||||
# Docker Ignore File for Sub2API
|
||||
# =============================================================================
|
||||
|
||||
# Git
|
||||
.git
|
||||
.gitignore
|
||||
.gitattributes
|
||||
|
||||
# Documentation
|
||||
*.md
|
||||
!deploy/DOCKER.md
|
||||
docs/
|
||||
|
||||
# IDE
|
||||
.idea/
|
||||
.vscode/
|
||||
*.swp
|
||||
*.swo
|
||||
|
||||
# OS files
|
||||
.DS_Store
|
||||
Thumbs.db
|
||||
|
||||
# Build artifacts
|
||||
dist/
|
||||
build/
|
||||
|
||||
# Node modules (will be installed in container)
|
||||
frontend/node_modules/
|
||||
node_modules/
|
||||
|
||||
# Go build cache (will be built in container)
|
||||
backend/vendor/
|
||||
|
||||
# Test files
|
||||
*_test.go
|
||||
**/*.test.js
|
||||
coverage/
|
||||
.nyc_output/
|
||||
|
||||
# Environment files
|
||||
.env
|
||||
.env.*
|
||||
!.env.example
|
||||
|
||||
# Local config
|
||||
config.yaml
|
||||
config.local.yaml
|
||||
|
||||
# Logs
|
||||
*.log
|
||||
logs/
|
||||
|
||||
# Temporary files
|
||||
tmp/
|
||||
temp/
|
||||
*.tmp
|
||||
|
||||
# Deploy files (not needed in image)
|
||||
deploy/install.sh
|
||||
deploy/sub2api.service
|
||||
deploy/sub2api-sudoers
|
||||
|
||||
# GoReleaser
|
||||
.goreleaser.yaml
|
||||
|
||||
# GitHub
|
||||
.github/
|
||||
|
||||
# Claude files
|
||||
.claude/
|
||||
issues/
|
||||
CLAUDE.md
|
||||
178
.github/workflows/release.yml
vendored
Normal file
178
.github/workflows/release.yml
vendored
Normal file
@@ -0,0 +1,178 @@
|
||||
name: Release
|
||||
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- 'v*'
|
||||
|
||||
permissions:
|
||||
contents: write
|
||||
packages: write
|
||||
|
||||
jobs:
|
||||
# Update VERSION file with tag version
|
||||
update-version:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Update VERSION file
|
||||
run: |
|
||||
VERSION=${GITHUB_REF#refs/tags/v}
|
||||
echo "$VERSION" > backend/cmd/server/VERSION
|
||||
echo "Updated VERSION file to: $VERSION"
|
||||
|
||||
- name: Upload VERSION artifact
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: version-file
|
||||
path: backend/cmd/server/VERSION
|
||||
retention-days: 1
|
||||
|
||||
build-frontend:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Setup Node.js
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: '20'
|
||||
cache: 'npm'
|
||||
cache-dependency-path: frontend/package-lock.json
|
||||
|
||||
- name: Install dependencies
|
||||
run: npm ci
|
||||
working-directory: frontend
|
||||
|
||||
- name: Build frontend
|
||||
run: npm run build
|
||||
working-directory: frontend
|
||||
|
||||
- name: Upload frontend artifact
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: frontend-dist
|
||||
path: backend/internal/web/dist/
|
||||
retention-days: 1
|
||||
|
||||
release:
|
||||
needs: [update-version, build-frontend]
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Download VERSION artifact
|
||||
uses: actions/download-artifact@v4
|
||||
with:
|
||||
name: version-file
|
||||
path: backend/cmd/server/
|
||||
|
||||
- name: Download frontend artifact
|
||||
uses: actions/download-artifact@v4
|
||||
with:
|
||||
name: frontend-dist
|
||||
path: backend/internal/web/dist/
|
||||
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: '1.24'
|
||||
cache-dependency-path: backend/go.sum
|
||||
|
||||
- name: Run GoReleaser
|
||||
uses: goreleaser/goreleaser-action@v6
|
||||
with:
|
||||
version: '~> v2'
|
||||
args: release --clean
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
# ===========================================================================
|
||||
# Docker Build and Push
|
||||
# ===========================================================================
|
||||
docker:
|
||||
needs: [update-version, build-frontend]
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Download VERSION artifact
|
||||
uses: actions/download-artifact@v4
|
||||
with:
|
||||
name: version-file
|
||||
path: backend/cmd/server/
|
||||
|
||||
- name: Download frontend artifact
|
||||
uses: actions/download-artifact@v4
|
||||
with:
|
||||
name: frontend-dist
|
||||
path: backend/internal/web/dist/
|
||||
|
||||
# Extract version from tag
|
||||
- name: Extract version
|
||||
id: version
|
||||
run: |
|
||||
VERSION=${GITHUB_REF#refs/tags/v}
|
||||
echo "version=$VERSION" >> $GITHUB_OUTPUT
|
||||
echo "Version: $VERSION"
|
||||
|
||||
# Set up Docker Buildx for multi-platform builds
|
||||
- name: Set up QEMU
|
||||
uses: docker/setup-qemu-action@v3
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
# Login to DockerHub
|
||||
- name: Login to DockerHub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||
|
||||
# Extract metadata for Docker
|
||||
- name: Extract Docker metadata
|
||||
id: meta
|
||||
uses: docker/metadata-action@v5
|
||||
with:
|
||||
images: |
|
||||
weishaw/sub2api
|
||||
tags: |
|
||||
type=semver,pattern={{version}}
|
||||
type=semver,pattern={{major}}.{{minor}}
|
||||
type=semver,pattern={{major}}
|
||||
type=raw,value=latest,enable={{is_default_branch}}
|
||||
|
||||
# Build and push Docker image
|
||||
- name: Build and push Docker image
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
context: .
|
||||
file: ./Dockerfile
|
||||
platforms: linux/amd64,linux/arm64
|
||||
push: true
|
||||
tags: ${{ steps.meta.outputs.tags }}
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
build-args: |
|
||||
VERSION=${{ steps.version.outputs.version }}
|
||||
COMMIT=${{ github.sha }}
|
||||
DATE=${{ github.event.head_commit.timestamp }}
|
||||
cache-from: type=gha
|
||||
cache-to: type=gha,mode=max
|
||||
|
||||
# Update DockerHub description (optional)
|
||||
- name: Update DockerHub description
|
||||
uses: peter-evans/dockerhub-description@v4
|
||||
with:
|
||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||
repository: weishaw/sub2api
|
||||
short-description: "Sub2API - AI API Gateway Platform"
|
||||
readme-filepath: ./deploy/DOCKER.md
|
||||
93
.gitignore
vendored
Normal file
93
.gitignore
vendored
Normal file
@@ -0,0 +1,93 @@
|
||||
docs/claude-relay-service/
|
||||
|
||||
# ===================
|
||||
# Go 后端
|
||||
# ===================
|
||||
# 二进制文件
|
||||
*.exe
|
||||
*.exe~
|
||||
*.dll
|
||||
*.so
|
||||
*.dylib
|
||||
backend/bin/
|
||||
backend/server
|
||||
backend/sub2api
|
||||
|
||||
# 测试覆盖率
|
||||
*.out
|
||||
coverage.html
|
||||
|
||||
# 依赖(使用 go mod)
|
||||
vendor/
|
||||
|
||||
# ===================
|
||||
# Node.js / Vue 前端
|
||||
# ===================
|
||||
node_modules/
|
||||
frontend/node_modules/
|
||||
frontend/dist/
|
||||
*.local
|
||||
|
||||
# 日志
|
||||
npm-debug.log*
|
||||
yarn-debug.log*
|
||||
yarn-error.log*
|
||||
pnpm-debug.log*
|
||||
|
||||
# ===================
|
||||
# 环境配置
|
||||
# ===================
|
||||
.env
|
||||
.env.local
|
||||
.env.*.local
|
||||
*.env
|
||||
!.env.example
|
||||
|
||||
# ===================
|
||||
# IDE / 编辑器
|
||||
# ===================
|
||||
.idea/
|
||||
.vscode/
|
||||
*.swp
|
||||
*.swo
|
||||
*~
|
||||
.project
|
||||
.settings/
|
||||
.classpath
|
||||
|
||||
# ===================
|
||||
# 操作系统
|
||||
# ===================
|
||||
.DS_Store
|
||||
Thumbs.db
|
||||
Desktop.ini
|
||||
|
||||
# ===================
|
||||
# 临时文件
|
||||
# ===================
|
||||
tmp/
|
||||
temp/
|
||||
*.tmp
|
||||
*.temp
|
||||
*.log
|
||||
*.bak
|
||||
|
||||
# ===================
|
||||
# 构建产物
|
||||
# ===================
|
||||
dist/
|
||||
build/
|
||||
release/
|
||||
|
||||
# 后端嵌入的前端构建产物
|
||||
backend/internal/web/dist/
|
||||
|
||||
# 后端运行时缓存数据
|
||||
backend/data/
|
||||
|
||||
# ===================
|
||||
# 其他
|
||||
# ===================
|
||||
tests
|
||||
CLAUDE.md
|
||||
.claude
|
||||
85
.goreleaser.yaml
Normal file
85
.goreleaser.yaml
Normal file
@@ -0,0 +1,85 @@
|
||||
version: 2
|
||||
|
||||
project_name: sub2api
|
||||
|
||||
before:
|
||||
hooks:
|
||||
- go mod tidy -C backend
|
||||
|
||||
builds:
|
||||
- id: sub2api
|
||||
dir: backend
|
||||
main: ./cmd/server
|
||||
binary: sub2api
|
||||
env:
|
||||
- CGO_ENABLED=0
|
||||
goos:
|
||||
- linux
|
||||
- windows
|
||||
- darwin
|
||||
goarch:
|
||||
- amd64
|
||||
- arm64
|
||||
ignore:
|
||||
- goos: windows
|
||||
goarch: arm64
|
||||
ldflags:
|
||||
- -s -w
|
||||
- -X main.Commit={{.Commit}}
|
||||
- -X main.Date={{.Date}}
|
||||
- -X main.BuildType=release
|
||||
|
||||
archives:
|
||||
- id: default
|
||||
format: tar.gz
|
||||
name_template: >-
|
||||
{{ .ProjectName }}_{{ .Version }}_{{ .Os }}_{{ .Arch }}
|
||||
format_overrides:
|
||||
- goos: windows
|
||||
format: zip
|
||||
files:
|
||||
- LICENSE*
|
||||
- README*
|
||||
- deploy/*
|
||||
|
||||
checksum:
|
||||
name_template: 'checksums.txt'
|
||||
algorithm: sha256
|
||||
|
||||
changelog:
|
||||
# 禁用自动 changelog,完全使用 tag 消息
|
||||
disable: true
|
||||
|
||||
release:
|
||||
github:
|
||||
owner: Wei-Shaw
|
||||
name: sub2api
|
||||
draft: false
|
||||
prerelease: auto
|
||||
name_template: "v{{.Version}}"
|
||||
# 完全使用 tag 消息作为 release 内容
|
||||
header: |
|
||||
## Sub2API {{.Version}}
|
||||
|
||||
> AI API Gateway Platform - 将 AI 订阅配额分发和管理
|
||||
|
||||
{{ .TagBody }}
|
||||
|
||||
footer: |
|
||||
|
||||
---
|
||||
|
||||
## 📥 Installation
|
||||
|
||||
**One-line install (Linux):**
|
||||
```bash
|
||||
curl -sSL https://raw.githubusercontent.com/Wei-Shaw/sub2api/main/deploy/install.sh | sudo bash
|
||||
```
|
||||
|
||||
**Manual download:**
|
||||
Download the appropriate archive for your platform from the assets below.
|
||||
|
||||
## 📚 Documentation
|
||||
|
||||
- [GitHub Repository](https://github.com/Wei-Shaw/sub2api)
|
||||
- [Installation Guide](https://github.com/Wei-Shaw/sub2api/blob/main/deploy/README.md)
|
||||
96
Dockerfile
Normal file
96
Dockerfile
Normal file
@@ -0,0 +1,96 @@
|
||||
# =============================================================================
|
||||
# Sub2API Multi-Stage Dockerfile
|
||||
# =============================================================================
|
||||
# Stage 1: Build frontend
|
||||
# Stage 2: Build Go backend with embedded frontend
|
||||
# Stage 3: Final minimal image
|
||||
# =============================================================================
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Stage 1: Frontend Builder
|
||||
# -----------------------------------------------------------------------------
|
||||
FROM node:20-alpine AS frontend-builder
|
||||
|
||||
WORKDIR /app/frontend
|
||||
|
||||
# Install dependencies first (better caching)
|
||||
COPY frontend/package*.json ./
|
||||
RUN npm ci
|
||||
|
||||
# Copy frontend source and build
|
||||
COPY frontend/ ./
|
||||
RUN npm run build
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Stage 2: Backend Builder
|
||||
# -----------------------------------------------------------------------------
|
||||
FROM golang:1.24-alpine AS backend-builder
|
||||
|
||||
# Build arguments for version info (set by CI)
|
||||
ARG VERSION=docker
|
||||
ARG COMMIT=docker
|
||||
ARG DATE
|
||||
|
||||
# Install build dependencies
|
||||
RUN apk add --no-cache git ca-certificates tzdata
|
||||
|
||||
WORKDIR /app/backend
|
||||
|
||||
# Copy go mod files first (better caching)
|
||||
COPY backend/go.mod backend/go.sum ./
|
||||
RUN go mod download
|
||||
|
||||
# Copy frontend dist from previous stage
|
||||
COPY --from=frontend-builder /app/frontend/../backend/internal/web/dist ./internal/web/dist
|
||||
|
||||
# Copy backend source
|
||||
COPY backend/ ./
|
||||
|
||||
# Build the binary (BuildType=release for CI builds)
|
||||
RUN CGO_ENABLED=0 GOOS=linux go build \
|
||||
-ldflags="-s -w -X main.Commit=${COMMIT} -X main.Date=${DATE:-$(date -u +%Y-%m-%dT%H:%M:%SZ)} -X main.BuildType=release" \
|
||||
-o /app/sub2api \
|
||||
./cmd/server
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Stage 3: Final Runtime Image
|
||||
# -----------------------------------------------------------------------------
|
||||
FROM alpine:3.19
|
||||
|
||||
# Labels
|
||||
LABEL maintainer="Wei-Shaw <github.com/Wei-Shaw>"
|
||||
LABEL description="Sub2API - AI API Gateway Platform"
|
||||
LABEL org.opencontainers.image.source="https://github.com/Wei-Shaw/sub2api"
|
||||
|
||||
# Install runtime dependencies
|
||||
RUN apk add --no-cache \
|
||||
ca-certificates \
|
||||
tzdata \
|
||||
curl \
|
||||
&& rm -rf /var/cache/apk/*
|
||||
|
||||
# Create non-root user
|
||||
RUN addgroup -g 1000 sub2api && \
|
||||
adduser -u 1000 -G sub2api -s /bin/sh -D sub2api
|
||||
|
||||
# Set working directory
|
||||
WORKDIR /app
|
||||
|
||||
# Copy binary from builder
|
||||
COPY --from=backend-builder /app/sub2api /app/sub2api
|
||||
|
||||
# Create data directory
|
||||
RUN mkdir -p /app/data && chown -R sub2api:sub2api /app
|
||||
|
||||
# Switch to non-root user
|
||||
USER sub2api
|
||||
|
||||
# Expose port (can be overridden by SERVER_PORT env var)
|
||||
EXPOSE 8080
|
||||
|
||||
# Health check
|
||||
HEALTHCHECK --interval=30s --timeout=10s --start-period=10s --retries=3 \
|
||||
CMD curl -f http://localhost:${SERVER_PORT:-8080}/health || exit 1
|
||||
|
||||
# Run the application
|
||||
ENTRYPOINT ["/app/sub2api"]
|
||||
318
README.md
Normal file
318
README.md
Normal file
@@ -0,0 +1,318 @@
|
||||
# Sub2API
|
||||
|
||||
<div align="center">
|
||||
|
||||
[](https://golang.org/)
|
||||
[](https://vuejs.org/)
|
||||
[](https://www.postgresql.org/)
|
||||
[](https://redis.io/)
|
||||
[](https://www.docker.com/)
|
||||
|
||||
**AI API Gateway Platform for Subscription Quota Distribution**
|
||||
|
||||
English | [中文](README_CN.md)
|
||||
|
||||
</div>
|
||||
|
||||
---
|
||||
|
||||
## Overview
|
||||
|
||||
Sub2API is an AI API gateway platform designed to distribute and manage API quotas from AI product subscriptions (like Claude Code $200/month). Users can access upstream AI services through platform-generated API Keys, while the platform handles authentication, billing, load balancing, and request forwarding.
|
||||
|
||||
## Features
|
||||
|
||||
- **Multi-Account Management** - Support multiple upstream account types (OAuth, API Key)
|
||||
- **API Key Distribution** - Generate and manage API Keys for users
|
||||
- **Precise Billing** - Token-level usage tracking and cost calculation
|
||||
- **Smart Scheduling** - Intelligent account selection with sticky sessions
|
||||
- **Concurrency Control** - Per-user and per-account concurrency limits
|
||||
- **Rate Limiting** - Configurable request and token rate limits
|
||||
- **Admin Dashboard** - Web interface for monitoring and management
|
||||
|
||||
## Tech Stack
|
||||
|
||||
| Component | Technology |
|
||||
|-----------|------------|
|
||||
| Backend | Go 1.21+, Gin, GORM |
|
||||
| Frontend | Vue 3.4+, Vite 5+, TailwindCSS |
|
||||
| Database | PostgreSQL 15+ |
|
||||
| Cache/Queue | Redis 7+ |
|
||||
|
||||
---
|
||||
|
||||
## Deployment
|
||||
|
||||
### Method 1: Script Installation (Recommended)
|
||||
|
||||
One-click installation script that downloads pre-built binaries from GitHub Releases.
|
||||
|
||||
#### Prerequisites
|
||||
|
||||
- Linux server (amd64 or arm64)
|
||||
- PostgreSQL 15+ (installed and running)
|
||||
- Redis 7+ (installed and running)
|
||||
- Root privileges
|
||||
|
||||
#### Installation Steps
|
||||
|
||||
```bash
|
||||
# Download and run the installation script
|
||||
curl -sSL https://raw.githubusercontent.com/Wei-Shaw/sub2api/main/deploy/install.sh | sudo bash
|
||||
```
|
||||
|
||||
The script will:
|
||||
1. Detect your system architecture
|
||||
2. Download the latest release
|
||||
3. Install binary to `/opt/sub2api`
|
||||
4. Create systemd service
|
||||
5. Configure system user and permissions
|
||||
|
||||
#### Post-Installation
|
||||
|
||||
```bash
|
||||
# 1. Start the service
|
||||
sudo systemctl start sub2api
|
||||
|
||||
# 2. Enable auto-start on boot
|
||||
sudo systemctl enable sub2api
|
||||
|
||||
# 3. Open Setup Wizard in browser
|
||||
# http://YOUR_SERVER_IP:8080
|
||||
```
|
||||
|
||||
The Setup Wizard will guide you through:
|
||||
- Database configuration
|
||||
- Redis configuration
|
||||
- Admin account creation
|
||||
|
||||
#### Upgrade
|
||||
|
||||
You can upgrade directly from the **Admin Dashboard** by clicking the **Check for Updates** button in the top-left corner.
|
||||
|
||||
The web interface will:
|
||||
- Check for new versions automatically
|
||||
- Download and apply updates with one click
|
||||
- Support rollback if needed
|
||||
|
||||
#### Useful Commands
|
||||
|
||||
```bash
|
||||
# Check status
|
||||
sudo systemctl status sub2api
|
||||
|
||||
# View logs
|
||||
sudo journalctl -u sub2api -f
|
||||
|
||||
# Restart service
|
||||
sudo systemctl restart sub2api
|
||||
|
||||
# Uninstall
|
||||
curl -sSL https://raw.githubusercontent.com/Wei-Shaw/sub2api/main/deploy/install.sh | sudo bash -s uninstall
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Method 2: Docker Compose
|
||||
|
||||
Deploy with Docker Compose, including PostgreSQL and Redis containers.
|
||||
|
||||
#### Prerequisites
|
||||
|
||||
- Docker 20.10+
|
||||
- Docker Compose v2+
|
||||
|
||||
#### Installation Steps
|
||||
|
||||
```bash
|
||||
# 1. Clone the repository
|
||||
git clone https://github.com/Wei-Shaw/sub2api.git
|
||||
cd sub2api
|
||||
|
||||
# 2. Enter the deploy directory
|
||||
cd deploy
|
||||
|
||||
# 3. Copy environment configuration
|
||||
cp .env.example .env
|
||||
|
||||
# 4. Edit configuration (set your passwords)
|
||||
nano .env
|
||||
```
|
||||
|
||||
**Required configuration in `.env`:**
|
||||
|
||||
```bash
|
||||
# PostgreSQL password (REQUIRED - change this!)
|
||||
POSTGRES_PASSWORD=your_secure_password_here
|
||||
|
||||
# Optional: Admin account
|
||||
ADMIN_EMAIL=admin@example.com
|
||||
ADMIN_PASSWORD=your_admin_password
|
||||
|
||||
# Optional: Custom port
|
||||
SERVER_PORT=8080
|
||||
```
|
||||
|
||||
```bash
|
||||
# 5. Start all services
|
||||
docker-compose up -d
|
||||
|
||||
# 6. Check status
|
||||
docker-compose ps
|
||||
|
||||
# 7. View logs
|
||||
docker-compose logs -f sub2api
|
||||
```
|
||||
|
||||
#### Access
|
||||
|
||||
Open `http://YOUR_SERVER_IP:8080` in your browser.
|
||||
|
||||
#### Upgrade
|
||||
|
||||
```bash
|
||||
# Pull latest image and recreate container
|
||||
docker-compose pull
|
||||
docker-compose up -d
|
||||
```
|
||||
|
||||
#### Useful Commands
|
||||
|
||||
```bash
|
||||
# Stop all services
|
||||
docker-compose down
|
||||
|
||||
# Restart
|
||||
docker-compose restart
|
||||
|
||||
# View all logs
|
||||
docker-compose logs -f
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Method 3: Build from Source
|
||||
|
||||
Build and run from source code for development or customization.
|
||||
|
||||
#### Prerequisites
|
||||
|
||||
- Go 1.21+
|
||||
- Node.js 18+
|
||||
- PostgreSQL 15+
|
||||
- Redis 7+
|
||||
|
||||
#### Build Steps
|
||||
|
||||
```bash
|
||||
# 1. Clone the repository
|
||||
git clone https://github.com/Wei-Shaw/sub2api.git
|
||||
cd sub2api
|
||||
|
||||
# 2. Build backend
|
||||
cd backend
|
||||
go build -o sub2api ./cmd/server
|
||||
|
||||
# 3. Build frontend
|
||||
cd ../frontend
|
||||
npm install
|
||||
npm run build
|
||||
|
||||
# 4. Copy frontend build to backend (for embedding)
|
||||
cp -r dist ../backend/internal/web/
|
||||
|
||||
# 5. Create configuration file
|
||||
cd ../backend
|
||||
cp ../deploy/config.example.yaml ./config.yaml
|
||||
|
||||
# 6. Edit configuration
|
||||
nano config.yaml
|
||||
```
|
||||
|
||||
**Key configuration in `config.yaml`:**
|
||||
|
||||
```yaml
|
||||
server:
|
||||
host: "0.0.0.0"
|
||||
port: 8080
|
||||
mode: "release"
|
||||
|
||||
database:
|
||||
host: "localhost"
|
||||
port: 5432
|
||||
user: "postgres"
|
||||
password: "your_password"
|
||||
dbname: "sub2api"
|
||||
|
||||
redis:
|
||||
host: "localhost"
|
||||
port: 6379
|
||||
password: ""
|
||||
|
||||
jwt:
|
||||
secret: "change-this-to-a-secure-random-string"
|
||||
expire_hour: 24
|
||||
|
||||
default:
|
||||
admin_email: "admin@example.com"
|
||||
admin_password: "admin123"
|
||||
```
|
||||
|
||||
```bash
|
||||
# 7. Run the application
|
||||
./sub2api
|
||||
```
|
||||
|
||||
#### Development Mode
|
||||
|
||||
```bash
|
||||
# Backend (with hot reload)
|
||||
cd backend
|
||||
go run ./cmd/server
|
||||
|
||||
# Frontend (with hot reload)
|
||||
cd frontend
|
||||
npm run dev
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Project Structure
|
||||
|
||||
```
|
||||
sub2api/
|
||||
├── backend/ # Go backend service
|
||||
│ ├── cmd/server/ # Application entry
|
||||
│ ├── internal/ # Internal modules
|
||||
│ │ ├── config/ # Configuration
|
||||
│ │ ├── model/ # Data models
|
||||
│ │ ├── service/ # Business logic
|
||||
│ │ ├── handler/ # HTTP handlers
|
||||
│ │ └── gateway/ # API gateway core
|
||||
│ └── resources/ # Static resources
|
||||
│
|
||||
├── frontend/ # Vue 3 frontend
|
||||
│ └── src/
|
||||
│ ├── api/ # API calls
|
||||
│ ├── stores/ # State management
|
||||
│ ├── views/ # Page components
|
||||
│ └── components/ # Reusable components
|
||||
│
|
||||
└── deploy/ # Deployment files
|
||||
├── docker-compose.yml # Docker Compose configuration
|
||||
├── .env.example # Environment variables for Docker Compose
|
||||
├── config.example.yaml # Full config file for binary deployment
|
||||
└── install.sh # One-click installation script
|
||||
```
|
||||
|
||||
## License
|
||||
|
||||
MIT License
|
||||
|
||||
---
|
||||
|
||||
<div align="center">
|
||||
|
||||
**If you find this project useful, please give it a star!**
|
||||
|
||||
</div>
|
||||
318
README_CN.md
Normal file
318
README_CN.md
Normal file
@@ -0,0 +1,318 @@
|
||||
# Sub2API
|
||||
|
||||
<div align="center">
|
||||
|
||||
[](https://golang.org/)
|
||||
[](https://vuejs.org/)
|
||||
[](https://www.postgresql.org/)
|
||||
[](https://redis.io/)
|
||||
[](https://www.docker.com/)
|
||||
|
||||
**AI API 网关平台 - 订阅配额分发管理**
|
||||
|
||||
[English](README.md) | 中文
|
||||
|
||||
</div>
|
||||
|
||||
---
|
||||
|
||||
## 项目概述
|
||||
|
||||
Sub2API 是一个 AI API 网关平台,用于分发和管理 AI 产品订阅(如 Claude Code $200/月)的 API 配额。用户通过平台生成的 API Key 调用上游 AI 服务,平台负责鉴权、计费、负载均衡和请求转发。
|
||||
|
||||
## 核心功能
|
||||
|
||||
- **多账号管理** - 支持多种上游账号类型(OAuth、API Key)
|
||||
- **API Key 分发** - 为用户生成和管理 API Key
|
||||
- **精确计费** - Token 级别的用量追踪和成本计算
|
||||
- **智能调度** - 智能账号选择,支持粘性会话
|
||||
- **并发控制** - 用户级和账号级并发限制
|
||||
- **速率限制** - 可配置的请求和 Token 速率限制
|
||||
- **管理后台** - Web 界面进行监控和管理
|
||||
|
||||
## 技术栈
|
||||
|
||||
| 组件 | 技术 |
|
||||
|------|------|
|
||||
| 后端 | Go 1.21+, Gin, GORM |
|
||||
| 前端 | Vue 3.4+, Vite 5+, TailwindCSS |
|
||||
| 数据库 | PostgreSQL 15+ |
|
||||
| 缓存/队列 | Redis 7+ |
|
||||
|
||||
---
|
||||
|
||||
## 部署方式
|
||||
|
||||
### 方式一:脚本安装(推荐)
|
||||
|
||||
一键安装脚本,自动从 GitHub Releases 下载预编译的二进制文件。
|
||||
|
||||
#### 前置条件
|
||||
|
||||
- Linux 服务器(amd64 或 arm64)
|
||||
- PostgreSQL 15+(已安装并运行)
|
||||
- Redis 7+(已安装并运行)
|
||||
- Root 权限
|
||||
|
||||
#### 安装步骤
|
||||
|
||||
```bash
|
||||
# 下载并运行安装脚本
|
||||
curl -sSL https://raw.githubusercontent.com/Wei-Shaw/sub2api/main/deploy/install.sh | sudo bash
|
||||
```
|
||||
|
||||
脚本会自动:
|
||||
1. 检测系统架构
|
||||
2. 下载最新版本
|
||||
3. 安装二进制文件到 `/opt/sub2api`
|
||||
4. 创建 systemd 服务
|
||||
5. 配置系统用户和权限
|
||||
|
||||
#### 安装后配置
|
||||
|
||||
```bash
|
||||
# 1. 启动服务
|
||||
sudo systemctl start sub2api
|
||||
|
||||
# 2. 设置开机自启
|
||||
sudo systemctl enable sub2api
|
||||
|
||||
# 3. 在浏览器中打开设置向导
|
||||
# http://你的服务器IP:8080
|
||||
```
|
||||
|
||||
设置向导将引导你完成:
|
||||
- 数据库配置
|
||||
- Redis 配置
|
||||
- 管理员账号创建
|
||||
|
||||
#### 升级
|
||||
|
||||
可以直接在 **管理后台** 左上角点击 **检测更新** 按钮进行在线升级。
|
||||
|
||||
网页升级功能支持:
|
||||
- 自动检测新版本
|
||||
- 一键下载并应用更新
|
||||
- 支持回滚
|
||||
|
||||
#### 常用命令
|
||||
|
||||
```bash
|
||||
# 查看状态
|
||||
sudo systemctl status sub2api
|
||||
|
||||
# 查看日志
|
||||
sudo journalctl -u sub2api -f
|
||||
|
||||
# 重启服务
|
||||
sudo systemctl restart sub2api
|
||||
|
||||
# 卸载
|
||||
curl -sSL https://raw.githubusercontent.com/Wei-Shaw/sub2api/main/deploy/install.sh | sudo bash -s uninstall
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### 方式二:Docker Compose
|
||||
|
||||
使用 Docker Compose 部署,包含 PostgreSQL 和 Redis 容器。
|
||||
|
||||
#### 前置条件
|
||||
|
||||
- Docker 20.10+
|
||||
- Docker Compose v2+
|
||||
|
||||
#### 安装步骤
|
||||
|
||||
```bash
|
||||
# 1. 克隆仓库
|
||||
git clone https://github.com/Wei-Shaw/sub2api.git
|
||||
cd sub2api
|
||||
|
||||
# 2. 进入 deploy 目录
|
||||
cd deploy
|
||||
|
||||
# 3. 复制环境配置文件
|
||||
cp .env.example .env
|
||||
|
||||
# 4. 编辑配置(设置密码等)
|
||||
nano .env
|
||||
```
|
||||
|
||||
**`.env` 必须配置项:**
|
||||
|
||||
```bash
|
||||
# PostgreSQL 密码(必须修改!)
|
||||
POSTGRES_PASSWORD=your_secure_password_here
|
||||
|
||||
# 可选:管理员账号
|
||||
ADMIN_EMAIL=admin@example.com
|
||||
ADMIN_PASSWORD=your_admin_password
|
||||
|
||||
# 可选:自定义端口
|
||||
SERVER_PORT=8080
|
||||
```
|
||||
|
||||
```bash
|
||||
# 5. 启动所有服务
|
||||
docker-compose up -d
|
||||
|
||||
# 6. 查看状态
|
||||
docker-compose ps
|
||||
|
||||
# 7. 查看日志
|
||||
docker-compose logs -f sub2api
|
||||
```
|
||||
|
||||
#### 访问
|
||||
|
||||
在浏览器中打开 `http://你的服务器IP:8080`
|
||||
|
||||
#### 升级
|
||||
|
||||
```bash
|
||||
# 拉取最新镜像并重建容器
|
||||
docker-compose pull
|
||||
docker-compose up -d
|
||||
```
|
||||
|
||||
#### 常用命令
|
||||
|
||||
```bash
|
||||
# 停止所有服务
|
||||
docker-compose down
|
||||
|
||||
# 重启
|
||||
docker-compose restart
|
||||
|
||||
# 查看所有日志
|
||||
docker-compose logs -f
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### 方式三:源码编译
|
||||
|
||||
从源码编译安装,适合开发或定制需求。
|
||||
|
||||
#### 前置条件
|
||||
|
||||
- Go 1.21+
|
||||
- Node.js 18+
|
||||
- PostgreSQL 15+
|
||||
- Redis 7+
|
||||
|
||||
#### 编译步骤
|
||||
|
||||
```bash
|
||||
# 1. 克隆仓库
|
||||
git clone https://github.com/Wei-Shaw/sub2api.git
|
||||
cd sub2api
|
||||
|
||||
# 2. 编译后端
|
||||
cd backend
|
||||
go build -o sub2api ./cmd/server
|
||||
|
||||
# 3. 编译前端
|
||||
cd ../frontend
|
||||
npm install
|
||||
npm run build
|
||||
|
||||
# 4. 复制前端构建产物到后端(用于嵌入)
|
||||
cp -r dist ../backend/internal/web/
|
||||
|
||||
# 5. 创建配置文件
|
||||
cd ../backend
|
||||
cp ../deploy/config.example.yaml ./config.yaml
|
||||
|
||||
# 6. 编辑配置
|
||||
nano config.yaml
|
||||
```
|
||||
|
||||
**`config.yaml` 关键配置:**
|
||||
|
||||
```yaml
|
||||
server:
|
||||
host: "0.0.0.0"
|
||||
port: 8080
|
||||
mode: "release"
|
||||
|
||||
database:
|
||||
host: "localhost"
|
||||
port: 5432
|
||||
user: "postgres"
|
||||
password: "your_password"
|
||||
dbname: "sub2api"
|
||||
|
||||
redis:
|
||||
host: "localhost"
|
||||
port: 6379
|
||||
password: ""
|
||||
|
||||
jwt:
|
||||
secret: "change-this-to-a-secure-random-string"
|
||||
expire_hour: 24
|
||||
|
||||
default:
|
||||
admin_email: "admin@example.com"
|
||||
admin_password: "admin123"
|
||||
```
|
||||
|
||||
```bash
|
||||
# 7. 运行应用
|
||||
./sub2api
|
||||
```
|
||||
|
||||
#### 开发模式
|
||||
|
||||
```bash
|
||||
# 后端(支持热重载)
|
||||
cd backend
|
||||
go run ./cmd/server
|
||||
|
||||
# 前端(支持热重载)
|
||||
cd frontend
|
||||
npm run dev
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 项目结构
|
||||
|
||||
```
|
||||
sub2api/
|
||||
├── backend/ # Go 后端服务
|
||||
│ ├── cmd/server/ # 应用入口
|
||||
│ ├── internal/ # 内部模块
|
||||
│ │ ├── config/ # 配置管理
|
||||
│ │ ├── model/ # 数据模型
|
||||
│ │ ├── service/ # 业务逻辑
|
||||
│ │ ├── handler/ # HTTP 处理器
|
||||
│ │ └── gateway/ # API 网关核心
|
||||
│ └── resources/ # 静态资源
|
||||
│
|
||||
├── frontend/ # Vue 3 前端
|
||||
│ └── src/
|
||||
│ ├── api/ # API 调用
|
||||
│ ├── stores/ # 状态管理
|
||||
│ ├── views/ # 页面组件
|
||||
│ └── components/ # 通用组件
|
||||
│
|
||||
└── deploy/ # 部署文件
|
||||
├── docker-compose.yml # Docker Compose 配置
|
||||
├── .env.example # Docker Compose 环境变量
|
||||
├── config.example.yaml # 二进制部署完整配置文件
|
||||
└── install.sh # 一键安装脚本
|
||||
```
|
||||
|
||||
## 许可证
|
||||
|
||||
MIT License
|
||||
|
||||
---
|
||||
|
||||
<div align="center">
|
||||
|
||||
**如果觉得有用,请给个 Star 支持一下!**
|
||||
|
||||
</div>
|
||||
24
backend/Dockerfile
Normal file
24
backend/Dockerfile
Normal file
@@ -0,0 +1,24 @@
|
||||
FROM golang:1.21-alpine
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# 安装必要的工具
|
||||
RUN apk add --no-cache git
|
||||
|
||||
# 复制go.mod和go.sum
|
||||
COPY go.mod go.sum ./
|
||||
|
||||
# 下载依赖
|
||||
RUN go mod download
|
||||
|
||||
# 复制源代码
|
||||
COPY . .
|
||||
|
||||
# 构建应用
|
||||
RUN go build -o main cmd/server/main.go
|
||||
|
||||
# 暴露端口
|
||||
EXPOSE 8080
|
||||
|
||||
# 运行应用
|
||||
CMD ["./main"]
|
||||
1
backend/cmd/server/VERSION
Normal file
1
backend/cmd/server/VERSION
Normal file
@@ -0,0 +1 @@
|
||||
0.1.1
|
||||
470
backend/cmd/server/main.go
Normal file
470
backend/cmd/server/main.go
Normal file
@@ -0,0 +1,470 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
_ "embed"
|
||||
"flag"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/signal"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"sub2api/internal/config"
|
||||
"sub2api/internal/handler"
|
||||
"sub2api/internal/middleware"
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/pkg/timezone"
|
||||
"sub2api/internal/repository"
|
||||
"sub2api/internal/service"
|
||||
"sub2api/internal/setup"
|
||||
"sub2api/internal/web"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/redis/go-redis/v9"
|
||||
"gorm.io/driver/postgres"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/logger"
|
||||
)
|
||||
|
||||
//go:embed VERSION
|
||||
var embeddedVersion string
|
||||
|
||||
// Build-time variables (can be set by ldflags)
|
||||
var (
|
||||
Version = ""
|
||||
Commit = "unknown"
|
||||
Date = "unknown"
|
||||
BuildType = "source" // "source" for manual builds, "release" for CI builds (set by ldflags)
|
||||
)
|
||||
|
||||
func init() {
|
||||
// Read version from embedded VERSION file
|
||||
Version = strings.TrimSpace(embeddedVersion)
|
||||
if Version == "" {
|
||||
Version = "0.0.0-dev"
|
||||
}
|
||||
}
|
||||
|
||||
func main() {
|
||||
// Parse command line flags
|
||||
setupMode := flag.Bool("setup", false, "Run setup wizard in CLI mode")
|
||||
showVersion := flag.Bool("version", false, "Show version information")
|
||||
flag.Parse()
|
||||
|
||||
if *showVersion {
|
||||
log.Printf("Sub2API %s (commit: %s, built: %s)\n", Version, Commit, Date)
|
||||
return
|
||||
}
|
||||
|
||||
// CLI setup mode
|
||||
if *setupMode {
|
||||
if err := setup.RunCLI(); err != nil {
|
||||
log.Fatalf("Setup failed: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Check if setup is needed
|
||||
if setup.NeedsSetup() {
|
||||
// Check if auto-setup is enabled (for Docker deployment)
|
||||
if setup.AutoSetupEnabled() {
|
||||
log.Println("Auto setup mode enabled...")
|
||||
if err := setup.AutoSetupFromEnv(); err != nil {
|
||||
log.Fatalf("Auto setup failed: %v", err)
|
||||
}
|
||||
// Continue to main server after auto-setup
|
||||
} else {
|
||||
log.Println("First run detected, starting setup wizard...")
|
||||
runSetupServer()
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Normal server mode
|
||||
runMainServer()
|
||||
}
|
||||
|
||||
func runSetupServer() {
|
||||
r := gin.New()
|
||||
r.Use(gin.Recovery())
|
||||
r.Use(middleware.CORS())
|
||||
|
||||
// Register setup routes
|
||||
setup.RegisterRoutes(r)
|
||||
|
||||
// Serve embedded frontend if available
|
||||
if web.HasEmbeddedFrontend() {
|
||||
r.Use(web.ServeEmbeddedFrontend())
|
||||
}
|
||||
|
||||
addr := ":8080"
|
||||
log.Printf("Setup wizard available at http://localhost%s", addr)
|
||||
log.Println("Complete the setup wizard to configure Sub2API")
|
||||
|
||||
if err := r.Run(addr); err != nil {
|
||||
log.Fatalf("Failed to start setup server: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func runMainServer() {
|
||||
// 加载配置
|
||||
cfg, err := config.Load()
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to load config: %v", err)
|
||||
}
|
||||
|
||||
// 初始化时区(类似 PHP 的 date_default_timezone_set)
|
||||
if err := timezone.Init(cfg.Timezone); err != nil {
|
||||
log.Fatalf("Failed to initialize timezone: %v", err)
|
||||
}
|
||||
|
||||
// 初始化数据库
|
||||
db, err := initDB(cfg)
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to connect to database: %v", err)
|
||||
}
|
||||
|
||||
// 初始化Redis
|
||||
rdb := initRedis(cfg)
|
||||
|
||||
// 初始化Repository
|
||||
repos := repository.NewRepositories(db)
|
||||
|
||||
// 初始化Service
|
||||
services := service.NewServices(repos, rdb, cfg)
|
||||
|
||||
// 初始化Handler
|
||||
buildInfo := handler.BuildInfo{
|
||||
Version: Version,
|
||||
BuildType: BuildType,
|
||||
}
|
||||
handlers := handler.NewHandlers(services, repos, rdb, buildInfo)
|
||||
|
||||
// 设置Gin模式
|
||||
if cfg.Server.Mode == "release" {
|
||||
gin.SetMode(gin.ReleaseMode)
|
||||
}
|
||||
|
||||
// 创建路由
|
||||
r := gin.New()
|
||||
r.Use(gin.Recovery())
|
||||
r.Use(middleware.Logger())
|
||||
r.Use(middleware.CORS())
|
||||
|
||||
// 注册路由
|
||||
registerRoutes(r, handlers, services, repos)
|
||||
|
||||
// Serve embedded frontend if available
|
||||
if web.HasEmbeddedFrontend() {
|
||||
r.Use(web.ServeEmbeddedFrontend())
|
||||
}
|
||||
|
||||
// 启动服务器
|
||||
srv := &http.Server{
|
||||
Addr: cfg.Server.Address(),
|
||||
Handler: r,
|
||||
// ReadHeaderTimeout: 读取请求头的超时时间,防止慢速请求头攻击
|
||||
ReadHeaderTimeout: time.Duration(cfg.Server.ReadHeaderTimeout) * time.Second,
|
||||
// IdleTimeout: 空闲连接超时时间,释放不活跃的连接资源
|
||||
IdleTimeout: time.Duration(cfg.Server.IdleTimeout) * time.Second,
|
||||
// 注意:不设置 WriteTimeout,因为流式响应可能持续十几分钟
|
||||
// 不设置 ReadTimeout,因为大请求体可能需要较长时间读取
|
||||
}
|
||||
|
||||
// 优雅关闭
|
||||
go func() {
|
||||
if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
||||
log.Fatalf("Failed to start server: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
log.Printf("Server started on %s", cfg.Server.Address())
|
||||
|
||||
// 等待中断信号
|
||||
quit := make(chan os.Signal, 1)
|
||||
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
|
||||
<-quit
|
||||
|
||||
log.Println("Shutting down server...")
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := srv.Shutdown(ctx); err != nil {
|
||||
log.Fatalf("Server forced to shutdown: %v", err)
|
||||
}
|
||||
|
||||
log.Println("Server exited")
|
||||
}
|
||||
|
||||
func initDB(cfg *config.Config) (*gorm.DB, error) {
|
||||
gormConfig := &gorm.Config{}
|
||||
if cfg.Server.Mode == "debug" {
|
||||
gormConfig.Logger = logger.Default.LogMode(logger.Info)
|
||||
}
|
||||
|
||||
// 使用带时区的 DSN 连接数据库
|
||||
db, err := gorm.Open(postgres.Open(cfg.Database.DSNWithTimezone(cfg.Timezone)), gormConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 自动迁移(开发环境)
|
||||
if cfg.Server.Mode == "debug" {
|
||||
if err := model.AutoMigrate(db); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return db, nil
|
||||
}
|
||||
|
||||
func initRedis(cfg *config.Config) *redis.Client {
|
||||
return redis.NewClient(&redis.Options{
|
||||
Addr: cfg.Redis.Address(),
|
||||
Password: cfg.Redis.Password,
|
||||
DB: cfg.Redis.DB,
|
||||
})
|
||||
}
|
||||
|
||||
func registerRoutes(r *gin.Engine, h *handler.Handlers, s *service.Services, repos *repository.Repositories) {
|
||||
// 健康检查
|
||||
r.GET("/health", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"status": "ok"})
|
||||
})
|
||||
|
||||
// API v1
|
||||
v1 := r.Group("/api/v1")
|
||||
{
|
||||
// 公开接口
|
||||
auth := v1.Group("/auth")
|
||||
{
|
||||
auth.POST("/register", h.Auth.Register)
|
||||
auth.POST("/login", h.Auth.Login)
|
||||
auth.POST("/send-verify-code", h.Auth.SendVerifyCode)
|
||||
}
|
||||
|
||||
// 公开设置(无需认证)
|
||||
settings := v1.Group("/settings")
|
||||
{
|
||||
settings.GET("/public", h.Setting.GetPublicSettings)
|
||||
}
|
||||
|
||||
// 需要认证的接口
|
||||
authenticated := v1.Group("")
|
||||
authenticated.Use(middleware.JWTAuth(s.Auth, repos.User))
|
||||
{
|
||||
// 当前用户信息
|
||||
authenticated.GET("/auth/me", h.Auth.GetCurrentUser)
|
||||
|
||||
// 用户接口
|
||||
user := authenticated.Group("/user")
|
||||
{
|
||||
user.GET("/profile", h.User.GetProfile)
|
||||
user.PUT("/password", h.User.ChangePassword)
|
||||
}
|
||||
|
||||
// API Key管理
|
||||
keys := authenticated.Group("/keys")
|
||||
{
|
||||
keys.GET("", h.APIKey.List)
|
||||
keys.GET("/:id", h.APIKey.GetByID)
|
||||
keys.POST("", h.APIKey.Create)
|
||||
keys.PUT("/:id", h.APIKey.Update)
|
||||
keys.DELETE("/:id", h.APIKey.Delete)
|
||||
}
|
||||
|
||||
// 用户可用分组(非管理员接口)
|
||||
groups := authenticated.Group("/groups")
|
||||
{
|
||||
groups.GET("/available", h.APIKey.GetAvailableGroups)
|
||||
}
|
||||
|
||||
// 使用记录
|
||||
usage := authenticated.Group("/usage")
|
||||
{
|
||||
usage.GET("", h.Usage.List)
|
||||
usage.GET("/:id", h.Usage.GetByID)
|
||||
usage.GET("/stats", h.Usage.Stats)
|
||||
// User dashboard endpoints
|
||||
usage.GET("/dashboard/stats", h.Usage.DashboardStats)
|
||||
usage.GET("/dashboard/trend", h.Usage.DashboardTrend)
|
||||
usage.GET("/dashboard/models", h.Usage.DashboardModels)
|
||||
usage.POST("/dashboard/api-keys-usage", h.Usage.DashboardApiKeysUsage)
|
||||
}
|
||||
|
||||
// 卡密兑换
|
||||
redeem := authenticated.Group("/redeem")
|
||||
{
|
||||
redeem.POST("", h.Redeem.Redeem)
|
||||
redeem.GET("/history", h.Redeem.GetHistory)
|
||||
}
|
||||
|
||||
// 用户订阅
|
||||
subscriptions := authenticated.Group("/subscriptions")
|
||||
{
|
||||
subscriptions.GET("", h.Subscription.List)
|
||||
subscriptions.GET("/active", h.Subscription.GetActive)
|
||||
subscriptions.GET("/progress", h.Subscription.GetProgress)
|
||||
subscriptions.GET("/summary", h.Subscription.GetSummary)
|
||||
}
|
||||
}
|
||||
|
||||
// 管理员接口
|
||||
admin := v1.Group("/admin")
|
||||
admin.Use(middleware.JWTAuth(s.Auth, repos.User), middleware.AdminOnly())
|
||||
{
|
||||
// 仪表盘
|
||||
dashboard := admin.Group("/dashboard")
|
||||
{
|
||||
dashboard.GET("/stats", h.Admin.Dashboard.GetStats)
|
||||
dashboard.GET("/realtime", h.Admin.Dashboard.GetRealtimeMetrics)
|
||||
dashboard.GET("/trend", h.Admin.Dashboard.GetUsageTrend)
|
||||
dashboard.GET("/models", h.Admin.Dashboard.GetModelStats)
|
||||
dashboard.GET("/api-keys-trend", h.Admin.Dashboard.GetApiKeyUsageTrend)
|
||||
dashboard.GET("/users-trend", h.Admin.Dashboard.GetUserUsageTrend)
|
||||
dashboard.POST("/users-usage", h.Admin.Dashboard.GetBatchUsersUsage)
|
||||
dashboard.POST("/api-keys-usage", h.Admin.Dashboard.GetBatchApiKeysUsage)
|
||||
}
|
||||
|
||||
// 用户管理
|
||||
users := admin.Group("/users")
|
||||
{
|
||||
users.GET("", h.Admin.User.List)
|
||||
users.GET("/:id", h.Admin.User.GetByID)
|
||||
users.POST("", h.Admin.User.Create)
|
||||
users.PUT("/:id", h.Admin.User.Update)
|
||||
users.DELETE("/:id", h.Admin.User.Delete)
|
||||
users.POST("/:id/balance", h.Admin.User.UpdateBalance)
|
||||
users.GET("/:id/api-keys", h.Admin.User.GetUserAPIKeys)
|
||||
users.GET("/:id/usage", h.Admin.User.GetUserUsage)
|
||||
}
|
||||
|
||||
// 分组管理
|
||||
groups := admin.Group("/groups")
|
||||
{
|
||||
groups.GET("", h.Admin.Group.List)
|
||||
groups.GET("/all", h.Admin.Group.GetAll)
|
||||
groups.GET("/:id", h.Admin.Group.GetByID)
|
||||
groups.POST("", h.Admin.Group.Create)
|
||||
groups.PUT("/:id", h.Admin.Group.Update)
|
||||
groups.DELETE("/:id", h.Admin.Group.Delete)
|
||||
groups.GET("/:id/stats", h.Admin.Group.GetStats)
|
||||
groups.GET("/:id/api-keys", h.Admin.Group.GetGroupAPIKeys)
|
||||
}
|
||||
|
||||
// 账号管理
|
||||
accounts := admin.Group("/accounts")
|
||||
{
|
||||
accounts.GET("", h.Admin.Account.List)
|
||||
accounts.GET("/:id", h.Admin.Account.GetByID)
|
||||
accounts.POST("", h.Admin.Account.Create)
|
||||
accounts.PUT("/:id", h.Admin.Account.Update)
|
||||
accounts.DELETE("/:id", h.Admin.Account.Delete)
|
||||
accounts.POST("/:id/test", h.Admin.Account.Test)
|
||||
accounts.POST("/:id/refresh", h.Admin.Account.Refresh)
|
||||
accounts.GET("/:id/stats", h.Admin.Account.GetStats)
|
||||
accounts.POST("/:id/clear-error", h.Admin.Account.ClearError)
|
||||
accounts.GET("/:id/usage", h.Admin.Account.GetUsage)
|
||||
accounts.GET("/:id/today-stats", h.Admin.Account.GetTodayStats)
|
||||
accounts.POST("/:id/clear-rate-limit", h.Admin.Account.ClearRateLimit)
|
||||
accounts.POST("/:id/schedulable", h.Admin.Account.SetSchedulable)
|
||||
accounts.POST("/batch", h.Admin.Account.BatchCreate)
|
||||
|
||||
// OAuth routes
|
||||
accounts.POST("/generate-auth-url", h.Admin.OAuth.GenerateAuthURL)
|
||||
accounts.POST("/generate-setup-token-url", h.Admin.OAuth.GenerateSetupTokenURL)
|
||||
accounts.POST("/exchange-code", h.Admin.OAuth.ExchangeCode)
|
||||
accounts.POST("/exchange-setup-token-code", h.Admin.OAuth.ExchangeSetupTokenCode)
|
||||
accounts.POST("/cookie-auth", h.Admin.OAuth.CookieAuth)
|
||||
accounts.POST("/setup-token-cookie-auth", h.Admin.OAuth.SetupTokenCookieAuth)
|
||||
}
|
||||
|
||||
// 代理管理
|
||||
proxies := admin.Group("/proxies")
|
||||
{
|
||||
proxies.GET("", h.Admin.Proxy.List)
|
||||
proxies.GET("/all", h.Admin.Proxy.GetAll)
|
||||
proxies.GET("/:id", h.Admin.Proxy.GetByID)
|
||||
proxies.POST("", h.Admin.Proxy.Create)
|
||||
proxies.PUT("/:id", h.Admin.Proxy.Update)
|
||||
proxies.DELETE("/:id", h.Admin.Proxy.Delete)
|
||||
proxies.POST("/:id/test", h.Admin.Proxy.Test)
|
||||
proxies.GET("/:id/stats", h.Admin.Proxy.GetStats)
|
||||
proxies.GET("/:id/accounts", h.Admin.Proxy.GetProxyAccounts)
|
||||
proxies.POST("/batch", h.Admin.Proxy.BatchCreate)
|
||||
}
|
||||
|
||||
// 卡密管理
|
||||
codes := admin.Group("/redeem-codes")
|
||||
{
|
||||
codes.GET("", h.Admin.Redeem.List)
|
||||
codes.GET("/stats", h.Admin.Redeem.GetStats)
|
||||
codes.GET("/export", h.Admin.Redeem.Export)
|
||||
codes.GET("/:id", h.Admin.Redeem.GetByID)
|
||||
codes.POST("/generate", h.Admin.Redeem.Generate)
|
||||
codes.DELETE("/:id", h.Admin.Redeem.Delete)
|
||||
codes.POST("/batch-delete", h.Admin.Redeem.BatchDelete)
|
||||
codes.POST("/:id/expire", h.Admin.Redeem.Expire)
|
||||
}
|
||||
|
||||
// 系统设置
|
||||
adminSettings := admin.Group("/settings")
|
||||
{
|
||||
adminSettings.GET("", h.Admin.Setting.GetSettings)
|
||||
adminSettings.PUT("", h.Admin.Setting.UpdateSettings)
|
||||
adminSettings.POST("/test-smtp", h.Admin.Setting.TestSmtpConnection)
|
||||
adminSettings.POST("/send-test-email", h.Admin.Setting.SendTestEmail)
|
||||
}
|
||||
|
||||
// 系统管理
|
||||
system := admin.Group("/system")
|
||||
{
|
||||
system.GET("/version", h.Admin.System.GetVersion)
|
||||
system.GET("/check-updates", h.Admin.System.CheckUpdates)
|
||||
system.POST("/update", h.Admin.System.PerformUpdate)
|
||||
system.POST("/rollback", h.Admin.System.Rollback)
|
||||
system.POST("/restart", h.Admin.System.RestartService)
|
||||
}
|
||||
|
||||
// 订阅管理
|
||||
subscriptions := admin.Group("/subscriptions")
|
||||
{
|
||||
subscriptions.GET("", h.Admin.Subscription.List)
|
||||
subscriptions.GET("/:id", h.Admin.Subscription.GetByID)
|
||||
subscriptions.GET("/:id/progress", h.Admin.Subscription.GetProgress)
|
||||
subscriptions.POST("/assign", h.Admin.Subscription.Assign)
|
||||
subscriptions.POST("/bulk-assign", h.Admin.Subscription.BulkAssign)
|
||||
subscriptions.POST("/:id/extend", h.Admin.Subscription.Extend)
|
||||
subscriptions.DELETE("/:id", h.Admin.Subscription.Revoke)
|
||||
}
|
||||
|
||||
// 分组下的订阅列表
|
||||
admin.GET("/groups/:id/subscriptions", h.Admin.Subscription.ListByGroup)
|
||||
|
||||
// 用户下的订阅列表
|
||||
admin.GET("/users/:id/subscriptions", h.Admin.Subscription.ListByUser)
|
||||
|
||||
// 使用记录管理
|
||||
usage := admin.Group("/usage")
|
||||
{
|
||||
usage.GET("", h.Admin.Usage.List)
|
||||
usage.GET("/stats", h.Admin.Usage.Stats)
|
||||
usage.GET("/search-users", h.Admin.Usage.SearchUsers)
|
||||
usage.GET("/search-api-keys", h.Admin.Usage.SearchApiKeys)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// API网关(Claude API兼容)
|
||||
gateway := r.Group("/v1")
|
||||
gateway.Use(middleware.ApiKeyAuthWithSubscription(s.ApiKey, s.Subscription))
|
||||
{
|
||||
gateway.POST("/messages", h.Gateway.Messages)
|
||||
gateway.GET("/models", h.Gateway.Models)
|
||||
gateway.GET("/usage", h.Gateway.Usage)
|
||||
}
|
||||
}
|
||||
38
backend/config.yaml
Normal file
38
backend/config.yaml
Normal file
@@ -0,0 +1,38 @@
|
||||
server:
|
||||
host: "0.0.0.0"
|
||||
port: 8080
|
||||
mode: "debug" # debug/release
|
||||
|
||||
database:
|
||||
host: "127.0.0.1"
|
||||
port: 5432
|
||||
user: "postgres"
|
||||
password: "XZeRr7nkjHWhm8fw"
|
||||
dbname: "sub2api"
|
||||
sslmode: "disable"
|
||||
|
||||
redis:
|
||||
host: "127.0.0.1"
|
||||
port: 6379
|
||||
password: ""
|
||||
db: 0
|
||||
|
||||
jwt:
|
||||
secret: "your-secret-key-change-in-production"
|
||||
expire_hour: 24
|
||||
|
||||
default:
|
||||
admin_email: "admin@sub2api.com"
|
||||
admin_password: "admin123"
|
||||
user_concurrency: 5
|
||||
user_balance: 0
|
||||
api_key_prefix: "sk-"
|
||||
rate_multiplier: 1.0
|
||||
|
||||
# Timezone configuration (similar to PHP's date_default_timezone_set)
|
||||
# This affects ALL time operations:
|
||||
# - Database timestamps
|
||||
# - Usage statistics "today" boundary
|
||||
# - Subscription expiry times
|
||||
# Common values: Asia/Shanghai, America/New_York, Europe/London, UTC
|
||||
timezone: "Asia/Shanghai"
|
||||
74
backend/go.mod
Normal file
74
backend/go.mod
Normal file
@@ -0,0 +1,74 @@
|
||||
module sub2api
|
||||
|
||||
go 1.24.0
|
||||
|
||||
toolchain go1.24.11
|
||||
|
||||
require (
|
||||
github.com/gin-gonic/gin v1.9.1
|
||||
github.com/golang-jwt/jwt/v5 v5.2.0
|
||||
github.com/google/uuid v1.6.0
|
||||
github.com/imroc/req/v3 v3.56.0
|
||||
github.com/lib/pq v1.10.9
|
||||
github.com/redis/go-redis/v9 v9.3.0
|
||||
github.com/spf13/viper v1.18.2
|
||||
golang.org/x/crypto v0.44.0
|
||||
golang.org/x/term v0.37.0
|
||||
gopkg.in/yaml.v3 v3.0.1
|
||||
gorm.io/driver/postgres v1.5.4
|
||||
gorm.io/gorm v1.25.5
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/andybalholm/brotli v1.2.0 // indirect
|
||||
github.com/bytedance/sonic v1.9.1 // indirect
|
||||
github.com/cespare/xxhash/v2 v2.2.0 // indirect
|
||||
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect
|
||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
|
||||
github.com/fsnotify/fsnotify v1.7.0 // indirect
|
||||
github.com/gabriel-vasile/mimetype v1.4.2 // indirect
|
||||
github.com/gin-contrib/sse v0.1.0 // indirect
|
||||
github.com/go-playground/locales v0.14.1 // indirect
|
||||
github.com/go-playground/universal-translator v0.18.1 // indirect
|
||||
github.com/go-playground/validator/v10 v10.14.0 // indirect
|
||||
github.com/goccy/go-json v0.10.2 // indirect
|
||||
github.com/google/go-querystring v1.1.0 // indirect
|
||||
github.com/hashicorp/hcl v1.0.0 // indirect
|
||||
github.com/icholy/digest v1.1.0 // indirect
|
||||
github.com/jackc/pgpassfile v1.0.0 // indirect
|
||||
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect
|
||||
github.com/jackc/pgx/v5 v5.4.3 // indirect
|
||||
github.com/jinzhu/inflection v1.0.0 // indirect
|
||||
github.com/jinzhu/now v1.1.5 // indirect
|
||||
github.com/json-iterator/go v1.1.12 // indirect
|
||||
github.com/klauspost/compress v1.18.1 // indirect
|
||||
github.com/klauspost/cpuid/v2 v2.2.4 // indirect
|
||||
github.com/leodido/go-urn v1.2.4 // indirect
|
||||
github.com/magiconair/properties v1.8.7 // indirect
|
||||
github.com/mattn/go-isatty v0.0.19 // indirect
|
||||
github.com/mitchellh/mapstructure v1.5.0 // indirect
|
||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
|
||||
github.com/modern-go/reflect2 v1.0.2 // indirect
|
||||
github.com/pelletier/go-toml/v2 v2.1.0 // indirect
|
||||
github.com/quic-go/qpack v0.5.1 // indirect
|
||||
github.com/quic-go/quic-go v0.56.0 // indirect
|
||||
github.com/refraction-networking/utls v1.8.1 // indirect
|
||||
github.com/sagikazarmark/locafero v0.4.0 // indirect
|
||||
github.com/sagikazarmark/slog-shim v0.1.0 // indirect
|
||||
github.com/sourcegraph/conc v0.3.0 // indirect
|
||||
github.com/spf13/afero v1.11.0 // indirect
|
||||
github.com/spf13/cast v1.6.0 // indirect
|
||||
github.com/spf13/pflag v1.0.5 // indirect
|
||||
github.com/subosito/gotenv v1.6.0 // indirect
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
||||
github.com/ugorji/go/codec v1.2.11 // indirect
|
||||
go.uber.org/atomic v1.9.0 // indirect
|
||||
go.uber.org/multierr v1.9.0 // indirect
|
||||
golang.org/x/arch v0.3.0 // indirect
|
||||
golang.org/x/exp v0.0.0-20230905200255-921286631fa9 // indirect
|
||||
golang.org/x/net v0.47.0 // indirect
|
||||
golang.org/x/sys v0.38.0 // indirect
|
||||
golang.org/x/text v0.31.0 // indirect
|
||||
google.golang.org/protobuf v1.31.0 // indirect
|
||||
gopkg.in/ini.v1 v1.67.0 // indirect
|
||||
)
|
||||
187
backend/go.sum
Normal file
187
backend/go.sum
Normal file
@@ -0,0 +1,187 @@
|
||||
github.com/andybalholm/brotli v1.2.0 h1:ukwgCxwYrmACq68yiUqwIWnGY0cTPox/M94sVwToPjQ=
|
||||
github.com/andybalholm/brotli v1.2.0/go.mod h1:rzTDkvFWvIrjDXZHkuS16NPggd91W3kUSvPlQ1pLaKY=
|
||||
github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs=
|
||||
github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c=
|
||||
github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA=
|
||||
github.com/bsm/gomega v1.27.10/go.mod h1:JyEr/xRbxbtgWNi8tIEVPUYZ5Dzef52k01W3YH0H+O0=
|
||||
github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM=
|
||||
github.com/bytedance/sonic v1.9.1 h1:6iJ6NqdoxCDr6mbY8h18oSO+cShGSMRGCEo7F2h0x8s=
|
||||
github.com/bytedance/sonic v1.9.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U=
|
||||
github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44=
|
||||
github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
|
||||
github.com/chenzhuoyu/base64x v0.0.0-20211019084208-fb5309c8db06/go.mod h1:DH46F32mSOjUmXrMHnKwZdA8wcEefY7UVqBKYGjpdQY=
|
||||
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 h1:qSGYFH7+jGhDF8vLC+iwCD4WpbV1EBDSzWkJODFLams=
|
||||
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311/go.mod h1:b583jCggY9gE99b6G5LEC39OIiVsWj+R97kbl5odCEk=
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM=
|
||||
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78=
|
||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
|
||||
github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8=
|
||||
github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0=
|
||||
github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nosvA=
|
||||
github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyTJwOpGvigM=
|
||||
github.com/gabriel-vasile/mimetype v1.4.2 h1:w5qFW6JKBz9Y393Y4q372O9A7cUSequkh1Q7OhCmWKU=
|
||||
github.com/gabriel-vasile/mimetype v1.4.2/go.mod h1:zApsH/mKG4w07erKIaJPFiX0Tsq9BFQgN3qGY5GnNgA=
|
||||
github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE=
|
||||
github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI=
|
||||
github.com/gin-gonic/gin v1.9.1 h1:4idEAncQnU5cB7BeOkPtxjfCSye0AAm1R0RVIqJ+Jmg=
|
||||
github.com/gin-gonic/gin v1.9.1/go.mod h1:hPrL7YrpYKXt5YId3A/Tnip5kqbEAP+KLuI3SUcPTeU=
|
||||
github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s=
|
||||
github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
|
||||
github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA=
|
||||
github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY=
|
||||
github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY=
|
||||
github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY=
|
||||
github.com/go-playground/validator/v10 v10.14.0 h1:vgvQWe3XCz3gIeFDm/HnTIbj6UGmg/+t63MyGU2n5js=
|
||||
github.com/go-playground/validator/v10 v10.14.0/go.mod h1:9iXMNT7sEkjXb0I+enO7QXmzG6QCsPWY4zveKFVRSyU=
|
||||
github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU=
|
||||
github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
|
||||
github.com/golang-jwt/jwt/v5 v5.2.0 h1:d/ix8ftRUorsN+5eMIlF4T6J8CAt9rch3My2winC1Jw=
|
||||
github.com/golang-jwt/jwt/v5 v5.2.0/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
|
||||
github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk=
|
||||
github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
|
||||
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
|
||||
github.com/google/go-querystring v1.1.0 h1:AnCroh3fv4ZBgVIf1Iwtovgjaw/GiKJo8M8yD/fhyJ8=
|
||||
github.com/google/go-querystring v1.1.0/go.mod h1:Kcdr2DB4koayq7X8pmAG4sNG59So17icRSOU623lUBU=
|
||||
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/hashicorp/hcl v1.0.0 h1:0Anlzjpi4vEasTeNFn2mLJgTSwt0+6sfsiTG8qcWGx4=
|
||||
github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ=
|
||||
github.com/icholy/digest v1.1.0 h1:HfGg9Irj7i+IX1o1QAmPfIBNu/Q5A5Tu3n/MED9k9H4=
|
||||
github.com/icholy/digest v1.1.0/go.mod h1:QNrsSGQ5v7v9cReDI0+eyjsXGUoRSUZQHeQ5C4XLa0Y=
|
||||
github.com/imroc/req/v3 v3.56.0 h1:t6YdqqerYBXhZ9+VjqsQs5wlKxdUNEvsgBhxWc1AEEo=
|
||||
github.com/imroc/req/v3 v3.56.0/go.mod h1:cUZSooE8hhzFNOrAbdxuemXDQxFXLQTnu3066jr7ZGk=
|
||||
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
|
||||
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
|
||||
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a h1:bbPeKD0xmW/Y25WS6cokEszi5g+S0QxI/d45PkRi7Nk=
|
||||
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
|
||||
github.com/jackc/pgx/v5 v5.4.3 h1:cxFyXhxlvAifxnkKKdlxv8XqUf59tDlYjnV5YYfsJJY=
|
||||
github.com/jackc/pgx/v5 v5.4.3/go.mod h1:Ig06C2Vu0t5qXC60W8sqIthScaEnFvojjj9dSljmHRA=
|
||||
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
|
||||
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
|
||||
github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ=
|
||||
github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
|
||||
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
|
||||
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
|
||||
github.com/klauspost/compress v1.18.1 h1:bcSGx7UbpBqMChDtsF28Lw6v/G94LPrrbMbdC3JH2co=
|
||||
github.com/klauspost/compress v1.18.1/go.mod h1:ZQFFVG+MdnR0P+l6wpXgIL4NTtwiKIdBnrBd8Nrxr+0=
|
||||
github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg=
|
||||
github.com/klauspost/cpuid/v2 v2.2.4 h1:acbojRNwl3o09bUq+yDCtZFc1aiwaAAxtcn8YkZXnvk=
|
||||
github.com/klauspost/cpuid/v2 v2.2.4/go.mod h1:RVVoqg1df56z8g3pUjL/3lE5UfnlrJX8tyFgg4nqhuY=
|
||||
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
|
||||
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
|
||||
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
|
||||
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
|
||||
github.com/leodido/go-urn v1.2.4 h1:XlAE/cm/ms7TE/VMVoduSpNBoyc2dOxHs5MZSwAN63Q=
|
||||
github.com/leodido/go-urn v1.2.4/go.mod h1:7ZrI8mTSeBSHl/UaRyKQW1qZeMgak41ANeCNaVckg+4=
|
||||
github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw=
|
||||
github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
|
||||
github.com/magiconair/properties v1.8.7 h1:IeQXZAiQcpL9mgcAe1Nu6cX9LLw6ExEHKjN0VQdvPDY=
|
||||
github.com/magiconair/properties v1.8.7/go.mod h1:Dhd985XPs7jluiymwWYZ0G4Z61jb3vdS329zhj2hYo0=
|
||||
github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA=
|
||||
github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||
github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY=
|
||||
github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo=
|
||||
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg=
|
||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
||||
github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M=
|
||||
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
|
||||
github.com/pelletier/go-toml/v2 v2.1.0 h1:FnwAJ4oYMvbT/34k9zzHuZNrhlz48GB3/s6at6/MHO4=
|
||||
github.com/pelletier/go-toml/v2 v2.1.0/go.mod h1:tJU2Z3ZkXwnxa4DPO899bsyIoywizdUvyaeZurnPPDc=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U=
|
||||
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/quic-go/qpack v0.5.1 h1:giqksBPnT/HDtZ6VhtFKgoLOWmlyo9Ei6u9PqzIMbhI=
|
||||
github.com/quic-go/qpack v0.5.1/go.mod h1:+PC4XFrEskIVkcLzpEkbLqq1uCoxPhQuvK5rH1ZgaEg=
|
||||
github.com/quic-go/quic-go v0.56.0 h1:q/TW+OLismmXAehgFLczhCDTYB3bFmua4D9lsNBWxvY=
|
||||
github.com/quic-go/quic-go v0.56.0/go.mod h1:9gx5KsFQtw2oZ6GZTyh+7YEvOxWCL9WZAepnHxgAo6c=
|
||||
github.com/redis/go-redis/v9 v9.3.0 h1:RiVDjmig62jIWp7Kk4XVLs0hzV6pI3PyTnnL0cnn0u0=
|
||||
github.com/redis/go-redis/v9 v9.3.0/go.mod h1:hdY0cQFCN4fnSYT6TkisLufl/4W5UIXyv0b/CLO2V2M=
|
||||
github.com/refraction-networking/utls v1.8.1 h1:yNY1kapmQU8JeM1sSw2H2asfTIwWxIkrMJI0pRUOCAo=
|
||||
github.com/refraction-networking/utls v1.8.1/go.mod h1:jkSOEkLqn+S/jtpEHPOsVv/4V4EVnelwbMQl4vCWXAM=
|
||||
github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ=
|
||||
github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog=
|
||||
github.com/sagikazarmark/locafero v0.4.0 h1:HApY1R9zGo4DBgr7dqsTH/JJxLTTsOt7u6keLGt6kNQ=
|
||||
github.com/sagikazarmark/locafero v0.4.0/go.mod h1:Pe1W6UlPYUk/+wc/6KFhbORCfqzgYEpgQ3O5fPuL3H4=
|
||||
github.com/sagikazarmark/slog-shim v0.1.0 h1:diDBnUNK9N/354PgrxMywXnAwEr1QZcOr6gto+ugjYE=
|
||||
github.com/sagikazarmark/slog-shim v0.1.0/go.mod h1:SrcSrq8aKtyuqEI1uvTDTK1arOWRIczQRv+GVI1AkeQ=
|
||||
github.com/sourcegraph/conc v0.3.0 h1:OQTbbt6P72L20UqAkXXuLOj79LfEanQ+YQFNpLA9ySo=
|
||||
github.com/sourcegraph/conc v0.3.0/go.mod h1:Sdozi7LEKbFPqYX2/J+iBAM6HpqSLTASQIKqDmF7Mt0=
|
||||
github.com/spf13/afero v1.11.0 h1:WJQKhtpdm3v2IzqG8VMqrr6Rf3UYpEF239Jy9wNepM8=
|
||||
github.com/spf13/afero v1.11.0/go.mod h1:GH9Y3pIexgf1MTIWtNGyogA5MwRIDXGUr+hbWNoBjkY=
|
||||
github.com/spf13/cast v1.6.0 h1:GEiTHELF+vaR5dhz3VqZfFSzZjYbgeKDpBxQVS4GYJ0=
|
||||
github.com/spf13/cast v1.6.0/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo=
|
||||
github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=
|
||||
github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
|
||||
github.com/spf13/viper v1.18.2 h1:LUXCnvUvSM6FXAsj6nnfc8Q2tp1dIgUfY9Kc8GsSOiQ=
|
||||
github.com/spf13/viper v1.18.2/go.mod h1:EKmWIqdnk5lOcmR72yw6hS+8OPYcwD0jteitLMVB+yk=
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
|
||||
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
|
||||
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
|
||||
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
|
||||
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
|
||||
github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
|
||||
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
|
||||
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
|
||||
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||
github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8=
|
||||
github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU=
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI=
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
|
||||
github.com/ugorji/go/codec v1.2.11 h1:BMaWp1Bb6fHwEtbplGBGJ498wD+LKlNSl25MjdZY4dU=
|
||||
github.com/ugorji/go/codec v1.2.11/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg=
|
||||
github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZqKjWU=
|
||||
github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E=
|
||||
go.uber.org/atomic v1.9.0 h1:ECmE8Bn/WFTYwEW/bpKD3M8VtR/zQVbavAoalC1PYyE=
|
||||
go.uber.org/atomic v1.9.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc=
|
||||
go.uber.org/mock v0.6.0 h1:hyF9dfmbgIX5EfOdasqLsWD6xqpNZlXblLB/Dbnwv3Y=
|
||||
go.uber.org/mock v0.6.0/go.mod h1:KiVJ4BqZJaMj4svdfmHM0AUx4NJYO8ZNpPnZn1Z+BBU=
|
||||
go.uber.org/multierr v1.9.0 h1:7fIwc/ZtS0q++VgcfqFDxSBZVv/Xo49/SYnDFupUwlI=
|
||||
go.uber.org/multierr v1.9.0/go.mod h1:X2jQV1h+kxSjClGpnseKVIxpmcjrj7MNnI0bnlfKTVQ=
|
||||
golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
|
||||
golang.org/x/arch v0.3.0 h1:02VY4/ZcO/gBOH6PUaoiptASxtXU10jazRCP865E97k=
|
||||
golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
|
||||
golang.org/x/crypto v0.44.0 h1:A97SsFvM3AIwEEmTBiaxPPTYpDC47w720rdiiUvgoAU=
|
||||
golang.org/x/crypto v0.44.0/go.mod h1:013i+Nw79BMiQiMsOPcVCB5ZIJbYkerPrGnOa00tvmc=
|
||||
golang.org/x/exp v0.0.0-20230905200255-921286631fa9 h1:GoHiUyI/Tp2nVkLI2mCxVkOjsbSXD66ic0XW0js0R9g=
|
||||
golang.org/x/exp v0.0.0-20230905200255-921286631fa9/go.mod h1:S2oDrQGGwySpoQPVqRShND87VCbxmc6bL1Yd2oYrm6k=
|
||||
golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY=
|
||||
golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU=
|
||||
golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc=
|
||||
golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||
golang.org/x/term v0.37.0 h1:8EGAD0qCmHYZg6J17DvsMy9/wJ7/D/4pV/wfnld5lTU=
|
||||
golang.org/x/term v0.37.0/go.mod h1:5pB4lxRNYYVZuTLmy8oR2BH8dflOR+IbTYFD8fi3254=
|
||||
golang.org/x/text v0.31.0 h1:aC8ghyu4JhP8VojJ2lEHBnochRno1sgL6nEi9WGFGMM=
|
||||
golang.org/x/text v0.31.0/go.mod h1:tKRAlv61yKIjGGHX/4tP1LTbc13YSec1pxVEWXzfoeM=
|
||||
golang.org/x/time v0.12.0 h1:ScB/8o8olJvc+CQPWrK3fPZNfh7qgwCrY0zJmoEQLSE=
|
||||
golang.org/x/time v0.12.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg=
|
||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
|
||||
google.golang.org/protobuf v1.31.0 h1:g0LDEJHgrBl9N9r17Ru3sqWhkIx2NB67okBHPwC7hs8=
|
||||
google.golang.org/protobuf v1.31.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
|
||||
gopkg.in/ini.v1 v1.67.0 h1:Dgnx+6+nfE+IfzjUEISNeydPJh9AXNNsWbGP9KzCsOA=
|
||||
gopkg.in/ini.v1 v1.67.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k=
|
||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
gorm.io/driver/postgres v1.5.4 h1:Iyrp9Meh3GmbSuyIAGyjkN+n9K+GHX9b9MqsTL4EJCo=
|
||||
gorm.io/driver/postgres v1.5.4/go.mod h1:Bgo89+h0CRcdA33Y6frlaHHVuTdOf87pmyzwW9C/BH0=
|
||||
gorm.io/gorm v1.25.5 h1:zR9lOiiYf09VNh5Q1gphfyia1JpiClIWG9hQaxB/mls=
|
||||
gorm.io/gorm v1.25.5/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8=
|
||||
gotest.tools/v3 v3.5.1 h1:EENdUnS3pdur5nybKYIh2Vfgc8IUNBjxDPSjtiJcOzU=
|
||||
gotest.tools/v3 v3.5.1/go.mod h1:isy3WKz7GK6uNw/sbHzfKBLvlvXwUyV06n6brMxxopU=
|
||||
rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4=
|
||||
205
backend/internal/config/config.go
Normal file
205
backend/internal/config/config.go
Normal file
@@ -0,0 +1,205 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/spf13/viper"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
Server ServerConfig `mapstructure:"server"`
|
||||
Database DatabaseConfig `mapstructure:"database"`
|
||||
Redis RedisConfig `mapstructure:"redis"`
|
||||
JWT JWTConfig `mapstructure:"jwt"`
|
||||
Default DefaultConfig `mapstructure:"default"`
|
||||
RateLimit RateLimitConfig `mapstructure:"rate_limit"`
|
||||
Pricing PricingConfig `mapstructure:"pricing"`
|
||||
Gateway GatewayConfig `mapstructure:"gateway"`
|
||||
Timezone string `mapstructure:"timezone"` // e.g. "Asia/Shanghai", "UTC"
|
||||
}
|
||||
|
||||
type PricingConfig struct {
|
||||
// 价格数据远程URL(默认使用LiteLLM镜像)
|
||||
RemoteURL string `mapstructure:"remote_url"`
|
||||
// 哈希校验文件URL
|
||||
HashURL string `mapstructure:"hash_url"`
|
||||
// 本地数据目录
|
||||
DataDir string `mapstructure:"data_dir"`
|
||||
// 回退文件路径
|
||||
FallbackFile string `mapstructure:"fallback_file"`
|
||||
// 更新间隔(小时)
|
||||
UpdateIntervalHours int `mapstructure:"update_interval_hours"`
|
||||
// 哈希校验间隔(分钟)
|
||||
HashCheckIntervalMinutes int `mapstructure:"hash_check_interval_minutes"`
|
||||
}
|
||||
|
||||
type ServerConfig struct {
|
||||
Host string `mapstructure:"host"`
|
||||
Port int `mapstructure:"port"`
|
||||
Mode string `mapstructure:"mode"` // debug/release
|
||||
ReadHeaderTimeout int `mapstructure:"read_header_timeout"` // 读取请求头超时(秒)
|
||||
IdleTimeout int `mapstructure:"idle_timeout"` // 空闲连接超时(秒)
|
||||
}
|
||||
|
||||
// GatewayConfig API网关相关配置
|
||||
type GatewayConfig struct {
|
||||
// 等待上游响应头的超时时间(秒),0表示无超时
|
||||
// 注意:这不影响流式数据传输,只控制等待响应头的时间
|
||||
ResponseHeaderTimeout int `mapstructure:"response_header_timeout"`
|
||||
}
|
||||
|
||||
func (s *ServerConfig) Address() string {
|
||||
return fmt.Sprintf("%s:%d", s.Host, s.Port)
|
||||
}
|
||||
|
||||
type DatabaseConfig struct {
|
||||
Host string `mapstructure:"host"`
|
||||
Port int `mapstructure:"port"`
|
||||
User string `mapstructure:"user"`
|
||||
Password string `mapstructure:"password"`
|
||||
DBName string `mapstructure:"dbname"`
|
||||
SSLMode string `mapstructure:"sslmode"`
|
||||
}
|
||||
|
||||
func (d *DatabaseConfig) DSN() string {
|
||||
return fmt.Sprintf(
|
||||
"host=%s port=%d user=%s password=%s dbname=%s sslmode=%s",
|
||||
d.Host, d.Port, d.User, d.Password, d.DBName, d.SSLMode,
|
||||
)
|
||||
}
|
||||
|
||||
// DSNWithTimezone returns DSN with timezone setting
|
||||
func (d *DatabaseConfig) DSNWithTimezone(tz string) string {
|
||||
if tz == "" {
|
||||
tz = "Asia/Shanghai"
|
||||
}
|
||||
return fmt.Sprintf(
|
||||
"host=%s port=%d user=%s password=%s dbname=%s sslmode=%s TimeZone=%s",
|
||||
d.Host, d.Port, d.User, d.Password, d.DBName, d.SSLMode, tz,
|
||||
)
|
||||
}
|
||||
|
||||
type RedisConfig struct {
|
||||
Host string `mapstructure:"host"`
|
||||
Port int `mapstructure:"port"`
|
||||
Password string `mapstructure:"password"`
|
||||
DB int `mapstructure:"db"`
|
||||
}
|
||||
|
||||
func (r *RedisConfig) Address() string {
|
||||
return fmt.Sprintf("%s:%d", r.Host, r.Port)
|
||||
}
|
||||
|
||||
type JWTConfig struct {
|
||||
Secret string `mapstructure:"secret"`
|
||||
ExpireHour int `mapstructure:"expire_hour"`
|
||||
}
|
||||
|
||||
type DefaultConfig struct {
|
||||
AdminEmail string `mapstructure:"admin_email"`
|
||||
AdminPassword string `mapstructure:"admin_password"`
|
||||
UserConcurrency int `mapstructure:"user_concurrency"`
|
||||
UserBalance float64 `mapstructure:"user_balance"`
|
||||
ApiKeyPrefix string `mapstructure:"api_key_prefix"`
|
||||
RateMultiplier float64 `mapstructure:"rate_multiplier"`
|
||||
}
|
||||
|
||||
type RateLimitConfig struct {
|
||||
OverloadCooldownMinutes int `mapstructure:"overload_cooldown_minutes"` // 529过载冷却时间(分钟)
|
||||
}
|
||||
|
||||
func Load() (*Config, error) {
|
||||
viper.SetConfigName("config")
|
||||
viper.SetConfigType("yaml")
|
||||
viper.AddConfigPath(".")
|
||||
viper.AddConfigPath("./config")
|
||||
viper.AddConfigPath("/etc/sub2api")
|
||||
|
||||
// 环境变量支持
|
||||
viper.AutomaticEnv()
|
||||
viper.SetEnvKeyReplacer(strings.NewReplacer(".", "_"))
|
||||
|
||||
// 默认值
|
||||
setDefaults()
|
||||
|
||||
if err := viper.ReadInConfig(); err != nil {
|
||||
if _, ok := err.(viper.ConfigFileNotFoundError); !ok {
|
||||
return nil, fmt.Errorf("read config error: %w", err)
|
||||
}
|
||||
// 配置文件不存在时使用默认值
|
||||
}
|
||||
|
||||
var cfg Config
|
||||
if err := viper.Unmarshal(&cfg); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal config error: %w", err)
|
||||
}
|
||||
|
||||
if err := cfg.Validate(); err != nil {
|
||||
return nil, fmt.Errorf("validate config error: %w", err)
|
||||
}
|
||||
|
||||
return &cfg, nil
|
||||
}
|
||||
|
||||
func setDefaults() {
|
||||
// Server
|
||||
viper.SetDefault("server.host", "0.0.0.0")
|
||||
viper.SetDefault("server.port", 8080)
|
||||
viper.SetDefault("server.mode", "debug")
|
||||
viper.SetDefault("server.read_header_timeout", 30) // 30秒读取请求头
|
||||
viper.SetDefault("server.idle_timeout", 120) // 120秒空闲超时
|
||||
|
||||
// Database
|
||||
viper.SetDefault("database.host", "localhost")
|
||||
viper.SetDefault("database.port", 5432)
|
||||
viper.SetDefault("database.user", "postgres")
|
||||
viper.SetDefault("database.password", "postgres")
|
||||
viper.SetDefault("database.dbname", "sub2api")
|
||||
viper.SetDefault("database.sslmode", "disable")
|
||||
|
||||
// Redis
|
||||
viper.SetDefault("redis.host", "localhost")
|
||||
viper.SetDefault("redis.port", 6379)
|
||||
viper.SetDefault("redis.password", "")
|
||||
viper.SetDefault("redis.db", 0)
|
||||
|
||||
// JWT
|
||||
viper.SetDefault("jwt.secret", "change-me-in-production")
|
||||
viper.SetDefault("jwt.expire_hour", 24)
|
||||
|
||||
// Default
|
||||
viper.SetDefault("default.admin_email", "admin@sub2api.com")
|
||||
viper.SetDefault("default.admin_password", "admin123")
|
||||
viper.SetDefault("default.user_concurrency", 5)
|
||||
viper.SetDefault("default.user_balance", 0)
|
||||
viper.SetDefault("default.api_key_prefix", "sk-")
|
||||
viper.SetDefault("default.rate_multiplier", 1.0)
|
||||
|
||||
// RateLimit
|
||||
viper.SetDefault("rate_limit.overload_cooldown_minutes", 10)
|
||||
|
||||
// Pricing - 从 price-mirror 分支同步,该分支维护了 sha256 哈希文件用于增量更新检查
|
||||
viper.SetDefault("pricing.remote_url", "https://raw.githubusercontent.com/Wei-Shaw/claude-relay-service/price-mirror/model_prices_and_context_window.json")
|
||||
viper.SetDefault("pricing.hash_url", "https://raw.githubusercontent.com/Wei-Shaw/claude-relay-service/price-mirror/model_prices_and_context_window.sha256")
|
||||
viper.SetDefault("pricing.data_dir", "./data")
|
||||
viper.SetDefault("pricing.fallback_file", "./resources/model-pricing/model_prices_and_context_window.json")
|
||||
viper.SetDefault("pricing.update_interval_hours", 24)
|
||||
viper.SetDefault("pricing.hash_check_interval_minutes", 10)
|
||||
|
||||
// Timezone (default to Asia/Shanghai for Chinese users)
|
||||
viper.SetDefault("timezone", "Asia/Shanghai")
|
||||
|
||||
// Gateway
|
||||
viper.SetDefault("gateway.response_header_timeout", 300) // 300秒(5分钟)等待上游响应头,LLM高负载时可能排队较久
|
||||
}
|
||||
|
||||
func (c *Config) Validate() error {
|
||||
if c.JWT.Secret == "" {
|
||||
return fmt.Errorf("jwt.secret is required")
|
||||
}
|
||||
if c.JWT.Secret == "change-me-in-production" && c.Server.Mode == "release" {
|
||||
return fmt.Errorf("jwt.secret must be changed in production")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
537
backend/internal/handler/admin/account_handler.go
Normal file
537
backend/internal/handler/admin/account_handler.go
Normal file
@@ -0,0 +1,537 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
|
||||
"sub2api/internal/pkg/response"
|
||||
"sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// OAuthHandler handles OAuth-related operations for accounts
|
||||
type OAuthHandler struct {
|
||||
oauthService *service.OAuthService
|
||||
adminService service.AdminService
|
||||
}
|
||||
|
||||
// NewOAuthHandler creates a new OAuth handler
|
||||
func NewOAuthHandler(oauthService *service.OAuthService, adminService service.AdminService) *OAuthHandler {
|
||||
return &OAuthHandler{
|
||||
oauthService: oauthService,
|
||||
adminService: adminService,
|
||||
}
|
||||
}
|
||||
|
||||
// AccountHandler handles admin account management
|
||||
type AccountHandler struct {
|
||||
adminService service.AdminService
|
||||
oauthService *service.OAuthService
|
||||
rateLimitService *service.RateLimitService
|
||||
accountUsageService *service.AccountUsageService
|
||||
accountTestService *service.AccountTestService
|
||||
}
|
||||
|
||||
// NewAccountHandler creates a new admin account handler
|
||||
func NewAccountHandler(adminService service.AdminService, oauthService *service.OAuthService, rateLimitService *service.RateLimitService, accountUsageService *service.AccountUsageService, accountTestService *service.AccountTestService) *AccountHandler {
|
||||
return &AccountHandler{
|
||||
adminService: adminService,
|
||||
oauthService: oauthService,
|
||||
rateLimitService: rateLimitService,
|
||||
accountUsageService: accountUsageService,
|
||||
accountTestService: accountTestService,
|
||||
}
|
||||
}
|
||||
|
||||
// CreateAccountRequest represents create account request
|
||||
type CreateAccountRequest struct {
|
||||
Name string `json:"name" binding:"required"`
|
||||
Platform string `json:"platform" binding:"required"`
|
||||
Type string `json:"type" binding:"required,oneof=oauth setup-token apikey"`
|
||||
Credentials map[string]interface{} `json:"credentials" binding:"required"`
|
||||
Extra map[string]interface{} `json:"extra"`
|
||||
ProxyID *int64 `json:"proxy_id"`
|
||||
Concurrency int `json:"concurrency"`
|
||||
Priority int `json:"priority"`
|
||||
GroupIDs []int64 `json:"group_ids"`
|
||||
}
|
||||
|
||||
// UpdateAccountRequest represents update account request
|
||||
// 使用指针类型来区分"未提供"和"设置为0"
|
||||
type UpdateAccountRequest struct {
|
||||
Name string `json:"name"`
|
||||
Type string `json:"type" binding:"omitempty,oneof=oauth setup-token apikey"`
|
||||
Credentials map[string]interface{} `json:"credentials"`
|
||||
Extra map[string]interface{} `json:"extra"`
|
||||
ProxyID *int64 `json:"proxy_id"`
|
||||
Concurrency *int `json:"concurrency"`
|
||||
Priority *int `json:"priority"`
|
||||
Status string `json:"status" binding:"omitempty,oneof=active inactive"`
|
||||
GroupIDs *[]int64 `json:"group_ids"`
|
||||
}
|
||||
|
||||
// List handles listing all accounts with pagination
|
||||
// GET /api/v1/admin/accounts
|
||||
func (h *AccountHandler) List(c *gin.Context) {
|
||||
page, pageSize := response.ParsePagination(c)
|
||||
platform := c.Query("platform")
|
||||
accountType := c.Query("type")
|
||||
status := c.Query("status")
|
||||
search := c.Query("search")
|
||||
|
||||
accounts, total, err := h.adminService.ListAccounts(c.Request.Context(), page, pageSize, platform, accountType, status, search)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to list accounts: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
response.Paginated(c, accounts, total, page, pageSize)
|
||||
}
|
||||
|
||||
// GetByID handles getting an account by ID
|
||||
// GET /api/v1/admin/accounts/:id
|
||||
func (h *AccountHandler) GetByID(c *gin.Context) {
|
||||
accountID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid account ID")
|
||||
return
|
||||
}
|
||||
|
||||
account, err := h.adminService.GetAccount(c.Request.Context(), accountID)
|
||||
if err != nil {
|
||||
response.NotFound(c, "Account not found")
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, account)
|
||||
}
|
||||
|
||||
// Create handles creating a new account
|
||||
// POST /api/v1/admin/accounts
|
||||
func (h *AccountHandler) Create(c *gin.Context) {
|
||||
var req CreateAccountRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
account, err := h.adminService.CreateAccount(c.Request.Context(), &service.CreateAccountInput{
|
||||
Name: req.Name,
|
||||
Platform: req.Platform,
|
||||
Type: req.Type,
|
||||
Credentials: req.Credentials,
|
||||
Extra: req.Extra,
|
||||
ProxyID: req.ProxyID,
|
||||
Concurrency: req.Concurrency,
|
||||
Priority: req.Priority,
|
||||
GroupIDs: req.GroupIDs,
|
||||
})
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Failed to create account: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, account)
|
||||
}
|
||||
|
||||
// Update handles updating an account
|
||||
// PUT /api/v1/admin/accounts/:id
|
||||
func (h *AccountHandler) Update(c *gin.Context) {
|
||||
accountID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid account ID")
|
||||
return
|
||||
}
|
||||
|
||||
var req UpdateAccountRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
account, err := h.adminService.UpdateAccount(c.Request.Context(), accountID, &service.UpdateAccountInput{
|
||||
Name: req.Name,
|
||||
Type: req.Type,
|
||||
Credentials: req.Credentials,
|
||||
Extra: req.Extra,
|
||||
ProxyID: req.ProxyID,
|
||||
Concurrency: req.Concurrency, // 指针类型,nil 表示未提供
|
||||
Priority: req.Priority, // 指针类型,nil 表示未提供
|
||||
Status: req.Status,
|
||||
GroupIDs: req.GroupIDs,
|
||||
})
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to update account: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, account)
|
||||
}
|
||||
|
||||
// Delete handles deleting an account
|
||||
// DELETE /api/v1/admin/accounts/:id
|
||||
func (h *AccountHandler) Delete(c *gin.Context) {
|
||||
accountID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid account ID")
|
||||
return
|
||||
}
|
||||
|
||||
err = h.adminService.DeleteAccount(c.Request.Context(), accountID)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to delete account: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{"message": "Account deleted successfully"})
|
||||
}
|
||||
|
||||
// Test handles testing account connectivity with SSE streaming
|
||||
// POST /api/v1/admin/accounts/:id/test
|
||||
func (h *AccountHandler) Test(c *gin.Context) {
|
||||
accountID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid account ID")
|
||||
return
|
||||
}
|
||||
|
||||
// Use AccountTestService to test the account with SSE streaming
|
||||
if err := h.accountTestService.TestAccountConnection(c, accountID); err != nil {
|
||||
// Error already sent via SSE, just log
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Refresh handles refreshing account credentials
|
||||
// POST /api/v1/admin/accounts/:id/refresh
|
||||
func (h *AccountHandler) Refresh(c *gin.Context) {
|
||||
accountID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid account ID")
|
||||
return
|
||||
}
|
||||
|
||||
// Get account
|
||||
account, err := h.adminService.GetAccount(c.Request.Context(), accountID)
|
||||
if err != nil {
|
||||
response.NotFound(c, "Account not found")
|
||||
return
|
||||
}
|
||||
|
||||
// Only refresh OAuth-based accounts (oauth and setup-token)
|
||||
if !account.IsOAuth() {
|
||||
response.BadRequest(c, "Cannot refresh non-OAuth account credentials")
|
||||
return
|
||||
}
|
||||
|
||||
// Use OAuth service to refresh token
|
||||
tokenInfo, err := h.oauthService.RefreshAccountToken(c.Request.Context(), account)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to refresh credentials: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// Update account credentials
|
||||
newCredentials := map[string]interface{}{
|
||||
"access_token": tokenInfo.AccessToken,
|
||||
"token_type": tokenInfo.TokenType,
|
||||
"expires_in": tokenInfo.ExpiresIn,
|
||||
"expires_at": tokenInfo.ExpiresAt,
|
||||
"refresh_token": tokenInfo.RefreshToken,
|
||||
"scope": tokenInfo.Scope,
|
||||
}
|
||||
|
||||
updatedAccount, err := h.adminService.UpdateAccount(c.Request.Context(), accountID, &service.UpdateAccountInput{
|
||||
Credentials: newCredentials,
|
||||
})
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to update account credentials: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, updatedAccount)
|
||||
}
|
||||
|
||||
// GetStats handles getting account statistics
|
||||
// GET /api/v1/admin/accounts/:id/stats
|
||||
func (h *AccountHandler) GetStats(c *gin.Context) {
|
||||
accountID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid account ID")
|
||||
return
|
||||
}
|
||||
|
||||
// Return mock data for now
|
||||
_ = accountID
|
||||
response.Success(c, gin.H{
|
||||
"total_requests": 0,
|
||||
"successful_requests": 0,
|
||||
"failed_requests": 0,
|
||||
"total_tokens": 0,
|
||||
"average_response_time": 0,
|
||||
})
|
||||
}
|
||||
|
||||
// ClearError handles clearing account error
|
||||
// POST /api/v1/admin/accounts/:id/clear-error
|
||||
func (h *AccountHandler) ClearError(c *gin.Context) {
|
||||
accountID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid account ID")
|
||||
return
|
||||
}
|
||||
|
||||
account, err := h.adminService.ClearAccountError(c.Request.Context(), accountID)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to clear error: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, account)
|
||||
}
|
||||
|
||||
// BatchCreate handles batch creating accounts
|
||||
// POST /api/v1/admin/accounts/batch
|
||||
func (h *AccountHandler) BatchCreate(c *gin.Context) {
|
||||
var req struct {
|
||||
Accounts []CreateAccountRequest `json:"accounts" binding:"required,min=1"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// Return mock data for now
|
||||
response.Success(c, gin.H{
|
||||
"success": len(req.Accounts),
|
||||
"failed": 0,
|
||||
"results": []gin.H{},
|
||||
})
|
||||
}
|
||||
|
||||
// ========== OAuth Handlers ==========
|
||||
|
||||
// GenerateAuthURLRequest represents the request for generating auth URL
|
||||
type GenerateAuthURLRequest struct {
|
||||
ProxyID *int64 `json:"proxy_id"`
|
||||
}
|
||||
|
||||
// GenerateAuthURL generates OAuth authorization URL with full scope
|
||||
// POST /api/v1/admin/accounts/generate-auth-url
|
||||
func (h *OAuthHandler) GenerateAuthURL(c *gin.Context) {
|
||||
var req GenerateAuthURLRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
// Allow empty body
|
||||
req = GenerateAuthURLRequest{}
|
||||
}
|
||||
|
||||
result, err := h.oauthService.GenerateAuthURL(c.Request.Context(), req.ProxyID)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to generate auth URL: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, result)
|
||||
}
|
||||
|
||||
// GenerateSetupTokenURL generates OAuth authorization URL for setup token (inference only)
|
||||
// POST /api/v1/admin/accounts/generate-setup-token-url
|
||||
func (h *OAuthHandler) GenerateSetupTokenURL(c *gin.Context) {
|
||||
var req GenerateAuthURLRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
// Allow empty body
|
||||
req = GenerateAuthURLRequest{}
|
||||
}
|
||||
|
||||
result, err := h.oauthService.GenerateSetupTokenURL(c.Request.Context(), req.ProxyID)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to generate setup token URL: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, result)
|
||||
}
|
||||
|
||||
// ExchangeCodeRequest represents the request for exchanging auth code
|
||||
type ExchangeCodeRequest struct {
|
||||
SessionID string `json:"session_id" binding:"required"`
|
||||
Code string `json:"code" binding:"required"`
|
||||
ProxyID *int64 `json:"proxy_id"`
|
||||
}
|
||||
|
||||
// ExchangeCode exchanges authorization code for tokens
|
||||
// POST /api/v1/admin/accounts/exchange-code
|
||||
func (h *OAuthHandler) ExchangeCode(c *gin.Context) {
|
||||
var req ExchangeCodeRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
tokenInfo, err := h.oauthService.ExchangeCode(c.Request.Context(), &service.ExchangeCodeInput{
|
||||
SessionID: req.SessionID,
|
||||
Code: req.Code,
|
||||
ProxyID: req.ProxyID,
|
||||
})
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Failed to exchange code: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, tokenInfo)
|
||||
}
|
||||
|
||||
// ExchangeSetupTokenCode exchanges authorization code for setup token
|
||||
// POST /api/v1/admin/accounts/exchange-setup-token-code
|
||||
func (h *OAuthHandler) ExchangeSetupTokenCode(c *gin.Context) {
|
||||
var req ExchangeCodeRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
tokenInfo, err := h.oauthService.ExchangeCode(c.Request.Context(), &service.ExchangeCodeInput{
|
||||
SessionID: req.SessionID,
|
||||
Code: req.Code,
|
||||
ProxyID: req.ProxyID,
|
||||
})
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Failed to exchange code: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, tokenInfo)
|
||||
}
|
||||
|
||||
// CookieAuthRequest represents the request for cookie-based authentication
|
||||
type CookieAuthRequest struct {
|
||||
SessionKey string `json:"code" binding:"required"` // Using 'code' field as sessionKey (frontend sends it this way)
|
||||
ProxyID *int64 `json:"proxy_id"`
|
||||
}
|
||||
|
||||
// CookieAuth performs OAuth using sessionKey (cookie-based auto-auth)
|
||||
// POST /api/v1/admin/accounts/cookie-auth
|
||||
func (h *OAuthHandler) CookieAuth(c *gin.Context) {
|
||||
var req CookieAuthRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
tokenInfo, err := h.oauthService.CookieAuth(c.Request.Context(), &service.CookieAuthInput{
|
||||
SessionKey: req.SessionKey,
|
||||
ProxyID: req.ProxyID,
|
||||
Scope: "full",
|
||||
})
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Cookie auth failed: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, tokenInfo)
|
||||
}
|
||||
|
||||
// SetupTokenCookieAuth performs OAuth using sessionKey for setup token (inference only)
|
||||
// POST /api/v1/admin/accounts/setup-token-cookie-auth
|
||||
func (h *OAuthHandler) SetupTokenCookieAuth(c *gin.Context) {
|
||||
var req CookieAuthRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
tokenInfo, err := h.oauthService.CookieAuth(c.Request.Context(), &service.CookieAuthInput{
|
||||
SessionKey: req.SessionKey,
|
||||
ProxyID: req.ProxyID,
|
||||
Scope: "inference",
|
||||
})
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Cookie auth failed: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, tokenInfo)
|
||||
}
|
||||
|
||||
// GetUsage handles getting account usage information
|
||||
// GET /api/v1/admin/accounts/:id/usage
|
||||
func (h *AccountHandler) GetUsage(c *gin.Context) {
|
||||
accountID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid account ID")
|
||||
return
|
||||
}
|
||||
|
||||
usage, err := h.accountUsageService.GetUsage(c.Request.Context(), accountID)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to get usage: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, usage)
|
||||
}
|
||||
|
||||
// ClearRateLimit handles clearing account rate limit status
|
||||
// POST /api/v1/admin/accounts/:id/clear-rate-limit
|
||||
func (h *AccountHandler) ClearRateLimit(c *gin.Context) {
|
||||
accountID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid account ID")
|
||||
return
|
||||
}
|
||||
|
||||
err = h.rateLimitService.ClearRateLimit(c.Request.Context(), accountID)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to clear rate limit: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{"message": "Rate limit cleared successfully"})
|
||||
}
|
||||
|
||||
// GetTodayStats handles getting account today statistics
|
||||
// GET /api/v1/admin/accounts/:id/today-stats
|
||||
func (h *AccountHandler) GetTodayStats(c *gin.Context) {
|
||||
accountID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid account ID")
|
||||
return
|
||||
}
|
||||
|
||||
stats, err := h.accountUsageService.GetTodayStats(c.Request.Context(), accountID)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to get today stats: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, stats)
|
||||
}
|
||||
|
||||
// SetSchedulableRequest represents the request body for setting schedulable status
|
||||
type SetSchedulableRequest struct {
|
||||
Schedulable bool `json:"schedulable"`
|
||||
}
|
||||
|
||||
// SetSchedulable handles toggling account schedulable status
|
||||
// POST /api/v1/admin/accounts/:id/schedulable
|
||||
func (h *AccountHandler) SetSchedulable(c *gin.Context) {
|
||||
accountID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid account ID")
|
||||
return
|
||||
}
|
||||
|
||||
var req SetSchedulableRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
account, err := h.adminService.SetAccountSchedulable(c.Request.Context(), accountID, req.Schedulable)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to update schedulable status: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, account)
|
||||
}
|
||||
274
backend/internal/handler/admin/dashboard_handler.go
Normal file
274
backend/internal/handler/admin/dashboard_handler.go
Normal file
@@ -0,0 +1,274 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"sub2api/internal/pkg/response"
|
||||
"sub2api/internal/pkg/timezone"
|
||||
"sub2api/internal/repository"
|
||||
"sub2api/internal/service"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// DashboardHandler handles admin dashboard statistics
|
||||
type DashboardHandler struct {
|
||||
adminService service.AdminService
|
||||
usageRepo *repository.UsageLogRepository
|
||||
startTime time.Time // Server start time for uptime calculation
|
||||
}
|
||||
|
||||
// NewDashboardHandler creates a new admin dashboard handler
|
||||
func NewDashboardHandler(adminService service.AdminService, usageRepo *repository.UsageLogRepository) *DashboardHandler {
|
||||
return &DashboardHandler{
|
||||
adminService: adminService,
|
||||
usageRepo: usageRepo,
|
||||
startTime: time.Now(),
|
||||
}
|
||||
}
|
||||
|
||||
// parseTimeRange parses start_date, end_date query parameters
|
||||
func parseTimeRange(c *gin.Context) (time.Time, time.Time) {
|
||||
now := timezone.Now()
|
||||
startDate := c.Query("start_date")
|
||||
endDate := c.Query("end_date")
|
||||
|
||||
var startTime, endTime time.Time
|
||||
|
||||
if startDate != "" {
|
||||
if t, err := timezone.ParseInLocation("2006-01-02", startDate); err == nil {
|
||||
startTime = t
|
||||
} else {
|
||||
startTime = timezone.StartOfDay(now.AddDate(0, 0, -7))
|
||||
}
|
||||
} else {
|
||||
startTime = timezone.StartOfDay(now.AddDate(0, 0, -7))
|
||||
}
|
||||
|
||||
if endDate != "" {
|
||||
if t, err := timezone.ParseInLocation("2006-01-02", endDate); err == nil {
|
||||
endTime = t.Add(24 * time.Hour) // Include the end date
|
||||
} else {
|
||||
endTime = timezone.StartOfDay(now.AddDate(0, 0, 1))
|
||||
}
|
||||
} else {
|
||||
endTime = timezone.StartOfDay(now.AddDate(0, 0, 1))
|
||||
}
|
||||
|
||||
return startTime, endTime
|
||||
}
|
||||
|
||||
// GetStats handles getting dashboard statistics
|
||||
// GET /api/v1/admin/dashboard/stats
|
||||
func (h *DashboardHandler) GetStats(c *gin.Context) {
|
||||
stats, err := h.usageRepo.GetDashboardStats(c.Request.Context())
|
||||
if err != nil {
|
||||
response.Error(c, 500, "Failed to get dashboard statistics")
|
||||
return
|
||||
}
|
||||
|
||||
// Calculate uptime in seconds
|
||||
uptime := int64(time.Since(h.startTime).Seconds())
|
||||
|
||||
response.Success(c, gin.H{
|
||||
// 用户统计
|
||||
"total_users": stats.TotalUsers,
|
||||
"today_new_users": stats.TodayNewUsers,
|
||||
"active_users": stats.ActiveUsers,
|
||||
|
||||
// API Key 统计
|
||||
"total_api_keys": stats.TotalApiKeys,
|
||||
"active_api_keys": stats.ActiveApiKeys,
|
||||
|
||||
// 账户统计
|
||||
"total_accounts": stats.TotalAccounts,
|
||||
"normal_accounts": stats.NormalAccounts,
|
||||
"error_accounts": stats.ErrorAccounts,
|
||||
"ratelimit_accounts": stats.RateLimitAccounts,
|
||||
"overload_accounts": stats.OverloadAccounts,
|
||||
|
||||
// 累计 Token 使用统计
|
||||
"total_requests": stats.TotalRequests,
|
||||
"total_input_tokens": stats.TotalInputTokens,
|
||||
"total_output_tokens": stats.TotalOutputTokens,
|
||||
"total_cache_creation_tokens": stats.TotalCacheCreationTokens,
|
||||
"total_cache_read_tokens": stats.TotalCacheReadTokens,
|
||||
"total_tokens": stats.TotalTokens,
|
||||
"total_cost": stats.TotalCost, // 标准计费
|
||||
"total_actual_cost": stats.TotalActualCost, // 实际扣除
|
||||
|
||||
// 今日 Token 使用统计
|
||||
"today_requests": stats.TodayRequests,
|
||||
"today_input_tokens": stats.TodayInputTokens,
|
||||
"today_output_tokens": stats.TodayOutputTokens,
|
||||
"today_cache_creation_tokens": stats.TodayCacheCreationTokens,
|
||||
"today_cache_read_tokens": stats.TodayCacheReadTokens,
|
||||
"today_tokens": stats.TodayTokens,
|
||||
"today_cost": stats.TodayCost, // 今日标准计费
|
||||
"today_actual_cost": stats.TodayActualCost, // 今日实际扣除
|
||||
|
||||
// 系统运行统计
|
||||
"average_duration_ms": stats.AverageDurationMs,
|
||||
"uptime": uptime,
|
||||
})
|
||||
}
|
||||
|
||||
// GetRealtimeMetrics handles getting real-time system metrics
|
||||
// GET /api/v1/admin/dashboard/realtime
|
||||
func (h *DashboardHandler) GetRealtimeMetrics(c *gin.Context) {
|
||||
// Return mock data for now
|
||||
response.Success(c, gin.H{
|
||||
"active_requests": 0,
|
||||
"requests_per_minute": 0,
|
||||
"average_response_time": 0,
|
||||
"error_rate": 0.0,
|
||||
})
|
||||
}
|
||||
|
||||
// GetUsageTrend handles getting usage trend data
|
||||
// GET /api/v1/admin/dashboard/trend
|
||||
// Query params: start_date, end_date (YYYY-MM-DD), granularity (day/hour)
|
||||
func (h *DashboardHandler) GetUsageTrend(c *gin.Context) {
|
||||
startTime, endTime := parseTimeRange(c)
|
||||
granularity := c.DefaultQuery("granularity", "day")
|
||||
|
||||
trend, err := h.usageRepo.GetUsageTrend(c.Request.Context(), startTime, endTime, granularity)
|
||||
if err != nil {
|
||||
response.Error(c, 500, "Failed to get usage trend")
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{
|
||||
"trend": trend,
|
||||
"start_date": startTime.Format("2006-01-02"),
|
||||
"end_date": endTime.Add(-24 * time.Hour).Format("2006-01-02"),
|
||||
"granularity": granularity,
|
||||
})
|
||||
}
|
||||
|
||||
// GetModelStats handles getting model usage statistics
|
||||
// GET /api/v1/admin/dashboard/models
|
||||
// Query params: start_date, end_date (YYYY-MM-DD)
|
||||
func (h *DashboardHandler) GetModelStats(c *gin.Context) {
|
||||
startTime, endTime := parseTimeRange(c)
|
||||
|
||||
stats, err := h.usageRepo.GetModelStats(c.Request.Context(), startTime, endTime)
|
||||
if err != nil {
|
||||
response.Error(c, 500, "Failed to get model statistics")
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{
|
||||
"models": stats,
|
||||
"start_date": startTime.Format("2006-01-02"),
|
||||
"end_date": endTime.Add(-24 * time.Hour).Format("2006-01-02"),
|
||||
})
|
||||
}
|
||||
|
||||
// GetApiKeyUsageTrend handles getting API key usage trend data
|
||||
// GET /api/v1/admin/dashboard/api-keys-trend
|
||||
// Query params: start_date, end_date (YYYY-MM-DD), granularity (day/hour), limit (default 5)
|
||||
func (h *DashboardHandler) GetApiKeyUsageTrend(c *gin.Context) {
|
||||
startTime, endTime := parseTimeRange(c)
|
||||
granularity := c.DefaultQuery("granularity", "day")
|
||||
limitStr := c.DefaultQuery("limit", "5")
|
||||
limit, err := strconv.Atoi(limitStr)
|
||||
if err != nil || limit <= 0 {
|
||||
limit = 5
|
||||
}
|
||||
|
||||
trend, err := h.usageRepo.GetApiKeyUsageTrend(c.Request.Context(), startTime, endTime, granularity, limit)
|
||||
if err != nil {
|
||||
response.Error(c, 500, "Failed to get API key usage trend")
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{
|
||||
"trend": trend,
|
||||
"start_date": startTime.Format("2006-01-02"),
|
||||
"end_date": endTime.Add(-24 * time.Hour).Format("2006-01-02"),
|
||||
"granularity": granularity,
|
||||
})
|
||||
}
|
||||
|
||||
// GetUserUsageTrend handles getting user usage trend data
|
||||
// GET /api/v1/admin/dashboard/users-trend
|
||||
// Query params: start_date, end_date (YYYY-MM-DD), granularity (day/hour), limit (default 12)
|
||||
func (h *DashboardHandler) GetUserUsageTrend(c *gin.Context) {
|
||||
startTime, endTime := parseTimeRange(c)
|
||||
granularity := c.DefaultQuery("granularity", "day")
|
||||
limitStr := c.DefaultQuery("limit", "12")
|
||||
limit, err := strconv.Atoi(limitStr)
|
||||
if err != nil || limit <= 0 {
|
||||
limit = 12
|
||||
}
|
||||
|
||||
trend, err := h.usageRepo.GetUserUsageTrend(c.Request.Context(), startTime, endTime, granularity, limit)
|
||||
if err != nil {
|
||||
response.Error(c, 500, "Failed to get user usage trend")
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{
|
||||
"trend": trend,
|
||||
"start_date": startTime.Format("2006-01-02"),
|
||||
"end_date": endTime.Add(-24 * time.Hour).Format("2006-01-02"),
|
||||
"granularity": granularity,
|
||||
})
|
||||
}
|
||||
|
||||
// BatchUsersUsageRequest represents the request body for batch user usage stats
|
||||
type BatchUsersUsageRequest struct {
|
||||
UserIDs []int64 `json:"user_ids" binding:"required"`
|
||||
}
|
||||
|
||||
// GetBatchUsersUsage handles getting usage stats for multiple users
|
||||
// POST /api/v1/admin/dashboard/users-usage
|
||||
func (h *DashboardHandler) GetBatchUsersUsage(c *gin.Context) {
|
||||
var req BatchUsersUsageRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if len(req.UserIDs) == 0 {
|
||||
response.Success(c, gin.H{"stats": map[string]interface{}{}})
|
||||
return
|
||||
}
|
||||
|
||||
stats, err := h.usageRepo.GetBatchUserUsageStats(c.Request.Context(), req.UserIDs)
|
||||
if err != nil {
|
||||
response.Error(c, 500, "Failed to get user usage stats")
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{"stats": stats})
|
||||
}
|
||||
|
||||
// BatchApiKeysUsageRequest represents the request body for batch api key usage stats
|
||||
type BatchApiKeysUsageRequest struct {
|
||||
ApiKeyIDs []int64 `json:"api_key_ids" binding:"required"`
|
||||
}
|
||||
|
||||
// GetBatchApiKeysUsage handles getting usage stats for multiple API keys
|
||||
// POST /api/v1/admin/dashboard/api-keys-usage
|
||||
func (h *DashboardHandler) GetBatchApiKeysUsage(c *gin.Context) {
|
||||
var req BatchApiKeysUsageRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if len(req.ApiKeyIDs) == 0 {
|
||||
response.Success(c, gin.H{"stats": map[string]interface{}{}})
|
||||
return
|
||||
}
|
||||
|
||||
stats, err := h.usageRepo.GetBatchApiKeyUsageStats(c.Request.Context(), req.ApiKeyIDs)
|
||||
if err != nil {
|
||||
response.Error(c, 500, "Failed to get API key usage stats")
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{"stats": stats})
|
||||
}
|
||||
233
backend/internal/handler/admin/group_handler.go
Normal file
233
backend/internal/handler/admin/group_handler.go
Normal file
@@ -0,0 +1,233 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/pkg/response"
|
||||
"sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// GroupHandler handles admin group management
|
||||
type GroupHandler struct {
|
||||
adminService service.AdminService
|
||||
}
|
||||
|
||||
// NewGroupHandler creates a new admin group handler
|
||||
func NewGroupHandler(adminService service.AdminService) *GroupHandler {
|
||||
return &GroupHandler{
|
||||
adminService: adminService,
|
||||
}
|
||||
}
|
||||
|
||||
// CreateGroupRequest represents create group request
|
||||
type CreateGroupRequest struct {
|
||||
Name string `json:"name" binding:"required"`
|
||||
Description string `json:"description"`
|
||||
Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini"`
|
||||
RateMultiplier float64 `json:"rate_multiplier"`
|
||||
IsExclusive bool `json:"is_exclusive"`
|
||||
SubscriptionType string `json:"subscription_type" binding:"omitempty,oneof=standard subscription"`
|
||||
DailyLimitUSD *float64 `json:"daily_limit_usd"`
|
||||
WeeklyLimitUSD *float64 `json:"weekly_limit_usd"`
|
||||
MonthlyLimitUSD *float64 `json:"monthly_limit_usd"`
|
||||
}
|
||||
|
||||
// UpdateGroupRequest represents update group request
|
||||
type UpdateGroupRequest struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini"`
|
||||
RateMultiplier *float64 `json:"rate_multiplier"`
|
||||
IsExclusive *bool `json:"is_exclusive"`
|
||||
Status string `json:"status" binding:"omitempty,oneof=active inactive"`
|
||||
SubscriptionType string `json:"subscription_type" binding:"omitempty,oneof=standard subscription"`
|
||||
DailyLimitUSD *float64 `json:"daily_limit_usd"`
|
||||
WeeklyLimitUSD *float64 `json:"weekly_limit_usd"`
|
||||
MonthlyLimitUSD *float64 `json:"monthly_limit_usd"`
|
||||
}
|
||||
|
||||
// List handles listing all groups with pagination
|
||||
// GET /api/v1/admin/groups
|
||||
func (h *GroupHandler) List(c *gin.Context) {
|
||||
page, pageSize := response.ParsePagination(c)
|
||||
platform := c.Query("platform")
|
||||
status := c.Query("status")
|
||||
isExclusiveStr := c.Query("is_exclusive")
|
||||
|
||||
var isExclusive *bool
|
||||
if isExclusiveStr != "" {
|
||||
val := isExclusiveStr == "true"
|
||||
isExclusive = &val
|
||||
}
|
||||
|
||||
groups, total, err := h.adminService.ListGroups(c.Request.Context(), page, pageSize, platform, status, isExclusive)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to list groups: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
response.Paginated(c, groups, total, page, pageSize)
|
||||
}
|
||||
|
||||
// GetAll handles getting all active groups without pagination
|
||||
// GET /api/v1/admin/groups/all
|
||||
func (h *GroupHandler) GetAll(c *gin.Context) {
|
||||
platform := c.Query("platform")
|
||||
|
||||
var groups []model.Group
|
||||
var err error
|
||||
|
||||
if platform != "" {
|
||||
groups, err = h.adminService.GetAllGroupsByPlatform(c.Request.Context(), platform)
|
||||
} else {
|
||||
groups, err = h.adminService.GetAllGroups(c.Request.Context())
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to get groups: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, groups)
|
||||
}
|
||||
|
||||
// GetByID handles getting a group by ID
|
||||
// GET /api/v1/admin/groups/:id
|
||||
func (h *GroupHandler) GetByID(c *gin.Context) {
|
||||
groupID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid group ID")
|
||||
return
|
||||
}
|
||||
|
||||
group, err := h.adminService.GetGroup(c.Request.Context(), groupID)
|
||||
if err != nil {
|
||||
response.NotFound(c, "Group not found")
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, group)
|
||||
}
|
||||
|
||||
// Create handles creating a new group
|
||||
// POST /api/v1/admin/groups
|
||||
func (h *GroupHandler) Create(c *gin.Context) {
|
||||
var req CreateGroupRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
group, err := h.adminService.CreateGroup(c.Request.Context(), &service.CreateGroupInput{
|
||||
Name: req.Name,
|
||||
Description: req.Description,
|
||||
Platform: req.Platform,
|
||||
RateMultiplier: req.RateMultiplier,
|
||||
IsExclusive: req.IsExclusive,
|
||||
SubscriptionType: req.SubscriptionType,
|
||||
DailyLimitUSD: req.DailyLimitUSD,
|
||||
WeeklyLimitUSD: req.WeeklyLimitUSD,
|
||||
MonthlyLimitUSD: req.MonthlyLimitUSD,
|
||||
})
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Failed to create group: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, group)
|
||||
}
|
||||
|
||||
// Update handles updating a group
|
||||
// PUT /api/v1/admin/groups/:id
|
||||
func (h *GroupHandler) Update(c *gin.Context) {
|
||||
groupID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid group ID")
|
||||
return
|
||||
}
|
||||
|
||||
var req UpdateGroupRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
group, err := h.adminService.UpdateGroup(c.Request.Context(), groupID, &service.UpdateGroupInput{
|
||||
Name: req.Name,
|
||||
Description: req.Description,
|
||||
Platform: req.Platform,
|
||||
RateMultiplier: req.RateMultiplier,
|
||||
IsExclusive: req.IsExclusive,
|
||||
Status: req.Status,
|
||||
SubscriptionType: req.SubscriptionType,
|
||||
DailyLimitUSD: req.DailyLimitUSD,
|
||||
WeeklyLimitUSD: req.WeeklyLimitUSD,
|
||||
MonthlyLimitUSD: req.MonthlyLimitUSD,
|
||||
})
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to update group: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, group)
|
||||
}
|
||||
|
||||
// Delete handles deleting a group
|
||||
// DELETE /api/v1/admin/groups/:id
|
||||
func (h *GroupHandler) Delete(c *gin.Context) {
|
||||
groupID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid group ID")
|
||||
return
|
||||
}
|
||||
|
||||
err = h.adminService.DeleteGroup(c.Request.Context(), groupID)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to delete group: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{"message": "Group deleted successfully"})
|
||||
}
|
||||
|
||||
// GetStats handles getting group statistics
|
||||
// GET /api/v1/admin/groups/:id/stats
|
||||
func (h *GroupHandler) GetStats(c *gin.Context) {
|
||||
groupID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid group ID")
|
||||
return
|
||||
}
|
||||
|
||||
// Return mock data for now
|
||||
response.Success(c, gin.H{
|
||||
"total_api_keys": 0,
|
||||
"active_api_keys": 0,
|
||||
"total_requests": 0,
|
||||
"total_cost": 0.0,
|
||||
})
|
||||
_ = groupID // TODO: implement actual stats
|
||||
}
|
||||
|
||||
// GetGroupAPIKeys handles getting API keys in a group
|
||||
// GET /api/v1/admin/groups/:id/api-keys
|
||||
func (h *GroupHandler) GetGroupAPIKeys(c *gin.Context) {
|
||||
groupID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid group ID")
|
||||
return
|
||||
}
|
||||
|
||||
page, pageSize := response.ParsePagination(c)
|
||||
|
||||
keys, total, err := h.adminService.GetGroupAPIKeys(c.Request.Context(), groupID, page, pageSize)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to get group API keys: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
response.Paginated(c, keys, total, page, pageSize)
|
||||
}
|
||||
300
backend/internal/handler/admin/proxy_handler.go
Normal file
300
backend/internal/handler/admin/proxy_handler.go
Normal file
@@ -0,0 +1,300 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
|
||||
"sub2api/internal/pkg/response"
|
||||
"sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// ProxyHandler handles admin proxy management
|
||||
type ProxyHandler struct {
|
||||
adminService service.AdminService
|
||||
}
|
||||
|
||||
// NewProxyHandler creates a new admin proxy handler
|
||||
func NewProxyHandler(adminService service.AdminService) *ProxyHandler {
|
||||
return &ProxyHandler{
|
||||
adminService: adminService,
|
||||
}
|
||||
}
|
||||
|
||||
// CreateProxyRequest represents create proxy request
|
||||
type CreateProxyRequest struct {
|
||||
Name string `json:"name" binding:"required"`
|
||||
Protocol string `json:"protocol" binding:"required,oneof=http https socks5"`
|
||||
Host string `json:"host" binding:"required"`
|
||||
Port int `json:"port" binding:"required,min=1,max=65535"`
|
||||
Username string `json:"username"`
|
||||
Password string `json:"password"`
|
||||
}
|
||||
|
||||
// UpdateProxyRequest represents update proxy request
|
||||
type UpdateProxyRequest struct {
|
||||
Name string `json:"name"`
|
||||
Protocol string `json:"protocol" binding:"omitempty,oneof=http https socks5"`
|
||||
Host string `json:"host"`
|
||||
Port int `json:"port" binding:"omitempty,min=1,max=65535"`
|
||||
Username string `json:"username"`
|
||||
Password string `json:"password"`
|
||||
Status string `json:"status" binding:"omitempty,oneof=active inactive"`
|
||||
}
|
||||
|
||||
// List handles listing all proxies with pagination
|
||||
// GET /api/v1/admin/proxies
|
||||
func (h *ProxyHandler) List(c *gin.Context) {
|
||||
page, pageSize := response.ParsePagination(c)
|
||||
protocol := c.Query("protocol")
|
||||
status := c.Query("status")
|
||||
search := c.Query("search")
|
||||
|
||||
proxies, total, err := h.adminService.ListProxies(c.Request.Context(), page, pageSize, protocol, status, search)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to list proxies: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
response.Paginated(c, proxies, total, page, pageSize)
|
||||
}
|
||||
|
||||
// GetAll handles getting all active proxies without pagination
|
||||
// GET /api/v1/admin/proxies/all
|
||||
// Optional query param: with_count=true to include account count per proxy
|
||||
func (h *ProxyHandler) GetAll(c *gin.Context) {
|
||||
withCount := c.Query("with_count") == "true"
|
||||
|
||||
if withCount {
|
||||
proxies, err := h.adminService.GetAllProxiesWithAccountCount(c.Request.Context())
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to get proxies: "+err.Error())
|
||||
return
|
||||
}
|
||||
response.Success(c, proxies)
|
||||
return
|
||||
}
|
||||
|
||||
proxies, err := h.adminService.GetAllProxies(c.Request.Context())
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to get proxies: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, proxies)
|
||||
}
|
||||
|
||||
// GetByID handles getting a proxy by ID
|
||||
// GET /api/v1/admin/proxies/:id
|
||||
func (h *ProxyHandler) GetByID(c *gin.Context) {
|
||||
proxyID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid proxy ID")
|
||||
return
|
||||
}
|
||||
|
||||
proxy, err := h.adminService.GetProxy(c.Request.Context(), proxyID)
|
||||
if err != nil {
|
||||
response.NotFound(c, "Proxy not found")
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, proxy)
|
||||
}
|
||||
|
||||
// Create handles creating a new proxy
|
||||
// POST /api/v1/admin/proxies
|
||||
func (h *ProxyHandler) Create(c *gin.Context) {
|
||||
var req CreateProxyRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
proxy, err := h.adminService.CreateProxy(c.Request.Context(), &service.CreateProxyInput{
|
||||
Name: req.Name,
|
||||
Protocol: req.Protocol,
|
||||
Host: req.Host,
|
||||
Port: req.Port,
|
||||
Username: req.Username,
|
||||
Password: req.Password,
|
||||
})
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Failed to create proxy: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, proxy)
|
||||
}
|
||||
|
||||
// Update handles updating a proxy
|
||||
// PUT /api/v1/admin/proxies/:id
|
||||
func (h *ProxyHandler) Update(c *gin.Context) {
|
||||
proxyID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid proxy ID")
|
||||
return
|
||||
}
|
||||
|
||||
var req UpdateProxyRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
proxy, err := h.adminService.UpdateProxy(c.Request.Context(), proxyID, &service.UpdateProxyInput{
|
||||
Name: req.Name,
|
||||
Protocol: req.Protocol,
|
||||
Host: req.Host,
|
||||
Port: req.Port,
|
||||
Username: req.Username,
|
||||
Password: req.Password,
|
||||
Status: req.Status,
|
||||
})
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to update proxy: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, proxy)
|
||||
}
|
||||
|
||||
// Delete handles deleting a proxy
|
||||
// DELETE /api/v1/admin/proxies/:id
|
||||
func (h *ProxyHandler) Delete(c *gin.Context) {
|
||||
proxyID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid proxy ID")
|
||||
return
|
||||
}
|
||||
|
||||
err = h.adminService.DeleteProxy(c.Request.Context(), proxyID)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to delete proxy: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{"message": "Proxy deleted successfully"})
|
||||
}
|
||||
|
||||
// Test handles testing proxy connectivity
|
||||
// POST /api/v1/admin/proxies/:id/test
|
||||
func (h *ProxyHandler) Test(c *gin.Context) {
|
||||
proxyID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid proxy ID")
|
||||
return
|
||||
}
|
||||
|
||||
result, err := h.adminService.TestProxy(c.Request.Context(), proxyID)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to test proxy: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, result)
|
||||
}
|
||||
|
||||
// GetStats handles getting proxy statistics
|
||||
// GET /api/v1/admin/proxies/:id/stats
|
||||
func (h *ProxyHandler) GetStats(c *gin.Context) {
|
||||
proxyID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid proxy ID")
|
||||
return
|
||||
}
|
||||
|
||||
// Return mock data for now
|
||||
_ = proxyID
|
||||
response.Success(c, gin.H{
|
||||
"total_accounts": 0,
|
||||
"active_accounts": 0,
|
||||
"total_requests": 0,
|
||||
"success_rate": 100.0,
|
||||
"average_latency": 0,
|
||||
})
|
||||
}
|
||||
|
||||
// GetProxyAccounts handles getting accounts using a proxy
|
||||
// GET /api/v1/admin/proxies/:id/accounts
|
||||
func (h *ProxyHandler) GetProxyAccounts(c *gin.Context) {
|
||||
proxyID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid proxy ID")
|
||||
return
|
||||
}
|
||||
|
||||
page, pageSize := response.ParsePagination(c)
|
||||
|
||||
accounts, total, err := h.adminService.GetProxyAccounts(c.Request.Context(), proxyID, page, pageSize)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to get proxy accounts: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
response.Paginated(c, accounts, total, page, pageSize)
|
||||
}
|
||||
|
||||
|
||||
// BatchCreateProxyItem represents a single proxy in batch create request
|
||||
type BatchCreateProxyItem struct {
|
||||
Protocol string `json:"protocol" binding:"required,oneof=http https socks5"`
|
||||
Host string `json:"host" binding:"required"`
|
||||
Port int `json:"port" binding:"required,min=1,max=65535"`
|
||||
Username string `json:"username"`
|
||||
Password string `json:"password"`
|
||||
}
|
||||
|
||||
// BatchCreateRequest represents batch create proxies request
|
||||
type BatchCreateRequest struct {
|
||||
Proxies []BatchCreateProxyItem `json:"proxies" binding:"required,min=1"`
|
||||
}
|
||||
|
||||
// BatchCreate handles batch creating proxies
|
||||
// POST /api/v1/admin/proxies/batch
|
||||
func (h *ProxyHandler) BatchCreate(c *gin.Context) {
|
||||
var req BatchCreateRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
created := 0
|
||||
skipped := 0
|
||||
|
||||
for _, item := range req.Proxies {
|
||||
// Check for duplicates (same host, port, username, password)
|
||||
exists, err := h.adminService.CheckProxyExists(c.Request.Context(), item.Host, item.Port, item.Username, item.Password)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to check proxy existence: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if exists {
|
||||
skipped++
|
||||
continue
|
||||
}
|
||||
|
||||
// Create proxy with default name
|
||||
_, err = h.adminService.CreateProxy(c.Request.Context(), &service.CreateProxyInput{
|
||||
Name: "default",
|
||||
Protocol: item.Protocol,
|
||||
Host: item.Host,
|
||||
Port: item.Port,
|
||||
Username: item.Username,
|
||||
Password: item.Password,
|
||||
})
|
||||
if err != nil {
|
||||
// If creation fails due to duplicate, count as skipped
|
||||
skipped++
|
||||
continue
|
||||
}
|
||||
|
||||
created++
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{
|
||||
"created": created,
|
||||
"skipped": skipped,
|
||||
})
|
||||
}
|
||||
219
backend/internal/handler/admin/redeem_handler.go
Normal file
219
backend/internal/handler/admin/redeem_handler.go
Normal file
@@ -0,0 +1,219 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/csv"
|
||||
"fmt"
|
||||
"strconv"
|
||||
|
||||
"sub2api/internal/pkg/response"
|
||||
"sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// RedeemHandler handles admin redeem code management
|
||||
type RedeemHandler struct {
|
||||
adminService service.AdminService
|
||||
}
|
||||
|
||||
// NewRedeemHandler creates a new admin redeem handler
|
||||
func NewRedeemHandler(adminService service.AdminService) *RedeemHandler {
|
||||
return &RedeemHandler{
|
||||
adminService: adminService,
|
||||
}
|
||||
}
|
||||
|
||||
// GenerateRedeemCodesRequest represents generate redeem codes request
|
||||
type GenerateRedeemCodesRequest struct {
|
||||
Count int `json:"count" binding:"required,min=1,max=100"`
|
||||
Type string `json:"type" binding:"required,oneof=balance concurrency subscription"`
|
||||
Value float64 `json:"value" binding:"min=0"`
|
||||
GroupID *int64 `json:"group_id"` // 订阅类型必填
|
||||
ValidityDays int `json:"validity_days"` // 订阅类型使用,默认30天
|
||||
}
|
||||
|
||||
// List handles listing all redeem codes with pagination
|
||||
// GET /api/v1/admin/redeem-codes
|
||||
func (h *RedeemHandler) List(c *gin.Context) {
|
||||
page, pageSize := response.ParsePagination(c)
|
||||
codeType := c.Query("type")
|
||||
status := c.Query("status")
|
||||
search := c.Query("search")
|
||||
|
||||
codes, total, err := h.adminService.ListRedeemCodes(c.Request.Context(), page, pageSize, codeType, status, search)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to list redeem codes: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
response.Paginated(c, codes, total, page, pageSize)
|
||||
}
|
||||
|
||||
// GetByID handles getting a redeem code by ID
|
||||
// GET /api/v1/admin/redeem-codes/:id
|
||||
func (h *RedeemHandler) GetByID(c *gin.Context) {
|
||||
codeID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid redeem code ID")
|
||||
return
|
||||
}
|
||||
|
||||
code, err := h.adminService.GetRedeemCode(c.Request.Context(), codeID)
|
||||
if err != nil {
|
||||
response.NotFound(c, "Redeem code not found")
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, code)
|
||||
}
|
||||
|
||||
// Generate handles generating new redeem codes
|
||||
// POST /api/v1/admin/redeem-codes/generate
|
||||
func (h *RedeemHandler) Generate(c *gin.Context) {
|
||||
var req GenerateRedeemCodesRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
codes, err := h.adminService.GenerateRedeemCodes(c.Request.Context(), &service.GenerateRedeemCodesInput{
|
||||
Count: req.Count,
|
||||
Type: req.Type,
|
||||
Value: req.Value,
|
||||
GroupID: req.GroupID,
|
||||
ValidityDays: req.ValidityDays,
|
||||
})
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to generate redeem codes: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, codes)
|
||||
}
|
||||
|
||||
// Delete handles deleting a redeem code
|
||||
// DELETE /api/v1/admin/redeem-codes/:id
|
||||
func (h *RedeemHandler) Delete(c *gin.Context) {
|
||||
codeID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid redeem code ID")
|
||||
return
|
||||
}
|
||||
|
||||
err = h.adminService.DeleteRedeemCode(c.Request.Context(), codeID)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to delete redeem code: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{"message": "Redeem code deleted successfully"})
|
||||
}
|
||||
|
||||
// BatchDelete handles batch deleting redeem codes
|
||||
// POST /api/v1/admin/redeem-codes/batch-delete
|
||||
func (h *RedeemHandler) BatchDelete(c *gin.Context) {
|
||||
var req struct {
|
||||
IDs []int64 `json:"ids" binding:"required,min=1"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
deleted, err := h.adminService.BatchDeleteRedeemCodes(c.Request.Context(), req.IDs)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to batch delete redeem codes: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{
|
||||
"deleted": deleted,
|
||||
"message": "Redeem codes deleted successfully",
|
||||
})
|
||||
}
|
||||
|
||||
// Expire handles expiring a redeem code
|
||||
// POST /api/v1/admin/redeem-codes/:id/expire
|
||||
func (h *RedeemHandler) Expire(c *gin.Context) {
|
||||
codeID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid redeem code ID")
|
||||
return
|
||||
}
|
||||
|
||||
code, err := h.adminService.ExpireRedeemCode(c.Request.Context(), codeID)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to expire redeem code: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, code)
|
||||
}
|
||||
|
||||
// GetStats handles getting redeem code statistics
|
||||
// GET /api/v1/admin/redeem-codes/stats
|
||||
func (h *RedeemHandler) GetStats(c *gin.Context) {
|
||||
// Return mock data for now
|
||||
response.Success(c, gin.H{
|
||||
"total_codes": 0,
|
||||
"active_codes": 0,
|
||||
"used_codes": 0,
|
||||
"expired_codes": 0,
|
||||
"total_value_distributed": 0.0,
|
||||
"by_type": gin.H{
|
||||
"balance": 0,
|
||||
"concurrency": 0,
|
||||
"trial": 0,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// Export handles exporting redeem codes to CSV
|
||||
// GET /api/v1/admin/redeem-codes/export
|
||||
func (h *RedeemHandler) Export(c *gin.Context) {
|
||||
codeType := c.Query("type")
|
||||
status := c.Query("status")
|
||||
|
||||
// Get all codes without pagination (use large page size)
|
||||
codes, _, err := h.adminService.ListRedeemCodes(c.Request.Context(), 1, 10000, codeType, status, "")
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to export redeem codes: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// Create CSV buffer
|
||||
var buf bytes.Buffer
|
||||
writer := csv.NewWriter(&buf)
|
||||
|
||||
// Write header
|
||||
writer.Write([]string{"id", "code", "type", "value", "status", "used_by", "used_at", "created_at"})
|
||||
|
||||
// Write data rows
|
||||
for _, code := range codes {
|
||||
usedBy := ""
|
||||
if code.UsedBy != nil {
|
||||
usedBy = fmt.Sprintf("%d", *code.UsedBy)
|
||||
}
|
||||
usedAt := ""
|
||||
if code.UsedAt != nil {
|
||||
usedAt = code.UsedAt.Format("2006-01-02 15:04:05")
|
||||
}
|
||||
writer.Write([]string{
|
||||
fmt.Sprintf("%d", code.ID),
|
||||
code.Code,
|
||||
code.Type,
|
||||
fmt.Sprintf("%.2f", code.Value),
|
||||
code.Status,
|
||||
usedBy,
|
||||
usedAt,
|
||||
code.CreatedAt.Format("2006-01-02 15:04:05"),
|
||||
})
|
||||
}
|
||||
|
||||
writer.Flush()
|
||||
|
||||
c.Header("Content-Type", "text/csv")
|
||||
c.Header("Content-Disposition", "attachment; filename=redeem_codes.csv")
|
||||
c.Data(200, "text/csv", buf.Bytes())
|
||||
}
|
||||
258
backend/internal/handler/admin/setting_handler.go
Normal file
258
backend/internal/handler/admin/setting_handler.go
Normal file
@@ -0,0 +1,258 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/pkg/response"
|
||||
"sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// SettingHandler 系统设置处理器
|
||||
type SettingHandler struct {
|
||||
settingService *service.SettingService
|
||||
emailService *service.EmailService
|
||||
}
|
||||
|
||||
// NewSettingHandler 创建系统设置处理器
|
||||
func NewSettingHandler(settingService *service.SettingService, emailService *service.EmailService) *SettingHandler {
|
||||
return &SettingHandler{
|
||||
settingService: settingService,
|
||||
emailService: emailService,
|
||||
}
|
||||
}
|
||||
|
||||
// GetSettings 获取所有系统设置
|
||||
// GET /api/v1/admin/settings
|
||||
func (h *SettingHandler) GetSettings(c *gin.Context) {
|
||||
settings, err := h.settingService.GetAllSettings(c.Request.Context())
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to get settings: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, settings)
|
||||
}
|
||||
|
||||
// UpdateSettingsRequest 更新设置请求
|
||||
type UpdateSettingsRequest struct {
|
||||
// 注册设置
|
||||
RegistrationEnabled bool `json:"registration_enabled"`
|
||||
EmailVerifyEnabled bool `json:"email_verify_enabled"`
|
||||
|
||||
// 邮件服务设置
|
||||
SmtpHost string `json:"smtp_host"`
|
||||
SmtpPort int `json:"smtp_port"`
|
||||
SmtpUsername string `json:"smtp_username"`
|
||||
SmtpPassword string `json:"smtp_password"`
|
||||
SmtpFrom string `json:"smtp_from_email"`
|
||||
SmtpFromName string `json:"smtp_from_name"`
|
||||
SmtpUseTLS bool `json:"smtp_use_tls"`
|
||||
|
||||
// Cloudflare Turnstile 设置
|
||||
TurnstileEnabled bool `json:"turnstile_enabled"`
|
||||
TurnstileSiteKey string `json:"turnstile_site_key"`
|
||||
TurnstileSecretKey string `json:"turnstile_secret_key"`
|
||||
|
||||
// OEM设置
|
||||
SiteName string `json:"site_name"`
|
||||
SiteLogo string `json:"site_logo"`
|
||||
SiteSubtitle string `json:"site_subtitle"`
|
||||
ApiBaseUrl string `json:"api_base_url"`
|
||||
ContactInfo string `json:"contact_info"`
|
||||
|
||||
// 默认配置
|
||||
DefaultConcurrency int `json:"default_concurrency"`
|
||||
DefaultBalance float64 `json:"default_balance"`
|
||||
}
|
||||
|
||||
// UpdateSettings 更新系统设置
|
||||
// PUT /api/v1/admin/settings
|
||||
func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
var req UpdateSettingsRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 验证参数
|
||||
if req.DefaultConcurrency < 1 {
|
||||
req.DefaultConcurrency = 1
|
||||
}
|
||||
if req.DefaultBalance < 0 {
|
||||
req.DefaultBalance = 0
|
||||
}
|
||||
if req.SmtpPort <= 0 {
|
||||
req.SmtpPort = 587
|
||||
}
|
||||
|
||||
settings := &model.SystemSettings{
|
||||
RegistrationEnabled: req.RegistrationEnabled,
|
||||
EmailVerifyEnabled: req.EmailVerifyEnabled,
|
||||
SmtpHost: req.SmtpHost,
|
||||
SmtpPort: req.SmtpPort,
|
||||
SmtpUsername: req.SmtpUsername,
|
||||
SmtpPassword: req.SmtpPassword,
|
||||
SmtpFrom: req.SmtpFrom,
|
||||
SmtpFromName: req.SmtpFromName,
|
||||
SmtpUseTLS: req.SmtpUseTLS,
|
||||
TurnstileEnabled: req.TurnstileEnabled,
|
||||
TurnstileSiteKey: req.TurnstileSiteKey,
|
||||
TurnstileSecretKey: req.TurnstileSecretKey,
|
||||
SiteName: req.SiteName,
|
||||
SiteLogo: req.SiteLogo,
|
||||
SiteSubtitle: req.SiteSubtitle,
|
||||
ApiBaseUrl: req.ApiBaseUrl,
|
||||
ContactInfo: req.ContactInfo,
|
||||
DefaultConcurrency: req.DefaultConcurrency,
|
||||
DefaultBalance: req.DefaultBalance,
|
||||
}
|
||||
|
||||
if err := h.settingService.UpdateSettings(c.Request.Context(), settings); err != nil {
|
||||
response.InternalError(c, "Failed to update settings: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 重新获取设置返回
|
||||
updatedSettings, err := h.settingService.GetAllSettings(c.Request.Context())
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to get updated settings: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, updatedSettings)
|
||||
}
|
||||
|
||||
// TestSmtpRequest 测试SMTP连接请求
|
||||
type TestSmtpRequest struct {
|
||||
SmtpHost string `json:"smtp_host" binding:"required"`
|
||||
SmtpPort int `json:"smtp_port"`
|
||||
SmtpUsername string `json:"smtp_username"`
|
||||
SmtpPassword string `json:"smtp_password"`
|
||||
SmtpUseTLS bool `json:"smtp_use_tls"`
|
||||
}
|
||||
|
||||
// TestSmtpConnection 测试SMTP连接
|
||||
// POST /api/v1/admin/settings/test-smtp
|
||||
func (h *SettingHandler) TestSmtpConnection(c *gin.Context) {
|
||||
var req TestSmtpRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if req.SmtpPort <= 0 {
|
||||
req.SmtpPort = 587
|
||||
}
|
||||
|
||||
// 如果未提供密码,从数据库获取已保存的密码
|
||||
password := req.SmtpPassword
|
||||
if password == "" {
|
||||
savedConfig, err := h.emailService.GetSmtpConfig(c.Request.Context())
|
||||
if err == nil && savedConfig != nil {
|
||||
password = savedConfig.Password
|
||||
}
|
||||
}
|
||||
|
||||
config := &service.SmtpConfig{
|
||||
Host: req.SmtpHost,
|
||||
Port: req.SmtpPort,
|
||||
Username: req.SmtpUsername,
|
||||
Password: password,
|
||||
UseTLS: req.SmtpUseTLS,
|
||||
}
|
||||
|
||||
err := h.emailService.TestSmtpConnectionWithConfig(config)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "SMTP connection test failed: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{"message": "SMTP connection successful"})
|
||||
}
|
||||
|
||||
// SendTestEmailRequest 发送测试邮件请求
|
||||
type SendTestEmailRequest struct {
|
||||
Email string `json:"email" binding:"required,email"`
|
||||
SmtpHost string `json:"smtp_host" binding:"required"`
|
||||
SmtpPort int `json:"smtp_port"`
|
||||
SmtpUsername string `json:"smtp_username"`
|
||||
SmtpPassword string `json:"smtp_password"`
|
||||
SmtpFrom string `json:"smtp_from_email"`
|
||||
SmtpFromName string `json:"smtp_from_name"`
|
||||
SmtpUseTLS bool `json:"smtp_use_tls"`
|
||||
}
|
||||
|
||||
// SendTestEmail 发送测试邮件
|
||||
// POST /api/v1/admin/settings/send-test-email
|
||||
func (h *SettingHandler) SendTestEmail(c *gin.Context) {
|
||||
var req SendTestEmailRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if req.SmtpPort <= 0 {
|
||||
req.SmtpPort = 587
|
||||
}
|
||||
|
||||
// 如果未提供密码,从数据库获取已保存的密码
|
||||
password := req.SmtpPassword
|
||||
if password == "" {
|
||||
savedConfig, err := h.emailService.GetSmtpConfig(c.Request.Context())
|
||||
if err == nil && savedConfig != nil {
|
||||
password = savedConfig.Password
|
||||
}
|
||||
}
|
||||
|
||||
config := &service.SmtpConfig{
|
||||
Host: req.SmtpHost,
|
||||
Port: req.SmtpPort,
|
||||
Username: req.SmtpUsername,
|
||||
Password: password,
|
||||
From: req.SmtpFrom,
|
||||
FromName: req.SmtpFromName,
|
||||
UseTLS: req.SmtpUseTLS,
|
||||
}
|
||||
|
||||
siteName := h.settingService.GetSiteName(c.Request.Context())
|
||||
subject := "[" + siteName + "] Test Email"
|
||||
body := `
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<style>
|
||||
body { font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif; background-color: #f5f5f5; margin: 0; padding: 20px; }
|
||||
.container { max-width: 600px; margin: 0 auto; background-color: #ffffff; border-radius: 8px; overflow: hidden; box-shadow: 0 2px 8px rgba(0,0,0,0.1); }
|
||||
.header { background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); color: white; padding: 30px; text-align: center; }
|
||||
.content { padding: 40px 30px; text-align: center; }
|
||||
.success { color: #10b981; font-size: 48px; margin-bottom: 20px; }
|
||||
.footer { background-color: #f8f9fa; padding: 20px; text-align: center; color: #999; font-size: 12px; }
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="container">
|
||||
<div class="header">
|
||||
<h1>` + siteName + `</h1>
|
||||
</div>
|
||||
<div class="content">
|
||||
<div class="success">✓</div>
|
||||
<h2>Email Configuration Successful!</h2>
|
||||
<p>This is a test email to verify your SMTP settings are working correctly.</p>
|
||||
</div>
|
||||
<div class="footer">
|
||||
<p>This is an automated test message.</p>
|
||||
</div>
|
||||
</div>
|
||||
</body>
|
||||
</html>
|
||||
`
|
||||
|
||||
if err := h.emailService.SendEmailWithConfig(config, req.Email, subject, body); err != nil {
|
||||
response.BadRequest(c, "Failed to send test email: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{"message": "Test email sent successfully"})
|
||||
}
|
||||
266
backend/internal/handler/admin/subscription_handler.go
Normal file
266
backend/internal/handler/admin/subscription_handler.go
Normal file
@@ -0,0 +1,266 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/pkg/response"
|
||||
"sub2api/internal/repository"
|
||||
"sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// toResponsePagination converts repository.PaginationResult to response.PaginationResult
|
||||
func toResponsePagination(p *repository.PaginationResult) *response.PaginationResult {
|
||||
if p == nil {
|
||||
return nil
|
||||
}
|
||||
return &response.PaginationResult{
|
||||
Total: p.Total,
|
||||
Page: p.Page,
|
||||
PageSize: p.PageSize,
|
||||
Pages: p.Pages,
|
||||
}
|
||||
}
|
||||
|
||||
// SubscriptionHandler handles admin subscription management
|
||||
type SubscriptionHandler struct {
|
||||
subscriptionService *service.SubscriptionService
|
||||
}
|
||||
|
||||
// NewSubscriptionHandler creates a new admin subscription handler
|
||||
func NewSubscriptionHandler(subscriptionService *service.SubscriptionService) *SubscriptionHandler {
|
||||
return &SubscriptionHandler{
|
||||
subscriptionService: subscriptionService,
|
||||
}
|
||||
}
|
||||
|
||||
// AssignSubscriptionRequest represents assign subscription request
|
||||
type AssignSubscriptionRequest struct {
|
||||
UserID int64 `json:"user_id" binding:"required"`
|
||||
GroupID int64 `json:"group_id" binding:"required"`
|
||||
ValidityDays int `json:"validity_days"`
|
||||
Notes string `json:"notes"`
|
||||
}
|
||||
|
||||
// BulkAssignSubscriptionRequest represents bulk assign subscription request
|
||||
type BulkAssignSubscriptionRequest struct {
|
||||
UserIDs []int64 `json:"user_ids" binding:"required,min=1"`
|
||||
GroupID int64 `json:"group_id" binding:"required"`
|
||||
ValidityDays int `json:"validity_days"`
|
||||
Notes string `json:"notes"`
|
||||
}
|
||||
|
||||
// ExtendSubscriptionRequest represents extend subscription request
|
||||
type ExtendSubscriptionRequest struct {
|
||||
Days int `json:"days" binding:"required,min=1"`
|
||||
}
|
||||
|
||||
// List handles listing all subscriptions with pagination and filters
|
||||
// GET /api/v1/admin/subscriptions
|
||||
func (h *SubscriptionHandler) List(c *gin.Context) {
|
||||
page, pageSize := response.ParsePagination(c)
|
||||
|
||||
// Parse optional filters
|
||||
var userID, groupID *int64
|
||||
if userIDStr := c.Query("user_id"); userIDStr != "" {
|
||||
if id, err := strconv.ParseInt(userIDStr, 10, 64); err == nil {
|
||||
userID = &id
|
||||
}
|
||||
}
|
||||
if groupIDStr := c.Query("group_id"); groupIDStr != "" {
|
||||
if id, err := strconv.ParseInt(groupIDStr, 10, 64); err == nil {
|
||||
groupID = &id
|
||||
}
|
||||
}
|
||||
status := c.Query("status")
|
||||
|
||||
subscriptions, pagination, err := h.subscriptionService.List(c.Request.Context(), page, pageSize, userID, groupID, status)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to list subscriptions: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
response.PaginatedWithResult(c, subscriptions, toResponsePagination(pagination))
|
||||
}
|
||||
|
||||
// GetByID handles getting a subscription by ID
|
||||
// GET /api/v1/admin/subscriptions/:id
|
||||
func (h *SubscriptionHandler) GetByID(c *gin.Context) {
|
||||
subscriptionID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid subscription ID")
|
||||
return
|
||||
}
|
||||
|
||||
subscription, err := h.subscriptionService.GetByID(c.Request.Context(), subscriptionID)
|
||||
if err != nil {
|
||||
response.NotFound(c, "Subscription not found")
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, subscription)
|
||||
}
|
||||
|
||||
// GetProgress handles getting subscription usage progress
|
||||
// GET /api/v1/admin/subscriptions/:id/progress
|
||||
func (h *SubscriptionHandler) GetProgress(c *gin.Context) {
|
||||
subscriptionID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid subscription ID")
|
||||
return
|
||||
}
|
||||
|
||||
progress, err := h.subscriptionService.GetSubscriptionProgress(c.Request.Context(), subscriptionID)
|
||||
if err != nil {
|
||||
response.NotFound(c, "Subscription not found")
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, progress)
|
||||
}
|
||||
|
||||
// Assign handles assigning a subscription to a user
|
||||
// POST /api/v1/admin/subscriptions/assign
|
||||
func (h *SubscriptionHandler) Assign(c *gin.Context) {
|
||||
var req AssignSubscriptionRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// Get admin user ID from context
|
||||
adminID := getAdminIDFromContext(c)
|
||||
|
||||
subscription, err := h.subscriptionService.AssignSubscription(c.Request.Context(), &service.AssignSubscriptionInput{
|
||||
UserID: req.UserID,
|
||||
GroupID: req.GroupID,
|
||||
ValidityDays: req.ValidityDays,
|
||||
AssignedBy: adminID,
|
||||
Notes: req.Notes,
|
||||
})
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Failed to assign subscription: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, subscription)
|
||||
}
|
||||
|
||||
// BulkAssign handles bulk assigning subscriptions to multiple users
|
||||
// POST /api/v1/admin/subscriptions/bulk-assign
|
||||
func (h *SubscriptionHandler) BulkAssign(c *gin.Context) {
|
||||
var req BulkAssignSubscriptionRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// Get admin user ID from context
|
||||
adminID := getAdminIDFromContext(c)
|
||||
|
||||
result, err := h.subscriptionService.BulkAssignSubscription(c.Request.Context(), &service.BulkAssignSubscriptionInput{
|
||||
UserIDs: req.UserIDs,
|
||||
GroupID: req.GroupID,
|
||||
ValidityDays: req.ValidityDays,
|
||||
AssignedBy: adminID,
|
||||
Notes: req.Notes,
|
||||
})
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to bulk assign subscriptions: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, result)
|
||||
}
|
||||
|
||||
// Extend handles extending a subscription
|
||||
// POST /api/v1/admin/subscriptions/:id/extend
|
||||
func (h *SubscriptionHandler) Extend(c *gin.Context) {
|
||||
subscriptionID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid subscription ID")
|
||||
return
|
||||
}
|
||||
|
||||
var req ExtendSubscriptionRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
subscription, err := h.subscriptionService.ExtendSubscription(c.Request.Context(), subscriptionID, req.Days)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to extend subscription: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, subscription)
|
||||
}
|
||||
|
||||
// Revoke handles revoking a subscription
|
||||
// DELETE /api/v1/admin/subscriptions/:id
|
||||
func (h *SubscriptionHandler) Revoke(c *gin.Context) {
|
||||
subscriptionID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid subscription ID")
|
||||
return
|
||||
}
|
||||
|
||||
err = h.subscriptionService.RevokeSubscription(c.Request.Context(), subscriptionID)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to revoke subscription: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{"message": "Subscription revoked successfully"})
|
||||
}
|
||||
|
||||
// ListByGroup handles listing subscriptions for a specific group
|
||||
// GET /api/v1/admin/groups/:id/subscriptions
|
||||
func (h *SubscriptionHandler) ListByGroup(c *gin.Context) {
|
||||
groupID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid group ID")
|
||||
return
|
||||
}
|
||||
|
||||
page, pageSize := response.ParsePagination(c)
|
||||
|
||||
subscriptions, pagination, err := h.subscriptionService.ListGroupSubscriptions(c.Request.Context(), groupID, page, pageSize)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to list group subscriptions: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
response.PaginatedWithResult(c, subscriptions, toResponsePagination(pagination))
|
||||
}
|
||||
|
||||
// ListByUser handles listing subscriptions for a specific user
|
||||
// GET /api/v1/admin/users/:id/subscriptions
|
||||
func (h *SubscriptionHandler) ListByUser(c *gin.Context) {
|
||||
userID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid user ID")
|
||||
return
|
||||
}
|
||||
|
||||
subscriptions, err := h.subscriptionService.ListUserSubscriptions(c.Request.Context(), userID)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to list user subscriptions: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, subscriptions)
|
||||
}
|
||||
|
||||
// Helper function to get admin ID from context
|
||||
func getAdminIDFromContext(c *gin.Context) int64 {
|
||||
if user, exists := c.Get("user"); exists {
|
||||
if u, ok := user.(*model.User); ok && u != nil {
|
||||
return u.ID
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
82
backend/internal/handler/admin/system_handler.go
Normal file
82
backend/internal/handler/admin/system_handler.go
Normal file
@@ -0,0 +1,82 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"sub2api/internal/pkg/response"
|
||||
"sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
// SystemHandler handles system-related operations
|
||||
type SystemHandler struct {
|
||||
updateSvc *service.UpdateService
|
||||
}
|
||||
|
||||
// NewSystemHandler creates a new SystemHandler
|
||||
func NewSystemHandler(rdb *redis.Client, version, buildType string) *SystemHandler {
|
||||
return &SystemHandler{
|
||||
updateSvc: service.NewUpdateService(rdb, version, buildType),
|
||||
}
|
||||
}
|
||||
|
||||
// GetVersion returns the current version
|
||||
// GET /api/v1/admin/system/version
|
||||
func (h *SystemHandler) GetVersion(c *gin.Context) {
|
||||
info, _ := h.updateSvc.CheckUpdate(c.Request.Context(), false)
|
||||
response.Success(c, gin.H{
|
||||
"version": info.CurrentVersion,
|
||||
})
|
||||
}
|
||||
|
||||
// CheckUpdates checks for available updates
|
||||
// GET /api/v1/admin/system/check-updates
|
||||
func (h *SystemHandler) CheckUpdates(c *gin.Context) {
|
||||
force := c.Query("force") == "true"
|
||||
info, err := h.updateSvc.CheckUpdate(c.Request.Context(), force)
|
||||
if err != nil {
|
||||
response.Error(c, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
response.Success(c, info)
|
||||
}
|
||||
|
||||
// PerformUpdate downloads and applies the update
|
||||
// POST /api/v1/admin/system/update
|
||||
func (h *SystemHandler) PerformUpdate(c *gin.Context) {
|
||||
if err := h.updateSvc.PerformUpdate(c.Request.Context()); err != nil {
|
||||
response.Error(c, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
response.Success(c, gin.H{
|
||||
"message": "Update completed. Please restart the service.",
|
||||
"need_restart": true,
|
||||
})
|
||||
}
|
||||
|
||||
// Rollback restores the previous version
|
||||
// POST /api/v1/admin/system/rollback
|
||||
func (h *SystemHandler) Rollback(c *gin.Context) {
|
||||
if err := h.updateSvc.Rollback(); err != nil {
|
||||
response.Error(c, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
response.Success(c, gin.H{
|
||||
"message": "Rollback completed. Please restart the service.",
|
||||
"need_restart": true,
|
||||
})
|
||||
}
|
||||
|
||||
// RestartService restarts the systemd service
|
||||
// POST /api/v1/admin/system/restart
|
||||
func (h *SystemHandler) RestartService(c *gin.Context) {
|
||||
if err := h.updateSvc.RestartService(); err != nil {
|
||||
response.Error(c, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
response.Success(c, gin.H{
|
||||
"message": "Service restart initiated",
|
||||
})
|
||||
}
|
||||
262
backend/internal/handler/admin/usage_handler.go
Normal file
262
backend/internal/handler/admin/usage_handler.go
Normal file
@@ -0,0 +1,262 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"sub2api/internal/pkg/response"
|
||||
"sub2api/internal/pkg/timezone"
|
||||
"sub2api/internal/repository"
|
||||
"sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// UsageHandler handles admin usage-related requests
|
||||
type UsageHandler struct {
|
||||
usageRepo *repository.UsageLogRepository
|
||||
apiKeyRepo *repository.ApiKeyRepository
|
||||
usageService *service.UsageService
|
||||
adminService service.AdminService
|
||||
}
|
||||
|
||||
// NewUsageHandler creates a new admin usage handler
|
||||
func NewUsageHandler(
|
||||
usageRepo *repository.UsageLogRepository,
|
||||
apiKeyRepo *repository.ApiKeyRepository,
|
||||
usageService *service.UsageService,
|
||||
adminService service.AdminService,
|
||||
) *UsageHandler {
|
||||
return &UsageHandler{
|
||||
usageRepo: usageRepo,
|
||||
apiKeyRepo: apiKeyRepo,
|
||||
usageService: usageService,
|
||||
adminService: adminService,
|
||||
}
|
||||
}
|
||||
|
||||
// List handles listing all usage records with filters
|
||||
// GET /api/v1/admin/usage
|
||||
func (h *UsageHandler) List(c *gin.Context) {
|
||||
page, pageSize := response.ParsePagination(c)
|
||||
|
||||
// Parse filters
|
||||
var userID, apiKeyID int64
|
||||
if userIDStr := c.Query("user_id"); userIDStr != "" {
|
||||
id, err := strconv.ParseInt(userIDStr, 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid user_id")
|
||||
return
|
||||
}
|
||||
userID = id
|
||||
}
|
||||
|
||||
if apiKeyIDStr := c.Query("api_key_id"); apiKeyIDStr != "" {
|
||||
id, err := strconv.ParseInt(apiKeyIDStr, 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid api_key_id")
|
||||
return
|
||||
}
|
||||
apiKeyID = id
|
||||
}
|
||||
|
||||
// Parse date range
|
||||
var startTime, endTime *time.Time
|
||||
if startDateStr := c.Query("start_date"); startDateStr != "" {
|
||||
t, err := timezone.ParseInLocation("2006-01-02", startDateStr)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid start_date format, use YYYY-MM-DD")
|
||||
return
|
||||
}
|
||||
startTime = &t
|
||||
}
|
||||
|
||||
if endDateStr := c.Query("end_date"); endDateStr != "" {
|
||||
t, err := timezone.ParseInLocation("2006-01-02", endDateStr)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid end_date format, use YYYY-MM-DD")
|
||||
return
|
||||
}
|
||||
// Set end time to end of day
|
||||
t = t.Add(24*time.Hour - time.Nanosecond)
|
||||
endTime = &t
|
||||
}
|
||||
|
||||
params := repository.PaginationParams{Page: page, PageSize: pageSize}
|
||||
filters := repository.UsageLogFilters{
|
||||
UserID: userID,
|
||||
ApiKeyID: apiKeyID,
|
||||
StartTime: startTime,
|
||||
EndTime: endTime,
|
||||
}
|
||||
|
||||
records, result, err := h.usageRepo.ListWithFilters(c.Request.Context(), params, filters)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to list usage records: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
response.Paginated(c, records, result.Total, page, pageSize)
|
||||
}
|
||||
|
||||
// Stats handles getting usage statistics with filters
|
||||
// GET /api/v1/admin/usage/stats
|
||||
func (h *UsageHandler) Stats(c *gin.Context) {
|
||||
// Parse filters
|
||||
var userID, apiKeyID int64
|
||||
if userIDStr := c.Query("user_id"); userIDStr != "" {
|
||||
id, err := strconv.ParseInt(userIDStr, 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid user_id")
|
||||
return
|
||||
}
|
||||
userID = id
|
||||
}
|
||||
|
||||
if apiKeyIDStr := c.Query("api_key_id"); apiKeyIDStr != "" {
|
||||
id, err := strconv.ParseInt(apiKeyIDStr, 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid api_key_id")
|
||||
return
|
||||
}
|
||||
apiKeyID = id
|
||||
}
|
||||
|
||||
// Parse date range
|
||||
now := timezone.Now()
|
||||
var startTime, endTime time.Time
|
||||
|
||||
startDateStr := c.Query("start_date")
|
||||
endDateStr := c.Query("end_date")
|
||||
|
||||
if startDateStr != "" && endDateStr != "" {
|
||||
var err error
|
||||
startTime, err = timezone.ParseInLocation("2006-01-02", startDateStr)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid start_date format, use YYYY-MM-DD")
|
||||
return
|
||||
}
|
||||
endTime, err = timezone.ParseInLocation("2006-01-02", endDateStr)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid end_date format, use YYYY-MM-DD")
|
||||
return
|
||||
}
|
||||
endTime = endTime.Add(24*time.Hour - time.Nanosecond)
|
||||
} else {
|
||||
period := c.DefaultQuery("period", "today")
|
||||
switch period {
|
||||
case "today":
|
||||
startTime = timezone.StartOfDay(now)
|
||||
case "week":
|
||||
startTime = now.AddDate(0, 0, -7)
|
||||
case "month":
|
||||
startTime = now.AddDate(0, -1, 0)
|
||||
default:
|
||||
startTime = timezone.StartOfDay(now)
|
||||
}
|
||||
endTime = now
|
||||
}
|
||||
|
||||
if apiKeyID > 0 {
|
||||
stats, err := h.usageService.GetStatsByApiKey(c.Request.Context(), apiKeyID, startTime, endTime)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to get usage statistics: "+err.Error())
|
||||
return
|
||||
}
|
||||
response.Success(c, stats)
|
||||
return
|
||||
}
|
||||
|
||||
if userID > 0 {
|
||||
stats, err := h.usageService.GetStatsByUser(c.Request.Context(), userID, startTime, endTime)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to get usage statistics: "+err.Error())
|
||||
return
|
||||
}
|
||||
response.Success(c, stats)
|
||||
return
|
||||
}
|
||||
|
||||
// Get global stats
|
||||
stats, err := h.usageRepo.GetGlobalStats(c.Request.Context(), startTime, endTime)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to get usage statistics: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, stats)
|
||||
}
|
||||
|
||||
// SearchUsers handles searching users by email keyword
|
||||
// GET /api/v1/admin/usage/search-users
|
||||
func (h *UsageHandler) SearchUsers(c *gin.Context) {
|
||||
keyword := c.Query("q")
|
||||
if keyword == "" {
|
||||
response.Success(c, []interface{}{})
|
||||
return
|
||||
}
|
||||
|
||||
// Limit to 30 results
|
||||
users, _, err := h.adminService.ListUsers(c.Request.Context(), 1, 30, "", "", keyword)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to search users: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// Return simplified user list (only id and email)
|
||||
type SimpleUser struct {
|
||||
ID int64 `json:"id"`
|
||||
Email string `json:"email"`
|
||||
}
|
||||
|
||||
result := make([]SimpleUser, len(users))
|
||||
for i, u := range users {
|
||||
result[i] = SimpleUser{
|
||||
ID: u.ID,
|
||||
Email: u.Email,
|
||||
}
|
||||
}
|
||||
|
||||
response.Success(c, result)
|
||||
}
|
||||
|
||||
// SearchApiKeys handles searching API keys by user
|
||||
// GET /api/v1/admin/usage/search-api-keys
|
||||
func (h *UsageHandler) SearchApiKeys(c *gin.Context) {
|
||||
userIDStr := c.Query("user_id")
|
||||
keyword := c.Query("q")
|
||||
|
||||
var userID int64
|
||||
if userIDStr != "" {
|
||||
id, err := strconv.ParseInt(userIDStr, 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid user_id")
|
||||
return
|
||||
}
|
||||
userID = id
|
||||
}
|
||||
|
||||
keys, err := h.apiKeyRepo.SearchApiKeys(c.Request.Context(), userID, keyword, 30)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to search API keys: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// Return simplified API key list (only id and name)
|
||||
type SimpleApiKey struct {
|
||||
ID int64 `json:"id"`
|
||||
Name string `json:"name"`
|
||||
UserID int64 `json:"user_id"`
|
||||
}
|
||||
|
||||
result := make([]SimpleApiKey, len(keys))
|
||||
for i, k := range keys {
|
||||
result[i] = SimpleApiKey{
|
||||
ID: k.ID,
|
||||
Name: k.Name,
|
||||
UserID: k.UserID,
|
||||
}
|
||||
}
|
||||
|
||||
response.Success(c, result)
|
||||
}
|
||||
221
backend/internal/handler/admin/user_handler.go
Normal file
221
backend/internal/handler/admin/user_handler.go
Normal file
@@ -0,0 +1,221 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
|
||||
"sub2api/internal/pkg/response"
|
||||
"sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// UserHandler handles admin user management
|
||||
type UserHandler struct {
|
||||
adminService service.AdminService
|
||||
}
|
||||
|
||||
// NewUserHandler creates a new admin user handler
|
||||
func NewUserHandler(adminService service.AdminService) *UserHandler {
|
||||
return &UserHandler{
|
||||
adminService: adminService,
|
||||
}
|
||||
}
|
||||
|
||||
// CreateUserRequest represents admin create user request
|
||||
type CreateUserRequest struct {
|
||||
Email string `json:"email" binding:"required,email"`
|
||||
Password string `json:"password" binding:"required,min=6"`
|
||||
Balance float64 `json:"balance"`
|
||||
Concurrency int `json:"concurrency"`
|
||||
AllowedGroups []int64 `json:"allowed_groups"`
|
||||
}
|
||||
|
||||
// UpdateUserRequest represents admin update user request
|
||||
// 使用指针类型来区分"未提供"和"设置为0"
|
||||
type UpdateUserRequest struct {
|
||||
Email string `json:"email" binding:"omitempty,email"`
|
||||
Password string `json:"password" binding:"omitempty,min=6"`
|
||||
Balance *float64 `json:"balance"`
|
||||
Concurrency *int `json:"concurrency"`
|
||||
Status string `json:"status" binding:"omitempty,oneof=active disabled"`
|
||||
AllowedGroups *[]int64 `json:"allowed_groups"`
|
||||
}
|
||||
|
||||
// UpdateBalanceRequest represents balance update request
|
||||
type UpdateBalanceRequest struct {
|
||||
Balance float64 `json:"balance" binding:"required"`
|
||||
Operation string `json:"operation" binding:"required,oneof=set add subtract"`
|
||||
}
|
||||
|
||||
// List handles listing all users with pagination
|
||||
// GET /api/v1/admin/users
|
||||
func (h *UserHandler) List(c *gin.Context) {
|
||||
page, pageSize := response.ParsePagination(c)
|
||||
status := c.Query("status")
|
||||
role := c.Query("role")
|
||||
search := c.Query("search")
|
||||
|
||||
users, total, err := h.adminService.ListUsers(c.Request.Context(), page, pageSize, status, role, search)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to list users: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
response.Paginated(c, users, total, page, pageSize)
|
||||
}
|
||||
|
||||
// GetByID handles getting a user by ID
|
||||
// GET /api/v1/admin/users/:id
|
||||
func (h *UserHandler) GetByID(c *gin.Context) {
|
||||
userID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid user ID")
|
||||
return
|
||||
}
|
||||
|
||||
user, err := h.adminService.GetUser(c.Request.Context(), userID)
|
||||
if err != nil {
|
||||
response.NotFound(c, "User not found")
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, user)
|
||||
}
|
||||
|
||||
// Create handles creating a new user
|
||||
// POST /api/v1/admin/users
|
||||
func (h *UserHandler) Create(c *gin.Context) {
|
||||
var req CreateUserRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
user, err := h.adminService.CreateUser(c.Request.Context(), &service.CreateUserInput{
|
||||
Email: req.Email,
|
||||
Password: req.Password,
|
||||
Balance: req.Balance,
|
||||
Concurrency: req.Concurrency,
|
||||
AllowedGroups: req.AllowedGroups,
|
||||
})
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Failed to create user: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, user)
|
||||
}
|
||||
|
||||
// Update handles updating a user
|
||||
// PUT /api/v1/admin/users/:id
|
||||
func (h *UserHandler) Update(c *gin.Context) {
|
||||
userID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid user ID")
|
||||
return
|
||||
}
|
||||
|
||||
var req UpdateUserRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 使用指针类型直接传递,nil 表示未提供该字段
|
||||
user, err := h.adminService.UpdateUser(c.Request.Context(), userID, &service.UpdateUserInput{
|
||||
Email: req.Email,
|
||||
Password: req.Password,
|
||||
Balance: req.Balance,
|
||||
Concurrency: req.Concurrency,
|
||||
Status: req.Status,
|
||||
AllowedGroups: req.AllowedGroups,
|
||||
})
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to update user: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, user)
|
||||
}
|
||||
|
||||
// Delete handles deleting a user
|
||||
// DELETE /api/v1/admin/users/:id
|
||||
func (h *UserHandler) Delete(c *gin.Context) {
|
||||
userID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid user ID")
|
||||
return
|
||||
}
|
||||
|
||||
err = h.adminService.DeleteUser(c.Request.Context(), userID)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to delete user: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{"message": "User deleted successfully"})
|
||||
}
|
||||
|
||||
// UpdateBalance handles updating user balance
|
||||
// POST /api/v1/admin/users/:id/balance
|
||||
func (h *UserHandler) UpdateBalance(c *gin.Context) {
|
||||
userID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid user ID")
|
||||
return
|
||||
}
|
||||
|
||||
var req UpdateBalanceRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
user, err := h.adminService.UpdateUserBalance(c.Request.Context(), userID, req.Balance, req.Operation)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to update balance: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, user)
|
||||
}
|
||||
|
||||
// GetUserAPIKeys handles getting user's API keys
|
||||
// GET /api/v1/admin/users/:id/api-keys
|
||||
func (h *UserHandler) GetUserAPIKeys(c *gin.Context) {
|
||||
userID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid user ID")
|
||||
return
|
||||
}
|
||||
|
||||
page, pageSize := response.ParsePagination(c)
|
||||
|
||||
keys, total, err := h.adminService.GetUserAPIKeys(c.Request.Context(), userID, page, pageSize)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to get user API keys: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
response.Paginated(c, keys, total, page, pageSize)
|
||||
}
|
||||
|
||||
// GetUserUsage handles getting user's usage statistics
|
||||
// GET /api/v1/admin/users/:id/usage
|
||||
func (h *UserHandler) GetUserUsage(c *gin.Context) {
|
||||
userID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid user ID")
|
||||
return
|
||||
}
|
||||
|
||||
period := c.DefaultQuery("period", "month")
|
||||
|
||||
stats, err := h.adminService.GetUserUsageStats(c.Request.Context(), userID, period)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to get user usage: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, stats)
|
||||
}
|
||||
235
backend/internal/handler/api_key_handler.go
Normal file
235
backend/internal/handler/api_key_handler.go
Normal file
@@ -0,0 +1,235 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/pkg/response"
|
||||
"sub2api/internal/repository"
|
||||
"sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// APIKeyHandler handles API key-related requests
|
||||
type APIKeyHandler struct {
|
||||
apiKeyService *service.ApiKeyService
|
||||
}
|
||||
|
||||
// NewAPIKeyHandler creates a new APIKeyHandler
|
||||
func NewAPIKeyHandler(apiKeyService *service.ApiKeyService) *APIKeyHandler {
|
||||
return &APIKeyHandler{
|
||||
apiKeyService: apiKeyService,
|
||||
}
|
||||
}
|
||||
|
||||
// CreateAPIKeyRequest represents the create API key request payload
|
||||
type CreateAPIKeyRequest struct {
|
||||
Name string `json:"name" binding:"required"`
|
||||
GroupID *int64 `json:"group_id"` // nullable
|
||||
CustomKey *string `json:"custom_key"` // 可选的自定义key
|
||||
}
|
||||
|
||||
// UpdateAPIKeyRequest represents the update API key request payload
|
||||
type UpdateAPIKeyRequest struct {
|
||||
Name string `json:"name"`
|
||||
GroupID *int64 `json:"group_id"`
|
||||
Status string `json:"status" binding:"omitempty,oneof=active inactive"`
|
||||
}
|
||||
|
||||
// List handles listing user's API keys with pagination
|
||||
// GET /api/v1/api-keys
|
||||
func (h *APIKeyHandler) List(c *gin.Context) {
|
||||
userValue, exists := c.Get("user")
|
||||
if !exists {
|
||||
response.Unauthorized(c, "User not authenticated")
|
||||
return
|
||||
}
|
||||
|
||||
user, ok := userValue.(*model.User)
|
||||
if !ok {
|
||||
response.InternalError(c, "Invalid user context")
|
||||
return
|
||||
}
|
||||
|
||||
page, pageSize := response.ParsePagination(c)
|
||||
params := repository.PaginationParams{Page: page, PageSize: pageSize}
|
||||
|
||||
keys, result, err := h.apiKeyService.List(c.Request.Context(), user.ID, params)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to list API keys: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
response.Paginated(c, keys, result.Total, page, pageSize)
|
||||
}
|
||||
|
||||
// GetByID handles getting a single API key
|
||||
// GET /api/v1/api-keys/:id
|
||||
func (h *APIKeyHandler) GetByID(c *gin.Context) {
|
||||
userValue, exists := c.Get("user")
|
||||
if !exists {
|
||||
response.Unauthorized(c, "User not authenticated")
|
||||
return
|
||||
}
|
||||
|
||||
user, ok := userValue.(*model.User)
|
||||
if !ok {
|
||||
response.InternalError(c, "Invalid user context")
|
||||
return
|
||||
}
|
||||
|
||||
keyID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid key ID")
|
||||
return
|
||||
}
|
||||
|
||||
key, err := h.apiKeyService.GetByID(c.Request.Context(), keyID)
|
||||
if err != nil {
|
||||
response.NotFound(c, "API key not found")
|
||||
return
|
||||
}
|
||||
|
||||
// 验证所有权
|
||||
if key.UserID != user.ID {
|
||||
response.Forbidden(c, "Not authorized to access this key")
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, key)
|
||||
}
|
||||
|
||||
// Create handles creating a new API key
|
||||
// POST /api/v1/api-keys
|
||||
func (h *APIKeyHandler) Create(c *gin.Context) {
|
||||
userValue, exists := c.Get("user")
|
||||
if !exists {
|
||||
response.Unauthorized(c, "User not authenticated")
|
||||
return
|
||||
}
|
||||
|
||||
user, ok := userValue.(*model.User)
|
||||
if !ok {
|
||||
response.InternalError(c, "Invalid user context")
|
||||
return
|
||||
}
|
||||
|
||||
var req CreateAPIKeyRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
svcReq := service.CreateApiKeyRequest{
|
||||
Name: req.Name,
|
||||
GroupID: req.GroupID,
|
||||
CustomKey: req.CustomKey,
|
||||
}
|
||||
key, err := h.apiKeyService.Create(c.Request.Context(), user.ID, svcReq)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to create API key: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, key)
|
||||
}
|
||||
|
||||
// Update handles updating an API key
|
||||
// PUT /api/v1/api-keys/:id
|
||||
func (h *APIKeyHandler) Update(c *gin.Context) {
|
||||
userValue, exists := c.Get("user")
|
||||
if !exists {
|
||||
response.Unauthorized(c, "User not authenticated")
|
||||
return
|
||||
}
|
||||
|
||||
user, ok := userValue.(*model.User)
|
||||
if !ok {
|
||||
response.InternalError(c, "Invalid user context")
|
||||
return
|
||||
}
|
||||
|
||||
keyID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid key ID")
|
||||
return
|
||||
}
|
||||
|
||||
var req UpdateAPIKeyRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
svcReq := service.UpdateApiKeyRequest{}
|
||||
if req.Name != "" {
|
||||
svcReq.Name = &req.Name
|
||||
}
|
||||
svcReq.GroupID = req.GroupID
|
||||
if req.Status != "" {
|
||||
svcReq.Status = &req.Status
|
||||
}
|
||||
|
||||
key, err := h.apiKeyService.Update(c.Request.Context(), keyID, user.ID, svcReq)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to update API key: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, key)
|
||||
}
|
||||
|
||||
// Delete handles deleting an API key
|
||||
// DELETE /api/v1/api-keys/:id
|
||||
func (h *APIKeyHandler) Delete(c *gin.Context) {
|
||||
userValue, exists := c.Get("user")
|
||||
if !exists {
|
||||
response.Unauthorized(c, "User not authenticated")
|
||||
return
|
||||
}
|
||||
|
||||
user, ok := userValue.(*model.User)
|
||||
if !ok {
|
||||
response.InternalError(c, "Invalid user context")
|
||||
return
|
||||
}
|
||||
|
||||
keyID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid key ID")
|
||||
return
|
||||
}
|
||||
|
||||
err = h.apiKeyService.Delete(c.Request.Context(), keyID, user.ID)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to delete API key: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{"message": "API key deleted successfully"})
|
||||
}
|
||||
|
||||
// GetAvailableGroups 获取用户可以绑定的分组列表
|
||||
// GET /api/v1/groups/available
|
||||
func (h *APIKeyHandler) GetAvailableGroups(c *gin.Context) {
|
||||
userValue, exists := c.Get("user")
|
||||
if !exists {
|
||||
response.Unauthorized(c, "User not authenticated")
|
||||
return
|
||||
}
|
||||
|
||||
user, ok := userValue.(*model.User)
|
||||
if !ok {
|
||||
response.InternalError(c, "Invalid user context")
|
||||
return
|
||||
}
|
||||
|
||||
groups, err := h.apiKeyService.GetAvailableGroups(c.Request.Context(), user.ID)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to get available groups: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, groups)
|
||||
}
|
||||
158
backend/internal/handler/auth_handler.go
Normal file
158
backend/internal/handler/auth_handler.go
Normal file
@@ -0,0 +1,158 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/pkg/response"
|
||||
"sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// AuthHandler handles authentication-related requests
|
||||
type AuthHandler struct {
|
||||
authService *service.AuthService
|
||||
}
|
||||
|
||||
// NewAuthHandler creates a new AuthHandler
|
||||
func NewAuthHandler(authService *service.AuthService) *AuthHandler {
|
||||
return &AuthHandler{
|
||||
authService: authService,
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterRequest represents the registration request payload
|
||||
type RegisterRequest struct {
|
||||
Email string `json:"email" binding:"required,email"`
|
||||
Password string `json:"password" binding:"required,min=6"`
|
||||
VerifyCode string `json:"verify_code"`
|
||||
TurnstileToken string `json:"turnstile_token"`
|
||||
}
|
||||
|
||||
// SendVerifyCodeRequest 发送验证码请求
|
||||
type SendVerifyCodeRequest struct {
|
||||
Email string `json:"email" binding:"required,email"`
|
||||
TurnstileToken string `json:"turnstile_token"`
|
||||
}
|
||||
|
||||
// SendVerifyCodeResponse 发送验证码响应
|
||||
type SendVerifyCodeResponse struct {
|
||||
Message string `json:"message"`
|
||||
Countdown int `json:"countdown"` // 倒计时秒数
|
||||
}
|
||||
|
||||
// LoginRequest represents the login request payload
|
||||
type LoginRequest struct {
|
||||
Email string `json:"email" binding:"required,email"`
|
||||
Password string `json:"password" binding:"required"`
|
||||
TurnstileToken string `json:"turnstile_token"`
|
||||
}
|
||||
|
||||
// AuthResponse 认证响应格式(匹配前端期望)
|
||||
type AuthResponse struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
TokenType string `json:"token_type"`
|
||||
User *model.User `json:"user"`
|
||||
}
|
||||
|
||||
// Register handles user registration
|
||||
// POST /api/v1/auth/register
|
||||
func (h *AuthHandler) Register(c *gin.Context) {
|
||||
var req RegisterRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// Turnstile 验证(当提供了邮箱验证码时跳过,因为发送验证码时已验证过)
|
||||
if req.VerifyCode == "" {
|
||||
if err := h.authService.VerifyTurnstile(c.Request.Context(), req.TurnstileToken, c.ClientIP()); err != nil {
|
||||
response.BadRequest(c, "Turnstile verification failed: "+err.Error())
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
token, user, err := h.authService.RegisterWithVerification(c.Request.Context(), req.Email, req.Password, req.VerifyCode)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Registration failed: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, AuthResponse{
|
||||
AccessToken: token,
|
||||
TokenType: "Bearer",
|
||||
User: user,
|
||||
})
|
||||
}
|
||||
|
||||
// SendVerifyCode 发送邮箱验证码
|
||||
// POST /api/v1/auth/send-verify-code
|
||||
func (h *AuthHandler) SendVerifyCode(c *gin.Context) {
|
||||
var req SendVerifyCodeRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// Turnstile 验证
|
||||
if err := h.authService.VerifyTurnstile(c.Request.Context(), req.TurnstileToken, c.ClientIP()); err != nil {
|
||||
response.BadRequest(c, "Turnstile verification failed: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
result, err := h.authService.SendVerifyCodeAsync(c.Request.Context(), req.Email)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Failed to send verification code: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, SendVerifyCodeResponse{
|
||||
Message: "Verification code sent successfully",
|
||||
Countdown: result.Countdown,
|
||||
})
|
||||
}
|
||||
|
||||
// Login handles user login
|
||||
// POST /api/v1/auth/login
|
||||
func (h *AuthHandler) Login(c *gin.Context) {
|
||||
var req LoginRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// Turnstile 验证
|
||||
if err := h.authService.VerifyTurnstile(c.Request.Context(), req.TurnstileToken, c.ClientIP()); err != nil {
|
||||
response.BadRequest(c, "Turnstile verification failed: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
token, user, err := h.authService.Login(c.Request.Context(), req.Email, req.Password)
|
||||
if err != nil {
|
||||
response.Unauthorized(c, "Login failed: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, AuthResponse{
|
||||
AccessToken: token,
|
||||
TokenType: "Bearer",
|
||||
User: user,
|
||||
})
|
||||
}
|
||||
|
||||
// GetCurrentUser handles getting current authenticated user
|
||||
// GET /api/v1/auth/me
|
||||
func (h *AuthHandler) GetCurrentUser(c *gin.Context) {
|
||||
userValue, exists := c.Get("user")
|
||||
if !exists {
|
||||
response.Unauthorized(c, "User not authenticated")
|
||||
return
|
||||
}
|
||||
|
||||
user, ok := userValue.(*model.User)
|
||||
if !ok {
|
||||
response.InternalError(c, "Invalid user context")
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, user)
|
||||
}
|
||||
445
backend/internal/handler/gateway_handler.go
Normal file
445
backend/internal/handler/gateway_handler.go
Normal file
@@ -0,0 +1,445 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"sub2api/internal/middleware"
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
const (
|
||||
// Maximum wait time for concurrency slot
|
||||
maxConcurrencyWait = 60 * time.Second
|
||||
// Ping interval during wait
|
||||
pingInterval = 5 * time.Second
|
||||
)
|
||||
|
||||
// GatewayHandler handles API gateway requests
|
||||
type GatewayHandler struct {
|
||||
gatewayService *service.GatewayService
|
||||
userService *service.UserService
|
||||
concurrencyService *service.ConcurrencyService
|
||||
billingCacheService *service.BillingCacheService
|
||||
}
|
||||
|
||||
// NewGatewayHandler creates a new GatewayHandler
|
||||
func NewGatewayHandler(gatewayService *service.GatewayService, userService *service.UserService, concurrencyService *service.ConcurrencyService, billingCacheService *service.BillingCacheService) *GatewayHandler {
|
||||
return &GatewayHandler{
|
||||
gatewayService: gatewayService,
|
||||
userService: userService,
|
||||
concurrencyService: concurrencyService,
|
||||
billingCacheService: billingCacheService,
|
||||
}
|
||||
}
|
||||
|
||||
// Messages handles Claude API compatible messages endpoint
|
||||
// POST /v1/messages
|
||||
func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
// 从context获取apiKey和user(ApiKeyAuth中间件已设置)
|
||||
apiKey, ok := middleware.GetApiKeyFromContext(c)
|
||||
if !ok {
|
||||
h.errorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key")
|
||||
return
|
||||
}
|
||||
|
||||
user, ok := middleware.GetUserFromContext(c)
|
||||
if !ok {
|
||||
h.errorResponse(c, http.StatusInternalServerError, "api_error", "User context not found")
|
||||
return
|
||||
}
|
||||
|
||||
// 读取请求体
|
||||
body, err := io.ReadAll(c.Request.Body)
|
||||
if err != nil {
|
||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to read request body")
|
||||
return
|
||||
}
|
||||
|
||||
if len(body) == 0 {
|
||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Request body is empty")
|
||||
return
|
||||
}
|
||||
|
||||
// 解析请求获取模型名和stream
|
||||
var req struct {
|
||||
Model string `json:"model"`
|
||||
Stream bool `json:"stream"`
|
||||
}
|
||||
if err := json.Unmarshal(body, &req); err != nil {
|
||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body")
|
||||
return
|
||||
}
|
||||
|
||||
// Track if we've started streaming (for error handling)
|
||||
streamStarted := false
|
||||
|
||||
// 获取订阅信息(可能为nil)- 提前获取用于后续检查
|
||||
subscription, _ := middleware.GetSubscriptionFromContext(c)
|
||||
|
||||
// 0. 检查wait队列是否已满
|
||||
maxWait := service.CalculateMaxWait(user.Concurrency)
|
||||
canWait, err := h.concurrencyService.IncrementWaitCount(c.Request.Context(), user.ID, maxWait)
|
||||
if err != nil {
|
||||
log.Printf("Increment wait count failed: %v", err)
|
||||
// On error, allow request to proceed
|
||||
} else if !canWait {
|
||||
h.errorResponse(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later")
|
||||
return
|
||||
}
|
||||
// 确保在函数退出时减少wait计数
|
||||
defer h.concurrencyService.DecrementWaitCount(c.Request.Context(), user.ID)
|
||||
|
||||
// 1. 首先获取用户并发槽位
|
||||
userReleaseFunc, err := h.acquireUserSlotWithWait(c, user, req.Stream, &streamStarted)
|
||||
if err != nil {
|
||||
log.Printf("User concurrency acquire failed: %v", err)
|
||||
h.handleConcurrencyError(c, err, "user", streamStarted)
|
||||
return
|
||||
}
|
||||
if userReleaseFunc != nil {
|
||||
defer userReleaseFunc()
|
||||
}
|
||||
|
||||
// 2. 【新增】Wait后二次检查余额/订阅
|
||||
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), user, apiKey, apiKey.Group, subscription); err != nil {
|
||||
log.Printf("Billing eligibility check failed after wait: %v", err)
|
||||
h.handleStreamingAwareError(c, http.StatusForbidden, "billing_error", err.Error(), streamStarted)
|
||||
return
|
||||
}
|
||||
|
||||
// 计算粘性会话hash
|
||||
sessionHash := h.gatewayService.GenerateSessionHash(body)
|
||||
|
||||
// 选择支持该模型的账号
|
||||
account, err := h.gatewayService.SelectAccountForModel(c.Request.Context(), apiKey.GroupID, sessionHash, req.Model)
|
||||
if err != nil {
|
||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
|
||||
return
|
||||
}
|
||||
|
||||
// 3. 获取账号并发槽位
|
||||
accountReleaseFunc, err := h.acquireAccountSlotWithWait(c, account, req.Stream, &streamStarted)
|
||||
if err != nil {
|
||||
log.Printf("Account concurrency acquire failed: %v", err)
|
||||
h.handleConcurrencyError(c, err, "account", streamStarted)
|
||||
return
|
||||
}
|
||||
if accountReleaseFunc != nil {
|
||||
defer accountReleaseFunc()
|
||||
}
|
||||
|
||||
// 转发请求
|
||||
result, err := h.gatewayService.Forward(c.Request.Context(), c, account, body)
|
||||
if err != nil {
|
||||
// 错误响应已在Forward中处理,这里只记录日志
|
||||
log.Printf("Forward request failed: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// 异步记录使用量(subscription已在函数开头获取)
|
||||
go func() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
|
||||
Result: result,
|
||||
ApiKey: apiKey,
|
||||
User: user,
|
||||
Account: account,
|
||||
Subscription: subscription,
|
||||
}); err != nil {
|
||||
log.Printf("Record usage failed: %v", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// acquireUserSlotWithWait acquires a user concurrency slot, waiting if necessary
|
||||
// For streaming requests, sends ping events during the wait
|
||||
// streamStarted is updated if streaming response has begun
|
||||
func (h *GatewayHandler) acquireUserSlotWithWait(c *gin.Context, user *model.User, isStream bool, streamStarted *bool) (func(), error) {
|
||||
ctx := c.Request.Context()
|
||||
|
||||
// Try to acquire immediately
|
||||
result, err := h.concurrencyService.AcquireUserSlot(ctx, user.ID, user.Concurrency)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if result.Acquired {
|
||||
return result.ReleaseFunc, nil
|
||||
}
|
||||
|
||||
// Need to wait - handle streaming ping if needed
|
||||
return h.waitForSlotWithPing(c, "user", user.ID, user.Concurrency, isStream, streamStarted)
|
||||
}
|
||||
|
||||
// acquireAccountSlotWithWait acquires an account concurrency slot, waiting if necessary
|
||||
// For streaming requests, sends ping events during the wait
|
||||
// streamStarted is updated if streaming response has begun
|
||||
func (h *GatewayHandler) acquireAccountSlotWithWait(c *gin.Context, account *model.Account, isStream bool, streamStarted *bool) (func(), error) {
|
||||
ctx := c.Request.Context()
|
||||
|
||||
// Try to acquire immediately
|
||||
result, err := h.concurrencyService.AcquireAccountSlot(ctx, account.ID, account.Concurrency)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if result.Acquired {
|
||||
return result.ReleaseFunc, nil
|
||||
}
|
||||
|
||||
// Need to wait - handle streaming ping if needed
|
||||
return h.waitForSlotWithPing(c, "account", account.ID, account.Concurrency, isStream, streamStarted)
|
||||
}
|
||||
|
||||
// concurrencyError represents a concurrency limit error with context
|
||||
type concurrencyError struct {
|
||||
SlotType string
|
||||
IsTimeout bool
|
||||
}
|
||||
|
||||
func (e *concurrencyError) Error() string {
|
||||
if e.IsTimeout {
|
||||
return fmt.Sprintf("timeout waiting for %s concurrency slot", e.SlotType)
|
||||
}
|
||||
return fmt.Sprintf("%s concurrency limit reached", e.SlotType)
|
||||
}
|
||||
|
||||
// waitForSlotWithPing waits for a concurrency slot, sending ping events for streaming requests
|
||||
// Note: For streaming requests, we send ping to keep the connection alive.
|
||||
// streamStarted pointer is updated when streaming begins (for proper error handling by caller)
|
||||
func (h *GatewayHandler) waitForSlotWithPing(c *gin.Context, slotType string, id int64, maxConcurrency int, isStream bool, streamStarted *bool) (func(), error) {
|
||||
ctx, cancel := context.WithTimeout(c.Request.Context(), maxConcurrencyWait)
|
||||
defer cancel()
|
||||
|
||||
// For streaming requests, set up SSE headers for ping
|
||||
var flusher http.Flusher
|
||||
if isStream {
|
||||
var ok bool
|
||||
flusher, ok = c.Writer.(http.Flusher)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("streaming not supported")
|
||||
}
|
||||
}
|
||||
|
||||
pingTicker := time.NewTicker(pingInterval)
|
||||
defer pingTicker.Stop()
|
||||
|
||||
pollTicker := time.NewTicker(100 * time.Millisecond)
|
||||
defer pollTicker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, &concurrencyError{
|
||||
SlotType: slotType,
|
||||
IsTimeout: true,
|
||||
}
|
||||
|
||||
case <-pingTicker.C:
|
||||
// Send ping for streaming requests to keep connection alive
|
||||
if isStream && flusher != nil {
|
||||
// Set headers on first ping (lazy initialization)
|
||||
if !*streamStarted {
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
c.Header("X-Accel-Buffering", "no")
|
||||
*streamStarted = true
|
||||
}
|
||||
fmt.Fprintf(c.Writer, "data: {\"type\": \"ping\"}\n\n")
|
||||
flusher.Flush()
|
||||
}
|
||||
|
||||
case <-pollTicker.C:
|
||||
// Try to acquire slot
|
||||
var result *service.AcquireResult
|
||||
var err error
|
||||
|
||||
if slotType == "user" {
|
||||
result, err = h.concurrencyService.AcquireUserSlot(ctx, id, maxConcurrency)
|
||||
} else {
|
||||
result, err = h.concurrencyService.AcquireAccountSlot(ctx, id, maxConcurrency)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if result.Acquired {
|
||||
return result.ReleaseFunc, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Models handles listing available models
|
||||
// GET /v1/models
|
||||
func (h *GatewayHandler) Models(c *gin.Context) {
|
||||
models := []gin.H{
|
||||
{
|
||||
"id": "claude-opus-4-5-20251101",
|
||||
"type": "model",
|
||||
"display_name": "Claude Opus 4.5",
|
||||
"created_at": "2025-11-01T00:00:00Z",
|
||||
},
|
||||
{
|
||||
"id": "claude-sonnet-4-5-20250929",
|
||||
"type": "model",
|
||||
"display_name": "Claude Sonnet 4.5",
|
||||
"created_at": "2025-09-29T00:00:00Z",
|
||||
},
|
||||
{
|
||||
"id": "claude-haiku-4-5-20251001",
|
||||
"type": "model",
|
||||
"display_name": "Claude Haiku 4.5",
|
||||
"created_at": "2025-10-01T00:00:00Z",
|
||||
},
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"data": models,
|
||||
"object": "list",
|
||||
})
|
||||
}
|
||||
|
||||
// Usage handles getting account balance for CC Switch integration
|
||||
// GET /v1/usage
|
||||
func (h *GatewayHandler) Usage(c *gin.Context) {
|
||||
apiKey, ok := middleware.GetApiKeyFromContext(c)
|
||||
if !ok {
|
||||
h.errorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key")
|
||||
return
|
||||
}
|
||||
|
||||
user, ok := middleware.GetUserFromContext(c)
|
||||
if !ok {
|
||||
h.errorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key")
|
||||
return
|
||||
}
|
||||
|
||||
// 订阅模式:返回订阅限额信息
|
||||
if apiKey.Group != nil && apiKey.Group.IsSubscriptionType() {
|
||||
subscription, ok := middleware.GetSubscriptionFromContext(c)
|
||||
if !ok {
|
||||
h.errorResponse(c, http.StatusForbidden, "subscription_error", "No active subscription")
|
||||
return
|
||||
}
|
||||
|
||||
remaining := h.calculateSubscriptionRemaining(apiKey.Group, subscription)
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"isValid": true,
|
||||
"planName": apiKey.Group.Name,
|
||||
"remaining": remaining,
|
||||
"unit": "USD",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 余额模式:返回钱包余额
|
||||
latestUser, err := h.userService.GetByID(c.Request.Context(), user.ID)
|
||||
if err != nil {
|
||||
h.errorResponse(c, http.StatusInternalServerError, "api_error", "Failed to get user info")
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"isValid": true,
|
||||
"planName": "钱包余额",
|
||||
"remaining": latestUser.Balance,
|
||||
"unit": "USD",
|
||||
})
|
||||
}
|
||||
|
||||
// calculateSubscriptionRemaining 计算订阅剩余可用额度
|
||||
// 逻辑:
|
||||
// 1. 如果日/周/月任一限额达到100%,返回0
|
||||
// 2. 否则返回所有已配置周期中剩余额度的最小值
|
||||
func (h *GatewayHandler) calculateSubscriptionRemaining(group *model.Group, sub *model.UserSubscription) float64 {
|
||||
var remainingValues []float64
|
||||
|
||||
// 检查日限额
|
||||
if group.HasDailyLimit() {
|
||||
remaining := *group.DailyLimitUSD - sub.DailyUsageUSD
|
||||
if remaining <= 0 {
|
||||
return 0
|
||||
}
|
||||
remainingValues = append(remainingValues, remaining)
|
||||
}
|
||||
|
||||
// 检查周限额
|
||||
if group.HasWeeklyLimit() {
|
||||
remaining := *group.WeeklyLimitUSD - sub.WeeklyUsageUSD
|
||||
if remaining <= 0 {
|
||||
return 0
|
||||
}
|
||||
remainingValues = append(remainingValues, remaining)
|
||||
}
|
||||
|
||||
// 检查月限额
|
||||
if group.HasMonthlyLimit() {
|
||||
remaining := *group.MonthlyLimitUSD - sub.MonthlyUsageUSD
|
||||
if remaining <= 0 {
|
||||
return 0
|
||||
}
|
||||
remainingValues = append(remainingValues, remaining)
|
||||
}
|
||||
|
||||
// 如果没有配置任何限额,返回-1表示无限制
|
||||
if len(remainingValues) == 0 {
|
||||
return -1
|
||||
}
|
||||
|
||||
// 返回最小值
|
||||
min := remainingValues[0]
|
||||
for _, v := range remainingValues[1:] {
|
||||
if v < min {
|
||||
min = v
|
||||
}
|
||||
}
|
||||
return min
|
||||
}
|
||||
|
||||
// handleConcurrencyError handles concurrency-related errors with proper 429 response
|
||||
func (h *GatewayHandler) handleConcurrencyError(c *gin.Context, err error, slotType string, streamStarted bool) {
|
||||
h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error",
|
||||
fmt.Sprintf("Concurrency limit exceeded for %s, please retry later", slotType), streamStarted)
|
||||
}
|
||||
|
||||
// handleStreamingAwareError handles errors that may occur after streaming has started
|
||||
func (h *GatewayHandler) handleStreamingAwareError(c *gin.Context, status int, errType, message string, streamStarted bool) {
|
||||
if streamStarted {
|
||||
// Stream already started, send error as SSE event then close
|
||||
flusher, ok := c.Writer.(http.Flusher)
|
||||
if ok {
|
||||
// Send error event in SSE format
|
||||
errorEvent := fmt.Sprintf(`data: {"type": "error", "error": {"type": "%s", "message": "%s"}}`+"\n\n", errType, message)
|
||||
fmt.Fprint(c.Writer, errorEvent)
|
||||
flusher.Flush()
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Normal case: return JSON response with proper status code
|
||||
h.errorResponse(c, status, errType, message)
|
||||
}
|
||||
|
||||
// errorResponse 返回Claude API格式的错误响应
|
||||
func (h *GatewayHandler) errorResponse(c *gin.Context, status int, errType, message string) {
|
||||
c.JSON(status, gin.H{
|
||||
"type": "error",
|
||||
"error": gin.H{
|
||||
"type": errType,
|
||||
"message": message,
|
||||
},
|
||||
})
|
||||
}
|
||||
70
backend/internal/handler/handler.go
Normal file
70
backend/internal/handler/handler.go
Normal file
@@ -0,0 +1,70 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"sub2api/internal/handler/admin"
|
||||
"sub2api/internal/repository"
|
||||
"sub2api/internal/service"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
// AdminHandlers contains all admin-related HTTP handlers
|
||||
type AdminHandlers struct {
|
||||
Dashboard *admin.DashboardHandler
|
||||
User *admin.UserHandler
|
||||
Group *admin.GroupHandler
|
||||
Account *admin.AccountHandler
|
||||
OAuth *admin.OAuthHandler
|
||||
Proxy *admin.ProxyHandler
|
||||
Redeem *admin.RedeemHandler
|
||||
Setting *admin.SettingHandler
|
||||
System *admin.SystemHandler
|
||||
Subscription *admin.SubscriptionHandler
|
||||
Usage *admin.UsageHandler
|
||||
}
|
||||
|
||||
// Handlers contains all HTTP handlers
|
||||
type Handlers struct {
|
||||
Auth *AuthHandler
|
||||
User *UserHandler
|
||||
APIKey *APIKeyHandler
|
||||
Usage *UsageHandler
|
||||
Redeem *RedeemHandler
|
||||
Subscription *SubscriptionHandler
|
||||
Admin *AdminHandlers
|
||||
Gateway *GatewayHandler
|
||||
Setting *SettingHandler
|
||||
}
|
||||
|
||||
// BuildInfo contains build-time information
|
||||
type BuildInfo struct {
|
||||
Version string
|
||||
BuildType string // "source" for manual builds, "release" for CI builds
|
||||
}
|
||||
|
||||
// NewHandlers creates a new Handlers instance with all handlers initialized
|
||||
func NewHandlers(services *service.Services, repos *repository.Repositories, rdb *redis.Client, buildInfo BuildInfo) *Handlers {
|
||||
return &Handlers{
|
||||
Auth: NewAuthHandler(services.Auth),
|
||||
User: NewUserHandler(services.User),
|
||||
APIKey: NewAPIKeyHandler(services.ApiKey),
|
||||
Usage: NewUsageHandler(services.Usage, repos.UsageLog, services.ApiKey),
|
||||
Redeem: NewRedeemHandler(services.Redeem),
|
||||
Subscription: NewSubscriptionHandler(services.Subscription),
|
||||
Admin: &AdminHandlers{
|
||||
Dashboard: admin.NewDashboardHandler(services.Admin, repos.UsageLog),
|
||||
User: admin.NewUserHandler(services.Admin),
|
||||
Group: admin.NewGroupHandler(services.Admin),
|
||||
Account: admin.NewAccountHandler(services.Admin, services.OAuth, services.RateLimit, services.AccountUsage, services.AccountTest),
|
||||
OAuth: admin.NewOAuthHandler(services.OAuth, services.Admin),
|
||||
Proxy: admin.NewProxyHandler(services.Admin),
|
||||
Redeem: admin.NewRedeemHandler(services.Admin),
|
||||
Setting: admin.NewSettingHandler(services.Setting, services.Email),
|
||||
System: admin.NewSystemHandler(rdb, buildInfo.Version, buildInfo.BuildType),
|
||||
Subscription: admin.NewSubscriptionHandler(services.Subscription),
|
||||
Usage: admin.NewUsageHandler(repos.UsageLog, repos.ApiKey, services.Usage, services.Admin),
|
||||
},
|
||||
Gateway: NewGatewayHandler(services.Gateway, services.User, services.Concurrency, services.BillingCache),
|
||||
Setting: NewSettingHandler(services.Setting, buildInfo.Version),
|
||||
}
|
||||
}
|
||||
92
backend/internal/handler/redeem_handler.go
Normal file
92
backend/internal/handler/redeem_handler.go
Normal file
@@ -0,0 +1,92 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/pkg/response"
|
||||
"sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// RedeemHandler handles redeem code-related requests
|
||||
type RedeemHandler struct {
|
||||
redeemService *service.RedeemService
|
||||
}
|
||||
|
||||
// NewRedeemHandler creates a new RedeemHandler
|
||||
func NewRedeemHandler(redeemService *service.RedeemService) *RedeemHandler {
|
||||
return &RedeemHandler{
|
||||
redeemService: redeemService,
|
||||
}
|
||||
}
|
||||
|
||||
// RedeemRequest represents the redeem code request payload
|
||||
type RedeemRequest struct {
|
||||
Code string `json:"code" binding:"required"`
|
||||
}
|
||||
|
||||
// RedeemResponse represents the redeem response
|
||||
type RedeemResponse struct {
|
||||
Message string `json:"message"`
|
||||
Type string `json:"type"`
|
||||
Value float64 `json:"value"`
|
||||
NewBalance *float64 `json:"new_balance,omitempty"`
|
||||
NewConcurrency *int `json:"new_concurrency,omitempty"`
|
||||
}
|
||||
|
||||
// Redeem handles redeeming a code
|
||||
// POST /api/v1/redeem
|
||||
func (h *RedeemHandler) Redeem(c *gin.Context) {
|
||||
userValue, exists := c.Get("user")
|
||||
if !exists {
|
||||
response.Unauthorized(c, "User not authenticated")
|
||||
return
|
||||
}
|
||||
|
||||
user, ok := userValue.(*model.User)
|
||||
if !ok {
|
||||
response.InternalError(c, "Invalid user context")
|
||||
return
|
||||
}
|
||||
|
||||
var req RedeemRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
result, err := h.redeemService.Redeem(c.Request.Context(), user.ID, req.Code)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Failed to redeem code: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, result)
|
||||
}
|
||||
|
||||
// GetHistory returns the user's redemption history
|
||||
// GET /api/v1/redeem/history
|
||||
func (h *RedeemHandler) GetHistory(c *gin.Context) {
|
||||
userValue, exists := c.Get("user")
|
||||
if !exists {
|
||||
response.Unauthorized(c, "User not authenticated")
|
||||
return
|
||||
}
|
||||
|
||||
user, ok := userValue.(*model.User)
|
||||
if !ok {
|
||||
response.InternalError(c, "Invalid user context")
|
||||
return
|
||||
}
|
||||
|
||||
// Default limit is 25
|
||||
limit := 25
|
||||
|
||||
codes, err := h.redeemService.GetUserHistory(c.Request.Context(), user.ID, limit)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to get history: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, codes)
|
||||
}
|
||||
35
backend/internal/handler/setting_handler.go
Normal file
35
backend/internal/handler/setting_handler.go
Normal file
@@ -0,0 +1,35 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"sub2api/internal/pkg/response"
|
||||
"sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// SettingHandler 公开设置处理器(无需认证)
|
||||
type SettingHandler struct {
|
||||
settingService *service.SettingService
|
||||
version string
|
||||
}
|
||||
|
||||
// NewSettingHandler 创建公开设置处理器
|
||||
func NewSettingHandler(settingService *service.SettingService, version string) *SettingHandler {
|
||||
return &SettingHandler{
|
||||
settingService: settingService,
|
||||
version: version,
|
||||
}
|
||||
}
|
||||
|
||||
// GetPublicSettings 获取公开设置
|
||||
// GET /api/v1/settings/public
|
||||
func (h *SettingHandler) GetPublicSettings(c *gin.Context) {
|
||||
settings, err := h.settingService.GetPublicSettings(c.Request.Context())
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to get settings: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
settings.Version = h.version
|
||||
response.Success(c, settings)
|
||||
}
|
||||
203
backend/internal/handler/subscription_handler.go
Normal file
203
backend/internal/handler/subscription_handler.go
Normal file
@@ -0,0 +1,203 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/pkg/response"
|
||||
"sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// SubscriptionSummaryItem represents a subscription item in summary
|
||||
type SubscriptionSummaryItem struct {
|
||||
ID int64 `json:"id"`
|
||||
GroupID int64 `json:"group_id"`
|
||||
GroupName string `json:"group_name"`
|
||||
Status string `json:"status"`
|
||||
DailyUsedUSD float64 `json:"daily_used_usd,omitempty"`
|
||||
DailyLimitUSD float64 `json:"daily_limit_usd,omitempty"`
|
||||
WeeklyUsedUSD float64 `json:"weekly_used_usd,omitempty"`
|
||||
WeeklyLimitUSD float64 `json:"weekly_limit_usd,omitempty"`
|
||||
MonthlyUsedUSD float64 `json:"monthly_used_usd,omitempty"`
|
||||
MonthlyLimitUSD float64 `json:"monthly_limit_usd,omitempty"`
|
||||
ExpiresAt *string `json:"expires_at,omitempty"`
|
||||
}
|
||||
|
||||
// SubscriptionProgressInfo represents subscription with progress info
|
||||
type SubscriptionProgressInfo struct {
|
||||
Subscription *model.UserSubscription `json:"subscription"`
|
||||
Progress *service.SubscriptionProgress `json:"progress"`
|
||||
}
|
||||
|
||||
// SubscriptionHandler handles user subscription operations
|
||||
type SubscriptionHandler struct {
|
||||
subscriptionService *service.SubscriptionService
|
||||
}
|
||||
|
||||
// NewSubscriptionHandler creates a new user subscription handler
|
||||
func NewSubscriptionHandler(subscriptionService *service.SubscriptionService) *SubscriptionHandler {
|
||||
return &SubscriptionHandler{
|
||||
subscriptionService: subscriptionService,
|
||||
}
|
||||
}
|
||||
|
||||
// List handles listing current user's subscriptions
|
||||
// GET /api/v1/subscriptions
|
||||
func (h *SubscriptionHandler) List(c *gin.Context) {
|
||||
user, exists := c.Get("user")
|
||||
if !exists {
|
||||
response.Unauthorized(c, "User not found in context")
|
||||
return
|
||||
}
|
||||
|
||||
u, ok := user.(*model.User)
|
||||
if !ok {
|
||||
response.InternalError(c, "Invalid user in context")
|
||||
return
|
||||
}
|
||||
|
||||
subscriptions, err := h.subscriptionService.ListUserSubscriptions(c.Request.Context(), u.ID)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to list subscriptions: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, subscriptions)
|
||||
}
|
||||
|
||||
// GetActive handles getting current user's active subscriptions
|
||||
// GET /api/v1/subscriptions/active
|
||||
func (h *SubscriptionHandler) GetActive(c *gin.Context) {
|
||||
user, exists := c.Get("user")
|
||||
if !exists {
|
||||
response.Unauthorized(c, "User not found in context")
|
||||
return
|
||||
}
|
||||
|
||||
u, ok := user.(*model.User)
|
||||
if !ok {
|
||||
response.InternalError(c, "Invalid user in context")
|
||||
return
|
||||
}
|
||||
|
||||
subscriptions, err := h.subscriptionService.ListActiveUserSubscriptions(c.Request.Context(), u.ID)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to get active subscriptions: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, subscriptions)
|
||||
}
|
||||
|
||||
// GetProgress handles getting subscription progress for current user
|
||||
// GET /api/v1/subscriptions/progress
|
||||
func (h *SubscriptionHandler) GetProgress(c *gin.Context) {
|
||||
user, exists := c.Get("user")
|
||||
if !exists {
|
||||
response.Unauthorized(c, "User not found in context")
|
||||
return
|
||||
}
|
||||
|
||||
u, ok := user.(*model.User)
|
||||
if !ok {
|
||||
response.InternalError(c, "Invalid user in context")
|
||||
return
|
||||
}
|
||||
|
||||
// Get all active subscriptions with progress
|
||||
subscriptions, err := h.subscriptionService.ListActiveUserSubscriptions(c.Request.Context(), u.ID)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to get subscriptions: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
result := make([]SubscriptionProgressInfo, 0, len(subscriptions))
|
||||
for i := range subscriptions {
|
||||
sub := &subscriptions[i]
|
||||
progress, err := h.subscriptionService.GetSubscriptionProgress(c.Request.Context(), sub.ID)
|
||||
if err != nil {
|
||||
// Skip subscriptions with errors
|
||||
continue
|
||||
}
|
||||
result = append(result, SubscriptionProgressInfo{
|
||||
Subscription: sub,
|
||||
Progress: progress,
|
||||
})
|
||||
}
|
||||
|
||||
response.Success(c, result)
|
||||
}
|
||||
|
||||
// GetSummary handles getting a summary of current user's subscription status
|
||||
// GET /api/v1/subscriptions/summary
|
||||
func (h *SubscriptionHandler) GetSummary(c *gin.Context) {
|
||||
user, exists := c.Get("user")
|
||||
if !exists {
|
||||
response.Unauthorized(c, "User not found in context")
|
||||
return
|
||||
}
|
||||
|
||||
u, ok := user.(*model.User)
|
||||
if !ok {
|
||||
response.InternalError(c, "Invalid user in context")
|
||||
return
|
||||
}
|
||||
|
||||
// Get all active subscriptions
|
||||
subscriptions, err := h.subscriptionService.ListActiveUserSubscriptions(c.Request.Context(), u.ID)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to get subscriptions: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
var totalUsed float64
|
||||
items := make([]SubscriptionSummaryItem, 0, len(subscriptions))
|
||||
|
||||
for _, sub := range subscriptions {
|
||||
item := SubscriptionSummaryItem{
|
||||
ID: sub.ID,
|
||||
GroupID: sub.GroupID,
|
||||
Status: sub.Status,
|
||||
DailyUsedUSD: sub.DailyUsageUSD,
|
||||
WeeklyUsedUSD: sub.WeeklyUsageUSD,
|
||||
MonthlyUsedUSD: sub.MonthlyUsageUSD,
|
||||
}
|
||||
|
||||
// Add group info if preloaded
|
||||
if sub.Group != nil {
|
||||
item.GroupName = sub.Group.Name
|
||||
if sub.Group.DailyLimitUSD != nil {
|
||||
item.DailyLimitUSD = *sub.Group.DailyLimitUSD
|
||||
}
|
||||
if sub.Group.WeeklyLimitUSD != nil {
|
||||
item.WeeklyLimitUSD = *sub.Group.WeeklyLimitUSD
|
||||
}
|
||||
if sub.Group.MonthlyLimitUSD != nil {
|
||||
item.MonthlyLimitUSD = *sub.Group.MonthlyLimitUSD
|
||||
}
|
||||
}
|
||||
|
||||
// Format expiration time
|
||||
if !sub.ExpiresAt.IsZero() {
|
||||
formatted := sub.ExpiresAt.Format("2006-01-02T15:04:05Z07:00")
|
||||
item.ExpiresAt = &formatted
|
||||
}
|
||||
|
||||
// Track total usage (use monthly as the most comprehensive)
|
||||
totalUsed += sub.MonthlyUsageUSD
|
||||
|
||||
items = append(items, item)
|
||||
}
|
||||
|
||||
summary := struct {
|
||||
ActiveCount int `json:"active_count"`
|
||||
TotalUsedUSD float64 `json:"total_used_usd"`
|
||||
Subscriptions []SubscriptionSummaryItem `json:"subscriptions"`
|
||||
}{
|
||||
ActiveCount: len(subscriptions),
|
||||
TotalUsedUSD: totalUsed,
|
||||
Subscriptions: items,
|
||||
}
|
||||
|
||||
response.Success(c, summary)
|
||||
}
|
||||
396
backend/internal/handler/usage_handler.go
Normal file
396
backend/internal/handler/usage_handler.go
Normal file
@@ -0,0 +1,396 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/pkg/response"
|
||||
"sub2api/internal/pkg/timezone"
|
||||
"sub2api/internal/repository"
|
||||
"sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// UsageHandler handles usage-related requests
|
||||
type UsageHandler struct {
|
||||
usageService *service.UsageService
|
||||
usageRepo *repository.UsageLogRepository
|
||||
apiKeyService *service.ApiKeyService
|
||||
}
|
||||
|
||||
// NewUsageHandler creates a new UsageHandler
|
||||
func NewUsageHandler(usageService *service.UsageService, usageRepo *repository.UsageLogRepository, apiKeyService *service.ApiKeyService) *UsageHandler {
|
||||
return &UsageHandler{
|
||||
usageService: usageService,
|
||||
usageRepo: usageRepo,
|
||||
apiKeyService: apiKeyService,
|
||||
}
|
||||
}
|
||||
|
||||
// List handles listing usage records with pagination
|
||||
// GET /api/v1/usage
|
||||
func (h *UsageHandler) List(c *gin.Context) {
|
||||
userValue, exists := c.Get("user")
|
||||
if !exists {
|
||||
response.Unauthorized(c, "User not authenticated")
|
||||
return
|
||||
}
|
||||
|
||||
user, ok := userValue.(*model.User)
|
||||
if !ok {
|
||||
response.InternalError(c, "Invalid user context")
|
||||
return
|
||||
}
|
||||
|
||||
page, pageSize := response.ParsePagination(c)
|
||||
|
||||
var apiKeyID int64
|
||||
if apiKeyIDStr := c.Query("api_key_id"); apiKeyIDStr != "" {
|
||||
id, err := strconv.ParseInt(apiKeyIDStr, 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid api_key_id")
|
||||
return
|
||||
}
|
||||
|
||||
// [Security Fix] Verify API Key ownership to prevent horizontal privilege escalation
|
||||
apiKey, err := h.apiKeyService.GetByID(c.Request.Context(), id)
|
||||
if err != nil {
|
||||
response.NotFound(c, "API key not found")
|
||||
return
|
||||
}
|
||||
if apiKey.UserID != user.ID {
|
||||
response.Forbidden(c, "Not authorized to access this API key's usage records")
|
||||
return
|
||||
}
|
||||
|
||||
apiKeyID = id
|
||||
}
|
||||
|
||||
params := repository.PaginationParams{Page: page, PageSize: pageSize}
|
||||
var records []model.UsageLog
|
||||
var result *repository.PaginationResult
|
||||
var err error
|
||||
|
||||
if apiKeyID > 0 {
|
||||
records, result, err = h.usageService.ListByApiKey(c.Request.Context(), apiKeyID, params)
|
||||
} else {
|
||||
records, result, err = h.usageService.ListByUser(c.Request.Context(), user.ID, params)
|
||||
}
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to list usage records: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
response.Paginated(c, records, result.Total, page, pageSize)
|
||||
}
|
||||
|
||||
// GetByID handles getting a single usage record
|
||||
// GET /api/v1/usage/:id
|
||||
func (h *UsageHandler) GetByID(c *gin.Context) {
|
||||
userValue, exists := c.Get("user")
|
||||
if !exists {
|
||||
response.Unauthorized(c, "User not authenticated")
|
||||
return
|
||||
}
|
||||
|
||||
user, ok := userValue.(*model.User)
|
||||
if !ok {
|
||||
response.InternalError(c, "Invalid user context")
|
||||
return
|
||||
}
|
||||
|
||||
usageID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid usage ID")
|
||||
return
|
||||
}
|
||||
|
||||
record, err := h.usageService.GetByID(c.Request.Context(), usageID)
|
||||
if err != nil {
|
||||
response.NotFound(c, "Usage record not found")
|
||||
return
|
||||
}
|
||||
|
||||
// 验证所有权
|
||||
if record.UserID != user.ID {
|
||||
response.Forbidden(c, "Not authorized to access this record")
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, record)
|
||||
}
|
||||
|
||||
// Stats handles getting usage statistics
|
||||
// GET /api/v1/usage/stats
|
||||
func (h *UsageHandler) Stats(c *gin.Context) {
|
||||
userValue, exists := c.Get("user")
|
||||
if !exists {
|
||||
response.Unauthorized(c, "User not authenticated")
|
||||
return
|
||||
}
|
||||
|
||||
user, ok := userValue.(*model.User)
|
||||
if !ok {
|
||||
response.InternalError(c, "Invalid user context")
|
||||
return
|
||||
}
|
||||
|
||||
var apiKeyID int64
|
||||
if apiKeyIDStr := c.Query("api_key_id"); apiKeyIDStr != "" {
|
||||
id, err := strconv.ParseInt(apiKeyIDStr, 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid api_key_id")
|
||||
return
|
||||
}
|
||||
|
||||
// [Security Fix] Verify API Key ownership to prevent horizontal privilege escalation
|
||||
apiKey, err := h.apiKeyService.GetByID(c.Request.Context(), id)
|
||||
if err != nil {
|
||||
response.NotFound(c, "API key not found")
|
||||
return
|
||||
}
|
||||
if apiKey.UserID != user.ID {
|
||||
response.Forbidden(c, "Not authorized to access this API key's statistics")
|
||||
return
|
||||
}
|
||||
|
||||
apiKeyID = id
|
||||
}
|
||||
|
||||
// 获取时间范围参数
|
||||
now := timezone.Now()
|
||||
var startTime, endTime time.Time
|
||||
|
||||
// 优先使用 start_date 和 end_date 参数
|
||||
startDateStr := c.Query("start_date")
|
||||
endDateStr := c.Query("end_date")
|
||||
|
||||
if startDateStr != "" && endDateStr != "" {
|
||||
// 使用自定义日期范围
|
||||
var err error
|
||||
startTime, err = timezone.ParseInLocation("2006-01-02", startDateStr)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid start_date format, use YYYY-MM-DD")
|
||||
return
|
||||
}
|
||||
endTime, err = timezone.ParseInLocation("2006-01-02", endDateStr)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid end_date format, use YYYY-MM-DD")
|
||||
return
|
||||
}
|
||||
// 设置结束时间为当天结束
|
||||
endTime = endTime.Add(24*time.Hour - time.Nanosecond)
|
||||
} else {
|
||||
// 使用 period 参数
|
||||
period := c.DefaultQuery("period", "today")
|
||||
switch period {
|
||||
case "today":
|
||||
startTime = timezone.StartOfDay(now)
|
||||
case "week":
|
||||
startTime = now.AddDate(0, 0, -7)
|
||||
case "month":
|
||||
startTime = now.AddDate(0, -1, 0)
|
||||
default:
|
||||
startTime = timezone.StartOfDay(now)
|
||||
}
|
||||
endTime = now
|
||||
}
|
||||
|
||||
var stats *service.UsageStats
|
||||
var err error
|
||||
if apiKeyID > 0 {
|
||||
stats, err = h.usageService.GetStatsByApiKey(c.Request.Context(), apiKeyID, startTime, endTime)
|
||||
} else {
|
||||
stats, err = h.usageService.GetStatsByUser(c.Request.Context(), user.ID, startTime, endTime)
|
||||
}
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to get usage statistics: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, stats)
|
||||
}
|
||||
|
||||
// parseUserTimeRange parses start_date, end_date query parameters for user dashboard
|
||||
func parseUserTimeRange(c *gin.Context) (time.Time, time.Time) {
|
||||
now := timezone.Now()
|
||||
startDate := c.Query("start_date")
|
||||
endDate := c.Query("end_date")
|
||||
|
||||
var startTime, endTime time.Time
|
||||
|
||||
if startDate != "" {
|
||||
if t, err := timezone.ParseInLocation("2006-01-02", startDate); err == nil {
|
||||
startTime = t
|
||||
} else {
|
||||
startTime = timezone.StartOfDay(now.AddDate(0, 0, -7))
|
||||
}
|
||||
} else {
|
||||
startTime = timezone.StartOfDay(now.AddDate(0, 0, -7))
|
||||
}
|
||||
|
||||
if endDate != "" {
|
||||
if t, err := timezone.ParseInLocation("2006-01-02", endDate); err == nil {
|
||||
endTime = t.Add(24 * time.Hour) // Include the end date
|
||||
} else {
|
||||
endTime = timezone.StartOfDay(now.AddDate(0, 0, 1))
|
||||
}
|
||||
} else {
|
||||
endTime = timezone.StartOfDay(now.AddDate(0, 0, 1))
|
||||
}
|
||||
|
||||
return startTime, endTime
|
||||
}
|
||||
|
||||
// DashboardStats handles getting user dashboard statistics
|
||||
// GET /api/v1/usage/dashboard/stats
|
||||
func (h *UsageHandler) DashboardStats(c *gin.Context) {
|
||||
userValue, exists := c.Get("user")
|
||||
if !exists {
|
||||
response.Unauthorized(c, "User not authenticated")
|
||||
return
|
||||
}
|
||||
|
||||
user, ok := userValue.(*model.User)
|
||||
if !ok {
|
||||
response.InternalError(c, "Invalid user context")
|
||||
return
|
||||
}
|
||||
|
||||
stats, err := h.usageRepo.GetUserDashboardStats(c.Request.Context(), user.ID)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to get dashboard statistics")
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, stats)
|
||||
}
|
||||
|
||||
// DashboardTrend handles getting user usage trend data
|
||||
// GET /api/v1/usage/dashboard/trend
|
||||
func (h *UsageHandler) DashboardTrend(c *gin.Context) {
|
||||
userValue, exists := c.Get("user")
|
||||
if !exists {
|
||||
response.Unauthorized(c, "User not authenticated")
|
||||
return
|
||||
}
|
||||
|
||||
user, ok := userValue.(*model.User)
|
||||
if !ok {
|
||||
response.InternalError(c, "Invalid user context")
|
||||
return
|
||||
}
|
||||
|
||||
startTime, endTime := parseUserTimeRange(c)
|
||||
granularity := c.DefaultQuery("granularity", "day")
|
||||
|
||||
trend, err := h.usageRepo.GetUserUsageTrendByUserID(c.Request.Context(), user.ID, startTime, endTime, granularity)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to get usage trend")
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{
|
||||
"trend": trend,
|
||||
"start_date": startTime.Format("2006-01-02"),
|
||||
"end_date": endTime.Add(-24 * time.Hour).Format("2006-01-02"),
|
||||
"granularity": granularity,
|
||||
})
|
||||
}
|
||||
|
||||
// DashboardModels handles getting user model usage statistics
|
||||
// GET /api/v1/usage/dashboard/models
|
||||
func (h *UsageHandler) DashboardModels(c *gin.Context) {
|
||||
userValue, exists := c.Get("user")
|
||||
if !exists {
|
||||
response.Unauthorized(c, "User not authenticated")
|
||||
return
|
||||
}
|
||||
|
||||
user, ok := userValue.(*model.User)
|
||||
if !ok {
|
||||
response.InternalError(c, "Invalid user context")
|
||||
return
|
||||
}
|
||||
|
||||
startTime, endTime := parseUserTimeRange(c)
|
||||
|
||||
stats, err := h.usageRepo.GetUserModelStats(c.Request.Context(), user.ID, startTime, endTime)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to get model statistics")
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{
|
||||
"models": stats,
|
||||
"start_date": startTime.Format("2006-01-02"),
|
||||
"end_date": endTime.Add(-24 * time.Hour).Format("2006-01-02"),
|
||||
})
|
||||
}
|
||||
|
||||
// BatchApiKeysUsageRequest represents the request for batch API keys usage
|
||||
type BatchApiKeysUsageRequest struct {
|
||||
ApiKeyIDs []int64 `json:"api_key_ids" binding:"required"`
|
||||
}
|
||||
|
||||
// DashboardApiKeysUsage handles getting usage stats for user's own API keys
|
||||
// POST /api/v1/usage/dashboard/api-keys-usage
|
||||
func (h *UsageHandler) DashboardApiKeysUsage(c *gin.Context) {
|
||||
userValue, exists := c.Get("user")
|
||||
if !exists {
|
||||
response.Unauthorized(c, "User not authenticated")
|
||||
return
|
||||
}
|
||||
|
||||
user, ok := userValue.(*model.User)
|
||||
if !ok {
|
||||
response.InternalError(c, "Invalid user context")
|
||||
return
|
||||
}
|
||||
|
||||
var req BatchApiKeysUsageRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if len(req.ApiKeyIDs) == 0 {
|
||||
response.Success(c, gin.H{"stats": map[string]interface{}{}})
|
||||
return
|
||||
}
|
||||
|
||||
// Verify ownership of all requested API keys
|
||||
userApiKeys, _, err := h.apiKeyService.List(c.Request.Context(), user.ID, repository.PaginationParams{Page: 1, PageSize: 1000})
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to verify API key ownership")
|
||||
return
|
||||
}
|
||||
|
||||
userApiKeyIDs := make(map[int64]bool)
|
||||
for _, key := range userApiKeys {
|
||||
userApiKeyIDs[key.ID] = true
|
||||
}
|
||||
|
||||
// Filter to only include user's own API keys
|
||||
validApiKeyIDs := make([]int64, 0)
|
||||
for _, id := range req.ApiKeyIDs {
|
||||
if userApiKeyIDs[id] {
|
||||
validApiKeyIDs = append(validApiKeyIDs, id)
|
||||
}
|
||||
}
|
||||
|
||||
if len(validApiKeyIDs) == 0 {
|
||||
response.Success(c, gin.H{"stats": map[string]interface{}{}})
|
||||
return
|
||||
}
|
||||
|
||||
stats, err := h.usageRepo.GetBatchApiKeyUsageStats(c.Request.Context(), validApiKeyIDs)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to get API key usage stats")
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{"stats": stats})
|
||||
}
|
||||
85
backend/internal/handler/user_handler.go
Normal file
85
backend/internal/handler/user_handler.go
Normal file
@@ -0,0 +1,85 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/pkg/response"
|
||||
"sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// UserHandler handles user-related requests
|
||||
type UserHandler struct {
|
||||
userService *service.UserService
|
||||
}
|
||||
|
||||
// NewUserHandler creates a new UserHandler
|
||||
func NewUserHandler(userService *service.UserService) *UserHandler {
|
||||
return &UserHandler{
|
||||
userService: userService,
|
||||
}
|
||||
}
|
||||
|
||||
// ChangePasswordRequest represents the change password request payload
|
||||
type ChangePasswordRequest struct {
|
||||
OldPassword string `json:"old_password" binding:"required"`
|
||||
NewPassword string `json:"new_password" binding:"required,min=6"`
|
||||
}
|
||||
|
||||
// GetProfile handles getting user profile
|
||||
// GET /api/v1/users/me
|
||||
func (h *UserHandler) GetProfile(c *gin.Context) {
|
||||
userValue, exists := c.Get("user")
|
||||
if !exists {
|
||||
response.Unauthorized(c, "User not authenticated")
|
||||
return
|
||||
}
|
||||
|
||||
user, ok := userValue.(*model.User)
|
||||
if !ok {
|
||||
response.InternalError(c, "Invalid user context")
|
||||
return
|
||||
}
|
||||
|
||||
userData, err := h.userService.GetByID(c.Request.Context(), user.ID)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to get user profile: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, userData)
|
||||
}
|
||||
|
||||
// ChangePassword handles changing user password
|
||||
// POST /api/v1/users/me/password
|
||||
func (h *UserHandler) ChangePassword(c *gin.Context) {
|
||||
userValue, exists := c.Get("user")
|
||||
if !exists {
|
||||
response.Unauthorized(c, "User not authenticated")
|
||||
return
|
||||
}
|
||||
|
||||
user, ok := userValue.(*model.User)
|
||||
if !ok {
|
||||
response.InternalError(c, "Invalid user context")
|
||||
return
|
||||
}
|
||||
|
||||
var req ChangePasswordRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
svcReq := service.ChangePasswordRequest{
|
||||
CurrentPassword: req.OldPassword,
|
||||
NewPassword: req.NewPassword,
|
||||
}
|
||||
err := h.userService.ChangePassword(c.Request.Context(), user.ID, svcReq)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Failed to change password: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{"message": "Password changed successfully"})
|
||||
}
|
||||
28
backend/internal/middleware/admin_only.go
Normal file
28
backend/internal/middleware/admin_only.go
Normal file
@@ -0,0 +1,28 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"sub2api/internal/model"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// AdminOnly 管理员权限中间件
|
||||
// 必须在JWTAuth中间件之后使用
|
||||
func AdminOnly() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// 从上下文获取用户
|
||||
user, exists := GetUserFromContext(c)
|
||||
if !exists {
|
||||
AbortWithError(c, 401, "UNAUTHORIZED", "User not found in context")
|
||||
return
|
||||
}
|
||||
|
||||
// 检查是否为管理员
|
||||
if user.Role != model.RoleAdmin {
|
||||
AbortWithError(c, 403, "FORBIDDEN", "Admin access required")
|
||||
return
|
||||
}
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
161
backend/internal/middleware/api_key_auth.go
Normal file
161
backend/internal/middleware/api_key_auth.go
Normal file
@@ -0,0 +1,161 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"log"
|
||||
"strings"
|
||||
"sub2api/internal/model"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// ApiKeyAuthService 定义API Key认证服务需要的接口
|
||||
type ApiKeyAuthService interface {
|
||||
GetByKey(ctx context.Context, key string) (*model.ApiKey, error)
|
||||
}
|
||||
|
||||
// SubscriptionAuthService 定义订阅认证服务需要的接口
|
||||
type SubscriptionAuthService interface {
|
||||
GetActiveSubscription(ctx context.Context, userID, groupID int64) (*model.UserSubscription, error)
|
||||
ValidateSubscription(ctx context.Context, sub *model.UserSubscription) error
|
||||
CheckAndActivateWindow(ctx context.Context, sub *model.UserSubscription) error
|
||||
CheckAndResetWindows(ctx context.Context, sub *model.UserSubscription) error
|
||||
CheckUsageLimits(ctx context.Context, sub *model.UserSubscription, group *model.Group, additionalCost float64) error
|
||||
}
|
||||
|
||||
// ApiKeyAuth API Key认证中间件
|
||||
func ApiKeyAuth(apiKeyRepo ApiKeyAuthService) gin.HandlerFunc {
|
||||
return ApiKeyAuthWithSubscription(apiKeyRepo, nil)
|
||||
}
|
||||
|
||||
// ApiKeyAuthWithSubscription API Key认证中间件(支持订阅验证)
|
||||
func ApiKeyAuthWithSubscription(apiKeyRepo ApiKeyAuthService, subscriptionService SubscriptionAuthService) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// 尝试从Authorization header中提取API key (Bearer scheme)
|
||||
authHeader := c.GetHeader("Authorization")
|
||||
var apiKeyString string
|
||||
|
||||
if authHeader != "" {
|
||||
// 验证Bearer scheme
|
||||
parts := strings.SplitN(authHeader, " ", 2)
|
||||
if len(parts) == 2 && parts[0] == "Bearer" {
|
||||
apiKeyString = parts[1]
|
||||
}
|
||||
}
|
||||
|
||||
// 如果Authorization header中没有,尝试从x-api-key header中提取
|
||||
if apiKeyString == "" {
|
||||
apiKeyString = c.GetHeader("x-api-key")
|
||||
}
|
||||
|
||||
// 如果两个header都没有API key
|
||||
if apiKeyString == "" {
|
||||
AbortWithError(c, 401, "API_KEY_REQUIRED", "API key is required in Authorization header (Bearer scheme) or x-api-key header")
|
||||
return
|
||||
}
|
||||
|
||||
// 从数据库验证API key
|
||||
apiKey, err := apiKeyRepo.GetByKey(c.Request.Context(), apiKeyString)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
AbortWithError(c, 401, "INVALID_API_KEY", "Invalid API key")
|
||||
return
|
||||
}
|
||||
AbortWithError(c, 500, "INTERNAL_ERROR", "Failed to validate API key")
|
||||
return
|
||||
}
|
||||
|
||||
// 检查API key是否激活
|
||||
if !apiKey.IsActive() {
|
||||
AbortWithError(c, 401, "API_KEY_DISABLED", "API key is disabled")
|
||||
return
|
||||
}
|
||||
|
||||
// 检查关联的用户
|
||||
if apiKey.User == nil {
|
||||
AbortWithError(c, 401, "USER_NOT_FOUND", "User associated with API key not found")
|
||||
return
|
||||
}
|
||||
|
||||
// 检查用户状态
|
||||
if !apiKey.User.IsActive() {
|
||||
AbortWithError(c, 401, "USER_INACTIVE", "User account is not active")
|
||||
return
|
||||
}
|
||||
|
||||
// 判断计费方式:订阅模式 vs 余额模式
|
||||
isSubscriptionType := apiKey.Group != nil && apiKey.Group.IsSubscriptionType()
|
||||
|
||||
if isSubscriptionType && subscriptionService != nil {
|
||||
// 订阅模式:验证订阅
|
||||
subscription, err := subscriptionService.GetActiveSubscription(
|
||||
c.Request.Context(),
|
||||
apiKey.User.ID,
|
||||
apiKey.Group.ID,
|
||||
)
|
||||
if err != nil {
|
||||
AbortWithError(c, 403, "SUBSCRIPTION_NOT_FOUND", "No active subscription found for this group")
|
||||
return
|
||||
}
|
||||
|
||||
// 验证订阅状态(是否过期、暂停等)
|
||||
if err := subscriptionService.ValidateSubscription(c.Request.Context(), subscription); err != nil {
|
||||
AbortWithError(c, 403, "SUBSCRIPTION_INVALID", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 激活滑动窗口(首次使用时)
|
||||
if err := subscriptionService.CheckAndActivateWindow(c.Request.Context(), subscription); err != nil {
|
||||
log.Printf("Failed to activate subscription windows: %v", err)
|
||||
}
|
||||
|
||||
// 检查并重置过期窗口
|
||||
if err := subscriptionService.CheckAndResetWindows(c.Request.Context(), subscription); err != nil {
|
||||
log.Printf("Failed to reset subscription windows: %v", err)
|
||||
}
|
||||
|
||||
// 预检查用量限制(使用0作为额外费用进行预检查)
|
||||
if err := subscriptionService.CheckUsageLimits(c.Request.Context(), subscription, apiKey.Group, 0); err != nil {
|
||||
AbortWithError(c, 429, "USAGE_LIMIT_EXCEEDED", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 将订阅信息存入上下文
|
||||
c.Set(string(ContextKeySubscription), subscription)
|
||||
} else {
|
||||
// 余额模式:检查用户余额
|
||||
if apiKey.User.Balance <= 0 {
|
||||
AbortWithError(c, 403, "INSUFFICIENT_BALANCE", "Insufficient account balance")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// 将API key和用户信息存入上下文
|
||||
c.Set(string(ContextKeyApiKey), apiKey)
|
||||
c.Set(string(ContextKeyUser), apiKey.User)
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// GetApiKeyFromContext 从上下文中获取API key
|
||||
func GetApiKeyFromContext(c *gin.Context) (*model.ApiKey, bool) {
|
||||
value, exists := c.Get(string(ContextKeyApiKey))
|
||||
if !exists {
|
||||
return nil, false
|
||||
}
|
||||
apiKey, ok := value.(*model.ApiKey)
|
||||
return apiKey, ok
|
||||
}
|
||||
|
||||
// GetSubscriptionFromContext 从上下文中获取订阅信息
|
||||
func GetSubscriptionFromContext(c *gin.Context) (*model.UserSubscription, bool) {
|
||||
value, exists := c.Get(string(ContextKeySubscription))
|
||||
if !exists {
|
||||
return nil, false
|
||||
}
|
||||
subscription, ok := value.(*model.UserSubscription)
|
||||
return subscription, ok
|
||||
}
|
||||
24
backend/internal/middleware/cors.go
Normal file
24
backend/internal/middleware/cors.go
Normal file
@@ -0,0 +1,24 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// CORS 跨域中间件
|
||||
func CORS() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// 设置允许跨域的响应头
|
||||
c.Writer.Header().Set("Access-Control-Allow-Origin", "*")
|
||||
c.Writer.Header().Set("Access-Control-Allow-Credentials", "true")
|
||||
c.Writer.Header().Set("Access-Control-Allow-Headers", "Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization, accept, origin, Cache-Control, X-Requested-With, X-API-Key")
|
||||
c.Writer.Header().Set("Access-Control-Allow-Methods", "POST, OPTIONS, GET, PUT, DELETE, PATCH")
|
||||
|
||||
// 处理预检请求
|
||||
if c.Request.Method == "OPTIONS" {
|
||||
c.AbortWithStatus(204)
|
||||
return
|
||||
}
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
76
backend/internal/middleware/jwt_auth.go
Normal file
76
backend/internal/middleware/jwt_auth.go
Normal file
@@ -0,0 +1,76 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// JWTAuth JWT认证中间件
|
||||
func JWTAuth(authService *service.AuthService, userRepo interface {
|
||||
GetByID(ctx context.Context, id int64) (*model.User, error)
|
||||
}) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// 从Authorization header中提取token
|
||||
authHeader := c.GetHeader("Authorization")
|
||||
if authHeader == "" {
|
||||
AbortWithError(c, 401, "UNAUTHORIZED", "Authorization header is required")
|
||||
return
|
||||
}
|
||||
|
||||
// 验证Bearer scheme
|
||||
parts := strings.SplitN(authHeader, " ", 2)
|
||||
if len(parts) != 2 || parts[0] != "Bearer" {
|
||||
AbortWithError(c, 401, "INVALID_AUTH_HEADER", "Authorization header format must be 'Bearer {token}'")
|
||||
return
|
||||
}
|
||||
|
||||
tokenString := parts[1]
|
||||
if tokenString == "" {
|
||||
AbortWithError(c, 401, "EMPTY_TOKEN", "Token cannot be empty")
|
||||
return
|
||||
}
|
||||
|
||||
// 验证token
|
||||
claims, err := authService.ValidateToken(tokenString)
|
||||
if err != nil {
|
||||
if err == service.ErrTokenExpired {
|
||||
AbortWithError(c, 401, "TOKEN_EXPIRED", "Token has expired")
|
||||
return
|
||||
}
|
||||
AbortWithError(c, 401, "INVALID_TOKEN", "Invalid token")
|
||||
return
|
||||
}
|
||||
|
||||
// 从数据库获取最新的用户信息
|
||||
user, err := userRepo.GetByID(c.Request.Context(), claims.UserID)
|
||||
if err != nil {
|
||||
AbortWithError(c, 401, "USER_NOT_FOUND", "User not found")
|
||||
return
|
||||
}
|
||||
|
||||
// 检查用户状态
|
||||
if !user.IsActive() {
|
||||
AbortWithError(c, 401, "USER_INACTIVE", "User account is not active")
|
||||
return
|
||||
}
|
||||
|
||||
// 将用户信息存入上下文
|
||||
c.Set(string(ContextKeyUser), user)
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// GetUserFromContext 从上下文中获取用户
|
||||
func GetUserFromContext(c *gin.Context) (*model.User, bool) {
|
||||
value, exists := c.Get(string(ContextKeyUser))
|
||||
if !exists {
|
||||
return nil, false
|
||||
}
|
||||
user, ok := value.(*model.User)
|
||||
return user, ok
|
||||
}
|
||||
52
backend/internal/middleware/logger.go
Normal file
52
backend/internal/middleware/logger.go
Normal file
@@ -0,0 +1,52 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"log"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// Logger 请求日志中间件
|
||||
func Logger() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// 开始时间
|
||||
startTime := time.Now()
|
||||
|
||||
// 处理请求
|
||||
c.Next()
|
||||
|
||||
// 结束时间
|
||||
endTime := time.Now()
|
||||
|
||||
// 执行时间
|
||||
latency := endTime.Sub(startTime)
|
||||
|
||||
// 请求方法
|
||||
method := c.Request.Method
|
||||
|
||||
// 请求路径
|
||||
path := c.Request.URL.Path
|
||||
|
||||
// 状态码
|
||||
statusCode := c.Writer.Status()
|
||||
|
||||
// 客户端IP
|
||||
clientIP := c.ClientIP()
|
||||
|
||||
// 日志格式: [时间] 状态码 | 延迟 | IP | 方法 路径
|
||||
log.Printf("[GIN] %v | %3d | %13v | %15s | %-7s %s",
|
||||
endTime.Format("2006/01/02 - 15:04:05"),
|
||||
statusCode,
|
||||
latency,
|
||||
clientIP,
|
||||
method,
|
||||
path,
|
||||
)
|
||||
|
||||
// 如果有错误,额外记录错误信息
|
||||
if len(c.Errors) > 0 {
|
||||
log.Printf("[GIN] Errors: %v", c.Errors.String())
|
||||
}
|
||||
}
|
||||
}
|
||||
35
backend/internal/middleware/middleware.go
Normal file
35
backend/internal/middleware/middleware.go
Normal file
@@ -0,0 +1,35 @@
|
||||
package middleware
|
||||
|
||||
import "github.com/gin-gonic/gin"
|
||||
|
||||
// ContextKey 定义上下文键类型
|
||||
type ContextKey string
|
||||
|
||||
const (
|
||||
// ContextKeyUser 用户上下文键
|
||||
ContextKeyUser ContextKey = "user"
|
||||
// ContextKeyApiKey API密钥上下文键
|
||||
ContextKeyApiKey ContextKey = "api_key"
|
||||
// ContextKeySubscription 订阅上下文键
|
||||
ContextKeySubscription ContextKey = "subscription"
|
||||
)
|
||||
|
||||
// ErrorResponse 标准错误响应结构
|
||||
type ErrorResponse struct {
|
||||
Code string `json:"code"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
// NewErrorResponse 创建错误响应
|
||||
func NewErrorResponse(code, message string) ErrorResponse {
|
||||
return ErrorResponse{
|
||||
Code: code,
|
||||
Message: message,
|
||||
}
|
||||
}
|
||||
|
||||
// AbortWithError 中断请求并返回JSON错误
|
||||
func AbortWithError(c *gin.Context, statusCode int, code, message string) {
|
||||
c.JSON(statusCode, NewErrorResponse(code, message))
|
||||
c.Abort()
|
||||
}
|
||||
265
backend/internal/model/account.go
Normal file
265
backend/internal/model/account.go
Normal file
@@ -0,0 +1,265 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// JSONB 用于存储JSONB数据
|
||||
type JSONB map[string]interface{}
|
||||
|
||||
func (j JSONB) Value() (driver.Value, error) {
|
||||
if j == nil {
|
||||
return nil, nil
|
||||
}
|
||||
return json.Marshal(j)
|
||||
}
|
||||
|
||||
func (j *JSONB) Scan(value interface{}) error {
|
||||
if value == nil {
|
||||
*j = nil
|
||||
return nil
|
||||
}
|
||||
bytes, ok := value.([]byte)
|
||||
if !ok {
|
||||
return errors.New("type assertion to []byte failed")
|
||||
}
|
||||
return json.Unmarshal(bytes, j)
|
||||
}
|
||||
|
||||
type Account struct {
|
||||
ID int64 `gorm:"primaryKey" json:"id"`
|
||||
Name string `gorm:"size:100;not null" json:"name"`
|
||||
Platform string `gorm:"size:50;not null" json:"platform"` // anthropic/openai/gemini
|
||||
Type string `gorm:"size:20;not null" json:"type"` // oauth/apikey
|
||||
Credentials JSONB `gorm:"type:jsonb;default:'{}'" json:"credentials"` // 凭证(加密存储)
|
||||
Extra JSONB `gorm:"type:jsonb;default:'{}'" json:"extra"` // 扩展信息
|
||||
ProxyID *int64 `gorm:"index" json:"proxy_id"`
|
||||
Concurrency int `gorm:"default:3;not null" json:"concurrency"`
|
||||
Priority int `gorm:"default:50;not null" json:"priority"` // 1-100,越小越高
|
||||
Status string `gorm:"size:20;default:active;not null" json:"status"` // active/disabled/error
|
||||
ErrorMessage string `gorm:"type:text" json:"error_message"`
|
||||
LastUsedAt *time.Time `gorm:"index" json:"last_used_at"`
|
||||
CreatedAt time.Time `gorm:"not null" json:"created_at"`
|
||||
UpdatedAt time.Time `gorm:"not null" json:"updated_at"`
|
||||
DeletedAt gorm.DeletedAt `gorm:"index" json:"-"`
|
||||
|
||||
// 调度控制
|
||||
Schedulable bool `gorm:"default:true;not null" json:"schedulable"`
|
||||
|
||||
// 限流状态 (429)
|
||||
RateLimitedAt *time.Time `gorm:"index" json:"rate_limited_at"`
|
||||
RateLimitResetAt *time.Time `gorm:"index" json:"rate_limit_reset_at"`
|
||||
|
||||
// 过载状态 (529)
|
||||
OverloadUntil *time.Time `gorm:"index" json:"overload_until"`
|
||||
|
||||
// 5小时时间窗口
|
||||
SessionWindowStart *time.Time `json:"session_window_start"`
|
||||
SessionWindowEnd *time.Time `json:"session_window_end"`
|
||||
SessionWindowStatus string `gorm:"size:20" json:"session_window_status"` // allowed/allowed_warning/rejected
|
||||
|
||||
// 关联
|
||||
Proxy *Proxy `gorm:"foreignKey:ProxyID" json:"proxy,omitempty"`
|
||||
AccountGroups []AccountGroup `gorm:"foreignKey:AccountID" json:"account_groups,omitempty"`
|
||||
|
||||
// 虚拟字段 (不存储到数据库)
|
||||
GroupIDs []int64 `gorm:"-" json:"group_ids,omitempty"`
|
||||
}
|
||||
|
||||
func (Account) TableName() string {
|
||||
return "accounts"
|
||||
}
|
||||
|
||||
// IsActive 检查是否激活
|
||||
func (a *Account) IsActive() bool {
|
||||
return a.Status == "active"
|
||||
}
|
||||
|
||||
// IsSchedulable 检查账号是否可调度
|
||||
func (a *Account) IsSchedulable() bool {
|
||||
if !a.IsActive() || !a.Schedulable {
|
||||
return false
|
||||
}
|
||||
now := time.Now()
|
||||
if a.OverloadUntil != nil && now.Before(*a.OverloadUntil) {
|
||||
return false
|
||||
}
|
||||
if a.RateLimitResetAt != nil && now.Before(*a.RateLimitResetAt) {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// IsRateLimited 检查是否处于限流状态
|
||||
func (a *Account) IsRateLimited() bool {
|
||||
if a.RateLimitResetAt == nil {
|
||||
return false
|
||||
}
|
||||
return time.Now().Before(*a.RateLimitResetAt)
|
||||
}
|
||||
|
||||
// IsOverloaded 检查是否处于过载状态
|
||||
func (a *Account) IsOverloaded() bool {
|
||||
if a.OverloadUntil == nil {
|
||||
return false
|
||||
}
|
||||
return time.Now().Before(*a.OverloadUntil)
|
||||
}
|
||||
|
||||
// IsOAuth 检查是否为OAuth类型账号(包括oauth和setup-token)
|
||||
func (a *Account) IsOAuth() bool {
|
||||
return a.Type == AccountTypeOAuth || a.Type == AccountTypeSetupToken
|
||||
}
|
||||
|
||||
// CanGetUsage 检查账号是否可以获取usage信息(只有oauth类型可以,setup-token没有profile权限)
|
||||
func (a *Account) CanGetUsage() bool {
|
||||
return a.Type == AccountTypeOAuth
|
||||
}
|
||||
|
||||
// GetCredential 获取凭证字段
|
||||
func (a *Account) GetCredential(key string) string {
|
||||
if a.Credentials == nil {
|
||||
return ""
|
||||
}
|
||||
if v, ok := a.Credentials[key]; ok {
|
||||
if s, ok := v.(string); ok {
|
||||
return s
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// GetModelMapping 获取模型映射配置
|
||||
// 返回格式: map[请求模型名]实际模型名
|
||||
func (a *Account) GetModelMapping() map[string]string {
|
||||
if a.Credentials == nil {
|
||||
return nil
|
||||
}
|
||||
raw, ok := a.Credentials["model_mapping"]
|
||||
if !ok || raw == nil {
|
||||
return nil
|
||||
}
|
||||
// 处理map[string]interface{}类型
|
||||
if m, ok := raw.(map[string]interface{}); ok {
|
||||
result := make(map[string]string)
|
||||
for k, v := range m {
|
||||
if s, ok := v.(string); ok {
|
||||
result[k] = s
|
||||
}
|
||||
}
|
||||
if len(result) > 0 {
|
||||
return result
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsModelSupported 检查请求的模型是否被该账号支持
|
||||
// 如果没有设置模型映射,则支持所有模型
|
||||
func (a *Account) IsModelSupported(requestedModel string) bool {
|
||||
mapping := a.GetModelMapping()
|
||||
if mapping == nil || len(mapping) == 0 {
|
||||
return true // 没有映射配置,支持所有模型
|
||||
}
|
||||
_, exists := mapping[requestedModel]
|
||||
return exists
|
||||
}
|
||||
|
||||
// GetMappedModel 获取映射后的实际模型名
|
||||
// 如果没有映射,返回原始模型名
|
||||
func (a *Account) GetMappedModel(requestedModel string) string {
|
||||
mapping := a.GetModelMapping()
|
||||
if mapping == nil || len(mapping) == 0 {
|
||||
return requestedModel
|
||||
}
|
||||
if mappedModel, exists := mapping[requestedModel]; exists {
|
||||
return mappedModel
|
||||
}
|
||||
return requestedModel
|
||||
}
|
||||
|
||||
// GetBaseURL 获取API基础URL(用于apikey类型账号)
|
||||
func (a *Account) GetBaseURL() string {
|
||||
if a.Type != AccountTypeApiKey {
|
||||
return ""
|
||||
}
|
||||
baseURL := a.GetCredential("base_url")
|
||||
if baseURL == "" {
|
||||
return "https://api.anthropic.com" // 默认URL
|
||||
}
|
||||
return baseURL
|
||||
}
|
||||
|
||||
// GetExtraString 从Extra字段获取字符串值
|
||||
func (a *Account) GetExtraString(key string) string {
|
||||
if a.Extra == nil {
|
||||
return ""
|
||||
}
|
||||
if v, ok := a.Extra[key]; ok {
|
||||
if s, ok := v.(string); ok {
|
||||
return s
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// IsCustomErrorCodesEnabled 检查是否启用自定义错误码功能(仅适用于 apikey 类型)
|
||||
func (a *Account) IsCustomErrorCodesEnabled() bool {
|
||||
if a.Type != AccountTypeApiKey || a.Credentials == nil {
|
||||
return false
|
||||
}
|
||||
if v, ok := a.Credentials["custom_error_codes_enabled"]; ok {
|
||||
if enabled, ok := v.(bool); ok {
|
||||
return enabled
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// GetCustomErrorCodes 获取自定义错误码列表
|
||||
func (a *Account) GetCustomErrorCodes() []int {
|
||||
if a.Credentials == nil {
|
||||
return nil
|
||||
}
|
||||
raw, ok := a.Credentials["custom_error_codes"]
|
||||
if !ok || raw == nil {
|
||||
return nil
|
||||
}
|
||||
// 处理 []interface{} 类型(JSON反序列化后的格式)
|
||||
if arr, ok := raw.([]interface{}); ok {
|
||||
result := make([]int, 0, len(arr))
|
||||
for _, v := range arr {
|
||||
// JSON 数字默认解析为 float64
|
||||
if f, ok := v.(float64); ok {
|
||||
result = append(result, int(f))
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ShouldHandleErrorCode 检查指定错误码是否应该被处理(停止调度/标记限流等)
|
||||
// 如果未启用自定义错误码或列表为空,返回 true(使用默认策略)
|
||||
// 如果启用且列表非空,只有在列表中的错误码才返回 true
|
||||
func (a *Account) ShouldHandleErrorCode(statusCode int) bool {
|
||||
if !a.IsCustomErrorCodesEnabled() {
|
||||
return true // 未启用,使用默认策略
|
||||
}
|
||||
codes := a.GetCustomErrorCodes()
|
||||
if len(codes) == 0 {
|
||||
return true // 启用但列表为空,fallback到默认策略
|
||||
}
|
||||
// 检查是否在自定义列表中
|
||||
for _, code := range codes {
|
||||
if code == statusCode {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
20
backend/internal/model/account_group.go
Normal file
20
backend/internal/model/account_group.go
Normal file
@@ -0,0 +1,20 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
type AccountGroup struct {
|
||||
AccountID int64 `gorm:"primaryKey" json:"account_id"`
|
||||
GroupID int64 `gorm:"primaryKey" json:"group_id"`
|
||||
Priority int `gorm:"default:50;not null" json:"priority"` // 分组内优先级
|
||||
CreatedAt time.Time `gorm:"not null" json:"created_at"`
|
||||
|
||||
// 关联
|
||||
Account *Account `gorm:"foreignKey:AccountID" json:"account,omitempty"`
|
||||
Group *Group `gorm:"foreignKey:GroupID" json:"group,omitempty"`
|
||||
}
|
||||
|
||||
func (AccountGroup) TableName() string {
|
||||
return "account_groups"
|
||||
}
|
||||
32
backend/internal/model/api_key.go
Normal file
32
backend/internal/model/api_key.go
Normal file
@@ -0,0 +1,32 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type ApiKey struct {
|
||||
ID int64 `gorm:"primaryKey" json:"id"`
|
||||
UserID int64 `gorm:"index;not null" json:"user_id"`
|
||||
Key string `gorm:"uniqueIndex;size:128;not null" json:"key"` // sk-xxx
|
||||
Name string `gorm:"size:100;not null" json:"name"`
|
||||
GroupID *int64 `gorm:"index" json:"group_id"`
|
||||
Status string `gorm:"size:20;default:active;not null" json:"status"` // active/disabled
|
||||
CreatedAt time.Time `gorm:"not null" json:"created_at"`
|
||||
UpdatedAt time.Time `gorm:"not null" json:"updated_at"`
|
||||
DeletedAt gorm.DeletedAt `gorm:"index" json:"-"`
|
||||
|
||||
// 关联
|
||||
User *User `gorm:"foreignKey:UserID" json:"user,omitempty"`
|
||||
Group *Group `gorm:"foreignKey:GroupID" json:"group,omitempty"`
|
||||
}
|
||||
|
||||
func (ApiKey) TableName() string {
|
||||
return "api_keys"
|
||||
}
|
||||
|
||||
// IsActive 检查是否激活
|
||||
func (k *ApiKey) IsActive() bool {
|
||||
return k.Status == "active"
|
||||
}
|
||||
73
backend/internal/model/group.go
Normal file
73
backend/internal/model/group.go
Normal file
@@ -0,0 +1,73 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// 订阅类型常量
|
||||
const (
|
||||
SubscriptionTypeStandard = "standard" // 标准计费模式(按余额扣费)
|
||||
SubscriptionTypeSubscription = "subscription" // 订阅模式(按限额控制)
|
||||
)
|
||||
|
||||
type Group struct {
|
||||
ID int64 `gorm:"primaryKey" json:"id"`
|
||||
Name string `gorm:"uniqueIndex;size:100;not null" json:"name"`
|
||||
Description string `gorm:"type:text" json:"description"`
|
||||
Platform string `gorm:"size:50;default:anthropic;not null" json:"platform"` // anthropic/openai/gemini
|
||||
RateMultiplier float64 `gorm:"type:decimal(10,4);default:1.0;not null" json:"rate_multiplier"`
|
||||
IsExclusive bool `gorm:"default:false;not null" json:"is_exclusive"`
|
||||
Status string `gorm:"size:20;default:active;not null" json:"status"` // active/disabled
|
||||
|
||||
// 订阅功能字段
|
||||
SubscriptionType string `gorm:"size:20;default:standard;not null" json:"subscription_type"` // standard/subscription
|
||||
DailyLimitUSD *float64 `gorm:"type:decimal(20,8)" json:"daily_limit_usd"`
|
||||
WeeklyLimitUSD *float64 `gorm:"type:decimal(20,8)" json:"weekly_limit_usd"`
|
||||
MonthlyLimitUSD *float64 `gorm:"type:decimal(20,8)" json:"monthly_limit_usd"`
|
||||
|
||||
CreatedAt time.Time `gorm:"not null" json:"created_at"`
|
||||
UpdatedAt time.Time `gorm:"not null" json:"updated_at"`
|
||||
DeletedAt gorm.DeletedAt `gorm:"index" json:"-"`
|
||||
|
||||
// 关联
|
||||
AccountGroups []AccountGroup `gorm:"foreignKey:GroupID" json:"account_groups,omitempty"`
|
||||
|
||||
// 虚拟字段 (不存储到数据库)
|
||||
AccountCount int64 `gorm:"-" json:"account_count,omitempty"`
|
||||
}
|
||||
|
||||
func (Group) TableName() string {
|
||||
return "groups"
|
||||
}
|
||||
|
||||
// IsActive 检查是否激活
|
||||
func (g *Group) IsActive() bool {
|
||||
return g.Status == "active"
|
||||
}
|
||||
|
||||
// IsSubscriptionType 检查是否为订阅类型分组
|
||||
func (g *Group) IsSubscriptionType() bool {
|
||||
return g.SubscriptionType == SubscriptionTypeSubscription
|
||||
}
|
||||
|
||||
// IsFreeSubscription 检查是否为免费订阅(不扣余额但有限额)
|
||||
func (g *Group) IsFreeSubscription() bool {
|
||||
return g.IsSubscriptionType() && g.RateMultiplier == 0
|
||||
}
|
||||
|
||||
// HasDailyLimit 检查是否有日限额
|
||||
func (g *Group) HasDailyLimit() bool {
|
||||
return g.DailyLimitUSD != nil && *g.DailyLimitUSD > 0
|
||||
}
|
||||
|
||||
// HasWeeklyLimit 检查是否有周限额
|
||||
func (g *Group) HasWeeklyLimit() bool {
|
||||
return g.WeeklyLimitUSD != nil && *g.WeeklyLimitUSD > 0
|
||||
}
|
||||
|
||||
// HasMonthlyLimit 检查是否有月限额
|
||||
func (g *Group) HasMonthlyLimit() bool {
|
||||
return g.MonthlyLimitUSD != nil && *g.MonthlyLimitUSD > 0
|
||||
}
|
||||
64
backend/internal/model/model.go
Normal file
64
backend/internal/model/model.go
Normal file
@@ -0,0 +1,64 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// AutoMigrate 自动迁移所有模型
|
||||
func AutoMigrate(db *gorm.DB) error {
|
||||
return db.AutoMigrate(
|
||||
&User{},
|
||||
&ApiKey{},
|
||||
&Group{},
|
||||
&Account{},
|
||||
&AccountGroup{},
|
||||
&Proxy{},
|
||||
&RedeemCode{},
|
||||
&UsageLog{},
|
||||
&Setting{},
|
||||
&UserSubscription{},
|
||||
)
|
||||
}
|
||||
|
||||
// 状态常量
|
||||
const (
|
||||
StatusActive = "active"
|
||||
StatusDisabled = "disabled"
|
||||
StatusError = "error"
|
||||
StatusUnused = "unused"
|
||||
StatusUsed = "used"
|
||||
StatusExpired = "expired"
|
||||
)
|
||||
|
||||
// 角色常量
|
||||
const (
|
||||
RoleAdmin = "admin"
|
||||
RoleUser = "user"
|
||||
)
|
||||
|
||||
// 平台常量
|
||||
const (
|
||||
PlatformAnthropic = "anthropic"
|
||||
PlatformOpenAI = "openai"
|
||||
PlatformGemini = "gemini"
|
||||
)
|
||||
|
||||
// 账号类型常量
|
||||
const (
|
||||
AccountTypeOAuth = "oauth" // OAuth类型账号(full scope: profile + inference)
|
||||
AccountTypeSetupToken = "setup-token" // Setup Token类型账号(inference only scope)
|
||||
AccountTypeApiKey = "apikey" // API Key类型账号
|
||||
)
|
||||
|
||||
// 卡密类型常量
|
||||
const (
|
||||
RedeemTypeBalance = "balance"
|
||||
RedeemTypeConcurrency = "concurrency"
|
||||
RedeemTypeSubscription = "subscription"
|
||||
)
|
||||
|
||||
// 管理员调整类型常量
|
||||
const (
|
||||
AdjustmentTypeAdminBalance = "admin_balance" // 管理员调整余额
|
||||
AdjustmentTypeAdminConcurrency = "admin_concurrency" // 管理员调整并发数
|
||||
)
|
||||
45
backend/internal/model/proxy.go
Normal file
45
backend/internal/model/proxy.go
Normal file
@@ -0,0 +1,45 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type Proxy struct {
|
||||
ID int64 `gorm:"primaryKey" json:"id"`
|
||||
Name string `gorm:"size:100;not null" json:"name"`
|
||||
Protocol string `gorm:"size:20;not null" json:"protocol"` // http/https/socks5
|
||||
Host string `gorm:"size:255;not null" json:"host"`
|
||||
Port int `gorm:"not null" json:"port"`
|
||||
Username string `gorm:"size:100" json:"username"`
|
||||
Password string `gorm:"size:100" json:"-"`
|
||||
Status string `gorm:"size:20;default:active;not null" json:"status"` // active/disabled
|
||||
CreatedAt time.Time `gorm:"not null" json:"created_at"`
|
||||
UpdatedAt time.Time `gorm:"not null" json:"updated_at"`
|
||||
DeletedAt gorm.DeletedAt `gorm:"index" json:"-"`
|
||||
}
|
||||
|
||||
func (Proxy) TableName() string {
|
||||
return "proxies"
|
||||
}
|
||||
|
||||
// IsActive 检查是否激活
|
||||
func (p *Proxy) IsActive() bool {
|
||||
return p.Status == "active"
|
||||
}
|
||||
|
||||
// URL 返回代理URL
|
||||
func (p *Proxy) URL() string {
|
||||
if p.Username != "" && p.Password != "" {
|
||||
return fmt.Sprintf("%s://%s:%s@%s:%d", p.Protocol, p.Username, p.Password, p.Host, p.Port)
|
||||
}
|
||||
return fmt.Sprintf("%s://%s:%d", p.Protocol, p.Host, p.Port)
|
||||
}
|
||||
|
||||
// ProxyWithAccountCount extends Proxy with account count information
|
||||
type ProxyWithAccountCount struct {
|
||||
Proxy
|
||||
AccountCount int64 `json:"account_count"`
|
||||
}
|
||||
47
backend/internal/model/redeem_code.go
Normal file
47
backend/internal/model/redeem_code.go
Normal file
@@ -0,0 +1,47 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"time"
|
||||
)
|
||||
|
||||
type RedeemCode struct {
|
||||
ID int64 `gorm:"primaryKey" json:"id"`
|
||||
Code string `gorm:"uniqueIndex;size:32;not null" json:"code"`
|
||||
Type string `gorm:"size:20;default:balance;not null" json:"type"` // balance/concurrency/subscription
|
||||
Value float64 `gorm:"type:decimal(20,8);not null" json:"value"` // 面值(USD)或并发数或有效天数
|
||||
Status string `gorm:"size:20;default:unused;not null" json:"status"` // unused/used
|
||||
UsedBy *int64 `gorm:"index" json:"used_by"`
|
||||
UsedAt *time.Time `json:"used_at"`
|
||||
CreatedAt time.Time `gorm:"not null" json:"created_at"`
|
||||
|
||||
// 订阅类型专用字段
|
||||
GroupID *int64 `gorm:"index" json:"group_id"` // 订阅分组ID (仅subscription类型使用)
|
||||
ValidityDays int `gorm:"default:30" json:"validity_days"` // 订阅有效天数 (仅subscription类型使用)
|
||||
|
||||
// 关联
|
||||
User *User `gorm:"foreignKey:UsedBy" json:"user,omitempty"`
|
||||
Group *Group `gorm:"foreignKey:GroupID" json:"group,omitempty"`
|
||||
}
|
||||
|
||||
func (RedeemCode) TableName() string {
|
||||
return "redeem_codes"
|
||||
}
|
||||
|
||||
// IsUsed 检查是否已使用
|
||||
func (r *RedeemCode) IsUsed() bool {
|
||||
return r.Status == "used"
|
||||
}
|
||||
|
||||
// CanUse 检查是否可以使用
|
||||
func (r *RedeemCode) CanUse() bool {
|
||||
return r.Status == "unused"
|
||||
}
|
||||
|
||||
// GenerateRedeemCode 生成唯一的兑换码
|
||||
func GenerateRedeemCode() string {
|
||||
b := make([]byte, 16)
|
||||
rand.Read(b)
|
||||
return hex.EncodeToString(b)
|
||||
}
|
||||
95
backend/internal/model/setting.go
Normal file
95
backend/internal/model/setting.go
Normal file
@@ -0,0 +1,95 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// Setting 系统设置模型(Key-Value存储)
|
||||
type Setting struct {
|
||||
ID int64 `gorm:"primaryKey" json:"id"`
|
||||
Key string `gorm:"uniqueIndex;size:100;not null" json:"key"`
|
||||
Value string `gorm:"type:text;not null" json:"value"`
|
||||
UpdatedAt time.Time `gorm:"not null" json:"updated_at"`
|
||||
}
|
||||
|
||||
func (Setting) TableName() string {
|
||||
return "settings"
|
||||
}
|
||||
|
||||
// 设置Key常量
|
||||
const (
|
||||
// 注册设置
|
||||
SettingKeyRegistrationEnabled = "registration_enabled" // 是否开放注册
|
||||
SettingKeyEmailVerifyEnabled = "email_verify_enabled" // 是否开启邮件验证
|
||||
|
||||
// 邮件服务设置
|
||||
SettingKeySmtpHost = "smtp_host" // SMTP服务器地址
|
||||
SettingKeySmtpPort = "smtp_port" // SMTP端口
|
||||
SettingKeySmtpUsername = "smtp_username" // SMTP用户名
|
||||
SettingKeySmtpPassword = "smtp_password" // SMTP密码(加密存储)
|
||||
SettingKeySmtpFrom = "smtp_from" // 发件人地址
|
||||
SettingKeySmtpFromName = "smtp_from_name" // 发件人名称
|
||||
SettingKeySmtpUseTLS = "smtp_use_tls" // 是否使用TLS
|
||||
|
||||
// Cloudflare Turnstile 设置
|
||||
SettingKeyTurnstileEnabled = "turnstile_enabled" // 是否启用 Turnstile 验证
|
||||
SettingKeyTurnstileSiteKey = "turnstile_site_key" // Turnstile Site Key
|
||||
SettingKeyTurnstileSecretKey = "turnstile_secret_key" // Turnstile Secret Key
|
||||
|
||||
// OEM设置
|
||||
SettingKeySiteName = "site_name" // 网站名称
|
||||
SettingKeySiteLogo = "site_logo" // 网站Logo (base64)
|
||||
SettingKeySiteSubtitle = "site_subtitle" // 网站副标题
|
||||
SettingKeyApiBaseUrl = "api_base_url" // API端点地址(用于客户端配置和导入)
|
||||
SettingKeyContactInfo = "contact_info" // 客服联系方式
|
||||
|
||||
// 默认配置
|
||||
SettingKeyDefaultConcurrency = "default_concurrency" // 新用户默认并发量
|
||||
SettingKeyDefaultBalance = "default_balance" // 新用户默认余额
|
||||
)
|
||||
|
||||
// SystemSettings 系统设置结构体(用于API响应)
|
||||
type SystemSettings struct {
|
||||
// 注册设置
|
||||
RegistrationEnabled bool `json:"registration_enabled"`
|
||||
EmailVerifyEnabled bool `json:"email_verify_enabled"`
|
||||
|
||||
// 邮件服务设置
|
||||
SmtpHost string `json:"smtp_host"`
|
||||
SmtpPort int `json:"smtp_port"`
|
||||
SmtpUsername string `json:"smtp_username"`
|
||||
SmtpPassword string `json:"smtp_password,omitempty"` // 不返回明文密码
|
||||
SmtpFrom string `json:"smtp_from_email"`
|
||||
SmtpFromName string `json:"smtp_from_name"`
|
||||
SmtpUseTLS bool `json:"smtp_use_tls"`
|
||||
|
||||
// Cloudflare Turnstile 设置
|
||||
TurnstileEnabled bool `json:"turnstile_enabled"`
|
||||
TurnstileSiteKey string `json:"turnstile_site_key"`
|
||||
TurnstileSecretKey string `json:"turnstile_secret_key,omitempty"` // 不返回明文密钥
|
||||
|
||||
// OEM设置
|
||||
SiteName string `json:"site_name"`
|
||||
SiteLogo string `json:"site_logo"`
|
||||
SiteSubtitle string `json:"site_subtitle"`
|
||||
ApiBaseUrl string `json:"api_base_url"`
|
||||
ContactInfo string `json:"contact_info"`
|
||||
|
||||
// 默认配置
|
||||
DefaultConcurrency int `json:"default_concurrency"`
|
||||
DefaultBalance float64 `json:"default_balance"`
|
||||
}
|
||||
|
||||
// PublicSettings 公开设置(无需登录即可获取)
|
||||
type PublicSettings struct {
|
||||
RegistrationEnabled bool `json:"registration_enabled"`
|
||||
EmailVerifyEnabled bool `json:"email_verify_enabled"`
|
||||
TurnstileEnabled bool `json:"turnstile_enabled"`
|
||||
TurnstileSiteKey string `json:"turnstile_site_key"`
|
||||
SiteName string `json:"site_name"`
|
||||
SiteLogo string `json:"site_logo"`
|
||||
SiteSubtitle string `json:"site_subtitle"`
|
||||
ApiBaseUrl string `json:"api_base_url"`
|
||||
ContactInfo string `json:"contact_info"`
|
||||
Version string `json:"version"`
|
||||
}
|
||||
67
backend/internal/model/usage_log.go
Normal file
67
backend/internal/model/usage_log.go
Normal file
@@ -0,0 +1,67 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// 消费类型常量
|
||||
const (
|
||||
BillingTypeBalance int8 = 0 // 钱包余额
|
||||
BillingTypeSubscription int8 = 1 // 订阅套餐
|
||||
)
|
||||
|
||||
type UsageLog struct {
|
||||
ID int64 `gorm:"primaryKey" json:"id"`
|
||||
UserID int64 `gorm:"index;not null" json:"user_id"`
|
||||
ApiKeyID int64 `gorm:"index;not null" json:"api_key_id"`
|
||||
AccountID int64 `gorm:"index;not null" json:"account_id"`
|
||||
RequestID string `gorm:"size:64" json:"request_id"`
|
||||
Model string `gorm:"size:100;index;not null" json:"model"`
|
||||
|
||||
// 订阅关联(可选)
|
||||
GroupID *int64 `gorm:"index" json:"group_id"`
|
||||
SubscriptionID *int64 `gorm:"index" json:"subscription_id"`
|
||||
|
||||
// Token使用量(4类)
|
||||
InputTokens int `gorm:"default:0;not null" json:"input_tokens"`
|
||||
OutputTokens int `gorm:"default:0;not null" json:"output_tokens"`
|
||||
CacheCreationTokens int `gorm:"default:0;not null" json:"cache_creation_tokens"`
|
||||
CacheReadTokens int `gorm:"default:0;not null" json:"cache_read_tokens"`
|
||||
|
||||
// 详细的缓存创建分类
|
||||
CacheCreation5mTokens int `gorm:"default:0;not null" json:"cache_creation_5m_tokens"`
|
||||
CacheCreation1hTokens int `gorm:"default:0;not null" json:"cache_creation_1h_tokens"`
|
||||
|
||||
// 费用(USD)
|
||||
InputCost float64 `gorm:"type:decimal(20,10);default:0;not null" json:"input_cost"`
|
||||
OutputCost float64 `gorm:"type:decimal(20,10);default:0;not null" json:"output_cost"`
|
||||
CacheCreationCost float64 `gorm:"type:decimal(20,10);default:0;not null" json:"cache_creation_cost"`
|
||||
CacheReadCost float64 `gorm:"type:decimal(20,10);default:0;not null" json:"cache_read_cost"`
|
||||
TotalCost float64 `gorm:"type:decimal(20,10);default:0;not null" json:"total_cost"` // 原始总费用
|
||||
ActualCost float64 `gorm:"type:decimal(20,10);default:0;not null" json:"actual_cost"` // 实际扣除费用
|
||||
RateMultiplier float64 `gorm:"type:decimal(10,4);default:1;not null" json:"rate_multiplier"` // 计费倍率
|
||||
|
||||
// 元数据
|
||||
BillingType int8 `gorm:"type:smallint;default:0;not null" json:"billing_type"` // 0=余额 1=订阅
|
||||
Stream bool `gorm:"default:false;not null" json:"stream"`
|
||||
DurationMs *int `json:"duration_ms"`
|
||||
FirstTokenMs *int `json:"first_token_ms"` // 首字时间(流式请求)
|
||||
|
||||
CreatedAt time.Time `gorm:"index;not null" json:"created_at"`
|
||||
|
||||
// 关联
|
||||
User *User `gorm:"foreignKey:UserID" json:"user,omitempty"`
|
||||
ApiKey *ApiKey `gorm:"foreignKey:ApiKeyID" json:"api_key,omitempty"`
|
||||
Account *Account `gorm:"foreignKey:AccountID" json:"account,omitempty"`
|
||||
Group *Group `gorm:"foreignKey:GroupID" json:"group,omitempty"`
|
||||
Subscription *UserSubscription `gorm:"foreignKey:SubscriptionID" json:"subscription,omitempty"`
|
||||
}
|
||||
|
||||
func (UsageLog) TableName() string {
|
||||
return "usage_logs"
|
||||
}
|
||||
|
||||
// TotalTokens 总token数
|
||||
func (u *UsageLog) TotalTokens() int {
|
||||
return u.InputTokens + u.OutputTokens + u.CacheCreationTokens + u.CacheReadTokens
|
||||
}
|
||||
74
backend/internal/model/user.go
Normal file
74
backend/internal/model/user.go
Normal file
@@ -0,0 +1,74 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/lib/pq"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type User struct {
|
||||
ID int64 `gorm:"primaryKey" json:"id"`
|
||||
Email string `gorm:"uniqueIndex;size:255;not null" json:"email"`
|
||||
PasswordHash string `gorm:"size:255;not null" json:"-"`
|
||||
Role string `gorm:"size:20;default:user;not null" json:"role"` // admin/user
|
||||
Balance float64 `gorm:"type:decimal(20,8);default:0;not null" json:"balance"`
|
||||
Concurrency int `gorm:"default:5;not null" json:"concurrency"`
|
||||
Status string `gorm:"size:20;default:active;not null" json:"status"` // active/disabled
|
||||
AllowedGroups pq.Int64Array `gorm:"type:bigint[]" json:"allowed_groups"`
|
||||
CreatedAt time.Time `gorm:"not null" json:"created_at"`
|
||||
UpdatedAt time.Time `gorm:"not null" json:"updated_at"`
|
||||
DeletedAt gorm.DeletedAt `gorm:"index" json:"-"`
|
||||
|
||||
// 关联
|
||||
ApiKeys []ApiKey `gorm:"foreignKey:UserID" json:"api_keys,omitempty"`
|
||||
}
|
||||
|
||||
func (User) TableName() string {
|
||||
return "users"
|
||||
}
|
||||
|
||||
// IsAdmin 检查是否管理员
|
||||
func (u *User) IsAdmin() bool {
|
||||
return u.Role == "admin"
|
||||
}
|
||||
|
||||
// IsActive 检查是否激活
|
||||
func (u *User) IsActive() bool {
|
||||
return u.Status == "active"
|
||||
}
|
||||
|
||||
// CanBindGroup 检查是否可以绑定指定分组
|
||||
// 对于标准类型分组:
|
||||
// - 如果 AllowedGroups 设置了值(非空数组),只能绑定列表中的分组
|
||||
// - 如果 AllowedGroups 为 nil 或空数组,可以绑定所有非专属分组
|
||||
func (u *User) CanBindGroup(groupID int64, isExclusive bool) bool {
|
||||
// 如果设置了 allowed_groups 且不为空,只能绑定指定的分组
|
||||
if len(u.AllowedGroups) > 0 {
|
||||
for _, id := range u.AllowedGroups {
|
||||
if id == groupID {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
// 如果没有设置 allowed_groups 或为空数组,可以绑定所有非专属分组
|
||||
return !isExclusive
|
||||
}
|
||||
|
||||
// SetPassword 设置密码(哈希存储)
|
||||
func (u *User) SetPassword(password string) error {
|
||||
hash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
u.PasswordHash = string(hash)
|
||||
return nil
|
||||
}
|
||||
|
||||
// CheckPassword 验证密码
|
||||
func (u *User) CheckPassword(password string) bool {
|
||||
err := bcrypt.CompareHashAndPassword([]byte(u.PasswordHash), []byte(password))
|
||||
return err == nil
|
||||
}
|
||||
157
backend/internal/model/user_subscription.go
Normal file
157
backend/internal/model/user_subscription.go
Normal file
@@ -0,0 +1,157 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// 订阅状态常量
|
||||
const (
|
||||
SubscriptionStatusActive = "active"
|
||||
SubscriptionStatusExpired = "expired"
|
||||
SubscriptionStatusSuspended = "suspended"
|
||||
)
|
||||
|
||||
// UserSubscription 用户订阅模型
|
||||
type UserSubscription struct {
|
||||
ID int64 `gorm:"primaryKey" json:"id"`
|
||||
UserID int64 `gorm:"index;not null" json:"user_id"`
|
||||
GroupID int64 `gorm:"index;not null" json:"group_id"`
|
||||
|
||||
// 订阅有效期
|
||||
StartsAt time.Time `gorm:"not null" json:"starts_at"`
|
||||
ExpiresAt time.Time `gorm:"not null" json:"expires_at"`
|
||||
Status string `gorm:"size:20;default:active;not null" json:"status"` // active/expired/suspended
|
||||
|
||||
// 滑动窗口起始时间(nil = 未激活)
|
||||
DailyWindowStart *time.Time `json:"daily_window_start"`
|
||||
WeeklyWindowStart *time.Time `json:"weekly_window_start"`
|
||||
MonthlyWindowStart *time.Time `json:"monthly_window_start"`
|
||||
|
||||
// 当前窗口已用额度(USD,基于 total_cost 计算)
|
||||
DailyUsageUSD float64 `gorm:"type:decimal(20,10);default:0;not null" json:"daily_usage_usd"`
|
||||
WeeklyUsageUSD float64 `gorm:"type:decimal(20,10);default:0;not null" json:"weekly_usage_usd"`
|
||||
MonthlyUsageUSD float64 `gorm:"type:decimal(20,10);default:0;not null" json:"monthly_usage_usd"`
|
||||
|
||||
// 管理员分配信息
|
||||
AssignedBy *int64 `gorm:"index" json:"assigned_by"`
|
||||
AssignedAt time.Time `gorm:"not null" json:"assigned_at"`
|
||||
Notes string `gorm:"type:text" json:"notes"`
|
||||
|
||||
CreatedAt time.Time `gorm:"not null" json:"created_at"`
|
||||
UpdatedAt time.Time `gorm:"not null" json:"updated_at"`
|
||||
|
||||
// 关联
|
||||
User *User `gorm:"foreignKey:UserID" json:"user,omitempty"`
|
||||
Group *Group `gorm:"foreignKey:GroupID" json:"group,omitempty"`
|
||||
AssignedByUser *User `gorm:"foreignKey:AssignedBy" json:"assigned_by_user,omitempty"`
|
||||
}
|
||||
|
||||
func (UserSubscription) TableName() string {
|
||||
return "user_subscriptions"
|
||||
}
|
||||
|
||||
// IsActive 检查订阅是否有效(状态为active且未过期)
|
||||
func (s *UserSubscription) IsActive() bool {
|
||||
return s.Status == SubscriptionStatusActive && time.Now().Before(s.ExpiresAt)
|
||||
}
|
||||
|
||||
// IsExpired 检查订阅是否已过期
|
||||
func (s *UserSubscription) IsExpired() bool {
|
||||
return time.Now().After(s.ExpiresAt)
|
||||
}
|
||||
|
||||
// DaysRemaining 返回订阅剩余天数
|
||||
func (s *UserSubscription) DaysRemaining() int {
|
||||
if s.IsExpired() {
|
||||
return 0
|
||||
}
|
||||
return int(time.Until(s.ExpiresAt).Hours() / 24)
|
||||
}
|
||||
|
||||
// IsWindowActivated 检查窗口是否已激活
|
||||
func (s *UserSubscription) IsWindowActivated() bool {
|
||||
return s.DailyWindowStart != nil || s.WeeklyWindowStart != nil || s.MonthlyWindowStart != nil
|
||||
}
|
||||
|
||||
// NeedsDailyReset 检查日窗口是否需要重置
|
||||
func (s *UserSubscription) NeedsDailyReset() bool {
|
||||
if s.DailyWindowStart == nil {
|
||||
return false
|
||||
}
|
||||
return time.Since(*s.DailyWindowStart) >= 24*time.Hour
|
||||
}
|
||||
|
||||
// NeedsWeeklyReset 检查周窗口是否需要重置
|
||||
func (s *UserSubscription) NeedsWeeklyReset() bool {
|
||||
if s.WeeklyWindowStart == nil {
|
||||
return false
|
||||
}
|
||||
return time.Since(*s.WeeklyWindowStart) >= 7*24*time.Hour
|
||||
}
|
||||
|
||||
// NeedsMonthlyReset 检查月窗口是否需要重置
|
||||
func (s *UserSubscription) NeedsMonthlyReset() bool {
|
||||
if s.MonthlyWindowStart == nil {
|
||||
return false
|
||||
}
|
||||
return time.Since(*s.MonthlyWindowStart) >= 30*24*time.Hour
|
||||
}
|
||||
|
||||
// DailyResetTime 返回日窗口重置时间
|
||||
func (s *UserSubscription) DailyResetTime() *time.Time {
|
||||
if s.DailyWindowStart == nil {
|
||||
return nil
|
||||
}
|
||||
t := s.DailyWindowStart.Add(24 * time.Hour)
|
||||
return &t
|
||||
}
|
||||
|
||||
// WeeklyResetTime 返回周窗口重置时间
|
||||
func (s *UserSubscription) WeeklyResetTime() *time.Time {
|
||||
if s.WeeklyWindowStart == nil {
|
||||
return nil
|
||||
}
|
||||
t := s.WeeklyWindowStart.Add(7 * 24 * time.Hour)
|
||||
return &t
|
||||
}
|
||||
|
||||
// MonthlyResetTime 返回月窗口重置时间
|
||||
func (s *UserSubscription) MonthlyResetTime() *time.Time {
|
||||
if s.MonthlyWindowStart == nil {
|
||||
return nil
|
||||
}
|
||||
t := s.MonthlyWindowStart.Add(30 * 24 * time.Hour)
|
||||
return &t
|
||||
}
|
||||
|
||||
// CheckDailyLimit 检查是否超出日限额
|
||||
func (s *UserSubscription) CheckDailyLimit(group *Group, additionalCost float64) bool {
|
||||
if !group.HasDailyLimit() {
|
||||
return true // 无限制
|
||||
}
|
||||
return s.DailyUsageUSD+additionalCost <= *group.DailyLimitUSD
|
||||
}
|
||||
|
||||
// CheckWeeklyLimit 检查是否超出周限额
|
||||
func (s *UserSubscription) CheckWeeklyLimit(group *Group, additionalCost float64) bool {
|
||||
if !group.HasWeeklyLimit() {
|
||||
return true // 无限制
|
||||
}
|
||||
return s.WeeklyUsageUSD+additionalCost <= *group.WeeklyLimitUSD
|
||||
}
|
||||
|
||||
// CheckMonthlyLimit 检查是否超出月限额
|
||||
func (s *UserSubscription) CheckMonthlyLimit(group *Group, additionalCost float64) bool {
|
||||
if !group.HasMonthlyLimit() {
|
||||
return true // 无限制
|
||||
}
|
||||
return s.MonthlyUsageUSD+additionalCost <= *group.MonthlyLimitUSD
|
||||
}
|
||||
|
||||
// CheckAllLimits 检查所有限额
|
||||
func (s *UserSubscription) CheckAllLimits(group *Group, additionalCost float64) (daily, weekly, monthly bool) {
|
||||
daily = s.CheckDailyLimit(group, additionalCost)
|
||||
weekly = s.CheckWeeklyLimit(group, additionalCost)
|
||||
monthly = s.CheckMonthlyLimit(group, additionalCost)
|
||||
return
|
||||
}
|
||||
223
backend/internal/pkg/oauth/oauth.go
Normal file
223
backend/internal/pkg/oauth/oauth.go
Normal file
@@ -0,0 +1,223 @@
|
||||
package oauth
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Claude OAuth Constants (from CRS project)
|
||||
const (
|
||||
// OAuth Client ID for Claude
|
||||
ClientID = "9d1c250a-e61b-44d9-88ed-5944d1962f5e"
|
||||
|
||||
// OAuth endpoints
|
||||
AuthorizeURL = "https://claude.ai/oauth/authorize"
|
||||
TokenURL = "https://console.anthropic.com/v1/oauth/token"
|
||||
RedirectURI = "https://console.anthropic.com/oauth/code/callback"
|
||||
|
||||
// Scopes
|
||||
ScopeProfile = "user:profile"
|
||||
ScopeInference = "user:inference"
|
||||
|
||||
// Session TTL
|
||||
SessionTTL = 30 * time.Minute
|
||||
)
|
||||
|
||||
// OAuthSession stores OAuth flow state
|
||||
type OAuthSession struct {
|
||||
State string `json:"state"`
|
||||
CodeVerifier string `json:"code_verifier"`
|
||||
Scope string `json:"scope"`
|
||||
ProxyURL string `json:"proxy_url,omitempty"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
}
|
||||
|
||||
// SessionStore manages OAuth sessions in memory
|
||||
type SessionStore struct {
|
||||
mu sync.RWMutex
|
||||
sessions map[string]*OAuthSession
|
||||
}
|
||||
|
||||
// NewSessionStore creates a new session store
|
||||
func NewSessionStore() *SessionStore {
|
||||
store := &SessionStore{
|
||||
sessions: make(map[string]*OAuthSession),
|
||||
}
|
||||
// Start cleanup goroutine
|
||||
go store.cleanup()
|
||||
return store
|
||||
}
|
||||
|
||||
// Set stores a session
|
||||
func (s *SessionStore) Set(sessionID string, session *OAuthSession) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.sessions[sessionID] = session
|
||||
}
|
||||
|
||||
// Get retrieves a session
|
||||
func (s *SessionStore) Get(sessionID string) (*OAuthSession, bool) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
session, ok := s.sessions[sessionID]
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
// Check if expired
|
||||
if time.Since(session.CreatedAt) > SessionTTL {
|
||||
return nil, false
|
||||
}
|
||||
return session, true
|
||||
}
|
||||
|
||||
// Delete removes a session
|
||||
func (s *SessionStore) Delete(sessionID string) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
delete(s.sessions, sessionID)
|
||||
}
|
||||
|
||||
// cleanup removes expired sessions periodically
|
||||
func (s *SessionStore) cleanup() {
|
||||
ticker := time.NewTicker(5 * time.Minute)
|
||||
for range ticker.C {
|
||||
s.mu.Lock()
|
||||
for id, session := range s.sessions {
|
||||
if time.Since(session.CreatedAt) > SessionTTL {
|
||||
delete(s.sessions, id)
|
||||
}
|
||||
}
|
||||
s.mu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
// GenerateRandomBytes generates cryptographically secure random bytes
|
||||
func GenerateRandomBytes(n int) ([]byte, error) {
|
||||
b := make([]byte, n)
|
||||
_, err := rand.Read(b)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return b, nil
|
||||
}
|
||||
|
||||
// GenerateState generates a random state string for OAuth
|
||||
func GenerateState() (string, error) {
|
||||
bytes, err := GenerateRandomBytes(32)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return hex.EncodeToString(bytes), nil
|
||||
}
|
||||
|
||||
// GenerateSessionID generates a unique session ID
|
||||
func GenerateSessionID() (string, error) {
|
||||
bytes, err := GenerateRandomBytes(16)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return hex.EncodeToString(bytes), nil
|
||||
}
|
||||
|
||||
// GenerateCodeVerifier generates a PKCE code verifier (32 bytes -> base64url)
|
||||
func GenerateCodeVerifier() (string, error) {
|
||||
bytes, err := GenerateRandomBytes(32)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return base64URLEncode(bytes), nil
|
||||
}
|
||||
|
||||
// GenerateCodeChallenge generates a PKCE code challenge using S256 method
|
||||
func GenerateCodeChallenge(verifier string) string {
|
||||
hash := sha256.Sum256([]byte(verifier))
|
||||
return base64URLEncode(hash[:])
|
||||
}
|
||||
|
||||
// base64URLEncode encodes bytes to base64url without padding
|
||||
func base64URLEncode(data []byte) string {
|
||||
encoded := base64.URLEncoding.EncodeToString(data)
|
||||
// Remove padding
|
||||
return strings.TrimRight(encoded, "=")
|
||||
}
|
||||
|
||||
// BuildAuthorizationURL builds the OAuth authorization URL
|
||||
func BuildAuthorizationURL(state, codeChallenge, scope string) string {
|
||||
params := url.Values{}
|
||||
params.Set("response_type", "code")
|
||||
params.Set("client_id", ClientID)
|
||||
params.Set("redirect_uri", RedirectURI)
|
||||
params.Set("scope", scope)
|
||||
params.Set("state", state)
|
||||
params.Set("code_challenge", codeChallenge)
|
||||
params.Set("code_challenge_method", "S256")
|
||||
|
||||
return fmt.Sprintf("%s?%s", AuthorizeURL, params.Encode())
|
||||
}
|
||||
|
||||
// TokenRequest represents the token exchange request body
|
||||
type TokenRequest struct {
|
||||
GrantType string `json:"grant_type"`
|
||||
ClientID string `json:"client_id"`
|
||||
Code string `json:"code"`
|
||||
RedirectURI string `json:"redirect_uri"`
|
||||
CodeVerifier string `json:"code_verifier"`
|
||||
State string `json:"state"`
|
||||
}
|
||||
|
||||
// TokenResponse represents the token response from OAuth provider
|
||||
type TokenResponse struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
TokenType string `json:"token_type"`
|
||||
ExpiresIn int64 `json:"expires_in"`
|
||||
RefreshToken string `json:"refresh_token,omitempty"`
|
||||
Scope string `json:"scope,omitempty"`
|
||||
// Organization and Account info from OAuth response
|
||||
Organization *OrgInfo `json:"organization,omitempty"`
|
||||
Account *AccountInfo `json:"account,omitempty"`
|
||||
}
|
||||
|
||||
// OrgInfo represents organization info from OAuth response
|
||||
type OrgInfo struct {
|
||||
UUID string `json:"uuid"`
|
||||
}
|
||||
|
||||
// AccountInfo represents account info from OAuth response
|
||||
type AccountInfo struct {
|
||||
UUID string `json:"uuid"`
|
||||
}
|
||||
|
||||
// RefreshTokenRequest represents the refresh token request
|
||||
type RefreshTokenRequest struct {
|
||||
GrantType string `json:"grant_type"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
ClientID string `json:"client_id"`
|
||||
}
|
||||
|
||||
// BuildTokenRequest creates a token exchange request
|
||||
func BuildTokenRequest(code, codeVerifier, state string) *TokenRequest {
|
||||
return &TokenRequest{
|
||||
GrantType: "authorization_code",
|
||||
ClientID: ClientID,
|
||||
Code: code,
|
||||
RedirectURI: RedirectURI,
|
||||
CodeVerifier: codeVerifier,
|
||||
State: state,
|
||||
}
|
||||
}
|
||||
|
||||
// BuildRefreshTokenRequest creates a refresh token request
|
||||
func BuildRefreshTokenRequest(refreshToken string) *RefreshTokenRequest {
|
||||
return &RefreshTokenRequest{
|
||||
GrantType: "refresh_token",
|
||||
RefreshToken: refreshToken,
|
||||
ClientID: ClientID,
|
||||
}
|
||||
}
|
||||
157
backend/internal/pkg/response/response.go
Normal file
157
backend/internal/pkg/response/response.go
Normal file
@@ -0,0 +1,157 @@
|
||||
package response
|
||||
|
||||
import (
|
||||
"math"
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// Response 标准API响应格式
|
||||
type Response struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Data interface{} `json:"data,omitempty"`
|
||||
}
|
||||
|
||||
// PaginatedData 分页数据格式(匹配前端期望)
|
||||
type PaginatedData struct {
|
||||
Items interface{} `json:"items"`
|
||||
Total int64 `json:"total"`
|
||||
Page int `json:"page"`
|
||||
PageSize int `json:"page_size"`
|
||||
Pages int `json:"pages"`
|
||||
}
|
||||
|
||||
// Success 返回成功响应
|
||||
func Success(c *gin.Context, data interface{}) {
|
||||
c.JSON(http.StatusOK, Response{
|
||||
Code: 0,
|
||||
Message: "success",
|
||||
Data: data,
|
||||
})
|
||||
}
|
||||
|
||||
// Created 返回创建成功响应
|
||||
func Created(c *gin.Context, data interface{}) {
|
||||
c.JSON(http.StatusCreated, Response{
|
||||
Code: 0,
|
||||
Message: "success",
|
||||
Data: data,
|
||||
})
|
||||
}
|
||||
|
||||
// Error 返回错误响应
|
||||
func Error(c *gin.Context, statusCode int, message string) {
|
||||
c.JSON(statusCode, Response{
|
||||
Code: statusCode,
|
||||
Message: message,
|
||||
})
|
||||
}
|
||||
|
||||
// BadRequest 返回400错误
|
||||
func BadRequest(c *gin.Context, message string) {
|
||||
Error(c, http.StatusBadRequest, message)
|
||||
}
|
||||
|
||||
// Unauthorized 返回401错误
|
||||
func Unauthorized(c *gin.Context, message string) {
|
||||
Error(c, http.StatusUnauthorized, message)
|
||||
}
|
||||
|
||||
// Forbidden 返回403错误
|
||||
func Forbidden(c *gin.Context, message string) {
|
||||
Error(c, http.StatusForbidden, message)
|
||||
}
|
||||
|
||||
// NotFound 返回404错误
|
||||
func NotFound(c *gin.Context, message string) {
|
||||
Error(c, http.StatusNotFound, message)
|
||||
}
|
||||
|
||||
// InternalError 返回500错误
|
||||
func InternalError(c *gin.Context, message string) {
|
||||
Error(c, http.StatusInternalServerError, message)
|
||||
}
|
||||
|
||||
// Paginated 返回分页数据
|
||||
func Paginated(c *gin.Context, items interface{}, total int64, page, pageSize int) {
|
||||
pages := int(math.Ceil(float64(total) / float64(pageSize)))
|
||||
if pages < 1 {
|
||||
pages = 1
|
||||
}
|
||||
|
||||
Success(c, PaginatedData{
|
||||
Items: items,
|
||||
Total: total,
|
||||
Page: page,
|
||||
PageSize: pageSize,
|
||||
Pages: pages,
|
||||
})
|
||||
}
|
||||
|
||||
// PaginationResult 分页结果(与repository.PaginationResult兼容)
|
||||
type PaginationResult struct {
|
||||
Total int64
|
||||
Page int
|
||||
PageSize int
|
||||
Pages int
|
||||
}
|
||||
|
||||
// PaginatedWithResult 使用PaginationResult返回分页数据
|
||||
func PaginatedWithResult(c *gin.Context, items interface{}, pagination *PaginationResult) {
|
||||
if pagination == nil {
|
||||
Success(c, PaginatedData{
|
||||
Items: items,
|
||||
Total: 0,
|
||||
Page: 1,
|
||||
PageSize: 20,
|
||||
Pages: 1,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
Success(c, PaginatedData{
|
||||
Items: items,
|
||||
Total: pagination.Total,
|
||||
Page: pagination.Page,
|
||||
PageSize: pagination.PageSize,
|
||||
Pages: pagination.Pages,
|
||||
})
|
||||
}
|
||||
|
||||
// ParsePagination 解析分页参数
|
||||
func ParsePagination(c *gin.Context) (page, pageSize int) {
|
||||
page = 1
|
||||
pageSize = 20
|
||||
|
||||
if p := c.Query("page"); p != "" {
|
||||
if val, err := parseInt(p); err == nil && val > 0 {
|
||||
page = val
|
||||
}
|
||||
}
|
||||
|
||||
// 支持 page_size 和 limit 两种参数名
|
||||
if ps := c.Query("page_size"); ps != "" {
|
||||
if val, err := parseInt(ps); err == nil && val > 0 && val <= 100 {
|
||||
pageSize = val
|
||||
}
|
||||
} else if l := c.Query("limit"); l != "" {
|
||||
if val, err := parseInt(l); err == nil && val > 0 && val <= 100 {
|
||||
pageSize = val
|
||||
}
|
||||
}
|
||||
|
||||
return page, pageSize
|
||||
}
|
||||
|
||||
func parseInt(s string) (int, error) {
|
||||
var result int
|
||||
for _, c := range s {
|
||||
if c < '0' || c > '9' {
|
||||
return 0, nil
|
||||
}
|
||||
result = result*10 + int(c-'0')
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
124
backend/internal/pkg/timezone/timezone.go
Normal file
124
backend/internal/pkg/timezone/timezone.go
Normal file
@@ -0,0 +1,124 @@
|
||||
// Package timezone provides global timezone management for the application.
|
||||
// Similar to PHP's date_default_timezone_set, this package allows setting
|
||||
// a global timezone that affects all time.Now() calls.
|
||||
package timezone
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
// location is the global timezone location
|
||||
location *time.Location
|
||||
// tzName stores the timezone name for logging/debugging
|
||||
tzName string
|
||||
)
|
||||
|
||||
// Init initializes the global timezone setting.
|
||||
// This should be called once at application startup.
|
||||
// Example timezone values: "Asia/Shanghai", "America/New_York", "UTC"
|
||||
func Init(tz string) error {
|
||||
if tz == "" {
|
||||
tz = "Asia/Shanghai" // Default timezone
|
||||
}
|
||||
|
||||
loc, err := time.LoadLocation(tz)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid timezone %q: %w", tz, err)
|
||||
}
|
||||
|
||||
// Set the global Go time.Local to our timezone
|
||||
// This affects time.Now() throughout the application
|
||||
time.Local = loc
|
||||
location = loc
|
||||
tzName = tz
|
||||
|
||||
log.Printf("Timezone initialized: %s (UTC offset: %s)", tz, getUTCOffset(loc))
|
||||
return nil
|
||||
}
|
||||
|
||||
// getUTCOffset returns the current UTC offset for a location
|
||||
func getUTCOffset(loc *time.Location) string {
|
||||
_, offset := time.Now().In(loc).Zone()
|
||||
hours := offset / 3600
|
||||
minutes := (offset % 3600) / 60
|
||||
if minutes < 0 {
|
||||
minutes = -minutes
|
||||
}
|
||||
sign := "+"
|
||||
if hours < 0 {
|
||||
sign = "-"
|
||||
hours = -hours
|
||||
}
|
||||
return fmt.Sprintf("%s%02d:%02d", sign, hours, minutes)
|
||||
}
|
||||
|
||||
// Now returns the current time in the configured timezone.
|
||||
// This is equivalent to time.Now() after Init() is called,
|
||||
// but provided for explicit timezone-aware code.
|
||||
func Now() time.Time {
|
||||
if location == nil {
|
||||
return time.Now()
|
||||
}
|
||||
return time.Now().In(location)
|
||||
}
|
||||
|
||||
// Location returns the configured timezone location.
|
||||
func Location() *time.Location {
|
||||
if location == nil {
|
||||
return time.Local
|
||||
}
|
||||
return location
|
||||
}
|
||||
|
||||
// Name returns the configured timezone name.
|
||||
func Name() string {
|
||||
if tzName == "" {
|
||||
return "Local"
|
||||
}
|
||||
return tzName
|
||||
}
|
||||
|
||||
// StartOfDay returns the start of the given day (00:00:00) in the configured timezone.
|
||||
func StartOfDay(t time.Time) time.Time {
|
||||
loc := Location()
|
||||
t = t.In(loc)
|
||||
return time.Date(t.Year(), t.Month(), t.Day(), 0, 0, 0, 0, loc)
|
||||
}
|
||||
|
||||
// Today returns the start of today (00:00:00) in the configured timezone.
|
||||
func Today() time.Time {
|
||||
return StartOfDay(Now())
|
||||
}
|
||||
|
||||
// EndOfDay returns the end of the given day (23:59:59.999999999) in the configured timezone.
|
||||
func EndOfDay(t time.Time) time.Time {
|
||||
loc := Location()
|
||||
t = t.In(loc)
|
||||
return time.Date(t.Year(), t.Month(), t.Day(), 23, 59, 59, 999999999, loc)
|
||||
}
|
||||
|
||||
// StartOfWeek returns the start of the week (Monday 00:00:00) for the given time.
|
||||
func StartOfWeek(t time.Time) time.Time {
|
||||
loc := Location()
|
||||
t = t.In(loc)
|
||||
weekday := int(t.Weekday())
|
||||
if weekday == 0 {
|
||||
weekday = 7 // Sunday is day 7
|
||||
}
|
||||
return time.Date(t.Year(), t.Month(), t.Day()-weekday+1, 0, 0, 0, 0, loc)
|
||||
}
|
||||
|
||||
// StartOfMonth returns the start of the month (1st day 00:00:00) for the given time.
|
||||
func StartOfMonth(t time.Time) time.Time {
|
||||
loc := Location()
|
||||
t = t.In(loc)
|
||||
return time.Date(t.Year(), t.Month(), 1, 0, 0, 0, 0, loc)
|
||||
}
|
||||
|
||||
// ParseInLocation parses a time string in the configured timezone.
|
||||
func ParseInLocation(layout, value string) (time.Time, error) {
|
||||
return time.ParseInLocation(layout, value, Location())
|
||||
}
|
||||
127
backend/internal/pkg/timezone/timezone_test.go
Normal file
127
backend/internal/pkg/timezone/timezone_test.go
Normal file
@@ -0,0 +1,127 @@
|
||||
package timezone
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestInit(t *testing.T) {
|
||||
// Test with valid timezone
|
||||
err := Init("Asia/Shanghai")
|
||||
if err != nil {
|
||||
t.Fatalf("Init failed with valid timezone: %v", err)
|
||||
}
|
||||
|
||||
// Verify time.Local was set
|
||||
if time.Local.String() != "Asia/Shanghai" {
|
||||
t.Errorf("time.Local not set correctly, got %s", time.Local.String())
|
||||
}
|
||||
|
||||
// Verify our location variable
|
||||
if Location().String() != "Asia/Shanghai" {
|
||||
t.Errorf("Location() not set correctly, got %s", Location().String())
|
||||
}
|
||||
|
||||
// Test Name()
|
||||
if Name() != "Asia/Shanghai" {
|
||||
t.Errorf("Name() not set correctly, got %s", Name())
|
||||
}
|
||||
}
|
||||
|
||||
func TestInitInvalidTimezone(t *testing.T) {
|
||||
err := Init("Invalid/Timezone")
|
||||
if err == nil {
|
||||
t.Error("Init should fail with invalid timezone")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTimeNowAffected(t *testing.T) {
|
||||
// Reset to UTC first
|
||||
Init("UTC")
|
||||
utcNow := time.Now()
|
||||
|
||||
// Switch to Shanghai (UTC+8)
|
||||
Init("Asia/Shanghai")
|
||||
shanghaiNow := time.Now()
|
||||
|
||||
// The times should be the same instant, but different timezone representation
|
||||
// Shanghai should be 8 hours ahead in display
|
||||
_, utcOffset := utcNow.Zone()
|
||||
_, shanghaiOffset := shanghaiNow.Zone()
|
||||
|
||||
expectedDiff := 8 * 3600 // 8 hours in seconds
|
||||
actualDiff := shanghaiOffset - utcOffset
|
||||
|
||||
if actualDiff != expectedDiff {
|
||||
t.Errorf("Timezone offset difference incorrect: expected %d, got %d", expectedDiff, actualDiff)
|
||||
}
|
||||
}
|
||||
|
||||
func TestToday(t *testing.T) {
|
||||
Init("Asia/Shanghai")
|
||||
|
||||
today := Today()
|
||||
now := Now()
|
||||
|
||||
// Today should be at 00:00:00
|
||||
if today.Hour() != 0 || today.Minute() != 0 || today.Second() != 0 {
|
||||
t.Errorf("Today() not at start of day: %v", today)
|
||||
}
|
||||
|
||||
// Today should be same date as now
|
||||
if today.Year() != now.Year() || today.Month() != now.Month() || today.Day() != now.Day() {
|
||||
t.Errorf("Today() date mismatch: today=%v, now=%v", today, now)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStartOfDay(t *testing.T) {
|
||||
Init("Asia/Shanghai")
|
||||
|
||||
// Create a time at 15:30:45
|
||||
testTime := time.Date(2024, 6, 15, 15, 30, 45, 123456789, Location())
|
||||
startOfDay := StartOfDay(testTime)
|
||||
|
||||
expected := time.Date(2024, 6, 15, 0, 0, 0, 0, Location())
|
||||
if !startOfDay.Equal(expected) {
|
||||
t.Errorf("StartOfDay incorrect: expected %v, got %v", expected, startOfDay)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTruncateVsStartOfDay(t *testing.T) {
|
||||
// This test demonstrates why Truncate(24*time.Hour) can be problematic
|
||||
// and why StartOfDay is more reliable for timezone-aware code
|
||||
|
||||
Init("Asia/Shanghai")
|
||||
|
||||
now := Now()
|
||||
|
||||
// Truncate operates on UTC, not local time
|
||||
truncated := now.Truncate(24 * time.Hour)
|
||||
|
||||
// StartOfDay operates on local time
|
||||
startOfDay := StartOfDay(now)
|
||||
|
||||
// These will likely be different for non-UTC timezones
|
||||
t.Logf("Now: %v", now)
|
||||
t.Logf("Truncate(24h): %v", truncated)
|
||||
t.Logf("StartOfDay: %v", startOfDay)
|
||||
|
||||
// The truncated time may not be at local midnight
|
||||
// StartOfDay is always at local midnight
|
||||
if startOfDay.Hour() != 0 {
|
||||
t.Errorf("StartOfDay should be at hour 0, got %d", startOfDay.Hour())
|
||||
}
|
||||
}
|
||||
|
||||
func TestDSTAwareness(t *testing.T) {
|
||||
// Test with a timezone that has DST (America/New_York)
|
||||
err := Init("America/New_York")
|
||||
if err != nil {
|
||||
t.Skipf("America/New_York timezone not available: %v", err)
|
||||
}
|
||||
|
||||
// Just verify it doesn't crash
|
||||
_ = Today()
|
||||
_ = Now()
|
||||
_ = StartOfDay(Now())
|
||||
}
|
||||
268
backend/internal/repository/account_repo.go
Normal file
268
backend/internal/repository/account_repo.go
Normal file
@@ -0,0 +1,268 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sub2api/internal/model"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type AccountRepository struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
func NewAccountRepository(db *gorm.DB) *AccountRepository {
|
||||
return &AccountRepository{db: db}
|
||||
}
|
||||
|
||||
func (r *AccountRepository) Create(ctx context.Context, account *model.Account) error {
|
||||
return r.db.WithContext(ctx).Create(account).Error
|
||||
}
|
||||
|
||||
func (r *AccountRepository) GetByID(ctx context.Context, id int64) (*model.Account, error) {
|
||||
var account model.Account
|
||||
err := r.db.WithContext(ctx).Preload("Proxy").Preload("AccountGroups").First(&account, id).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// 填充 GroupIDs 虚拟字段
|
||||
account.GroupIDs = make([]int64, 0, len(account.AccountGroups))
|
||||
for _, ag := range account.AccountGroups {
|
||||
account.GroupIDs = append(account.GroupIDs, ag.GroupID)
|
||||
}
|
||||
return &account, nil
|
||||
}
|
||||
|
||||
func (r *AccountRepository) Update(ctx context.Context, account *model.Account) error {
|
||||
return r.db.WithContext(ctx).Save(account).Error
|
||||
}
|
||||
|
||||
func (r *AccountRepository) Delete(ctx context.Context, id int64) error {
|
||||
// 先删除账号与分组的绑定关系
|
||||
if err := r.db.WithContext(ctx).Where("account_id = ?", id).Delete(&model.AccountGroup{}).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
// 再删除账号
|
||||
return r.db.WithContext(ctx).Delete(&model.Account{}, id).Error
|
||||
}
|
||||
|
||||
func (r *AccountRepository) List(ctx context.Context, params PaginationParams) ([]model.Account, *PaginationResult, error) {
|
||||
return r.ListWithFilters(ctx, params, "", "", "", "")
|
||||
}
|
||||
|
||||
// ListWithFilters lists accounts with optional filtering by platform, type, status, and search query
|
||||
func (r *AccountRepository) ListWithFilters(ctx context.Context, params PaginationParams, platform, accountType, status, search string) ([]model.Account, *PaginationResult, error) {
|
||||
var accounts []model.Account
|
||||
var total int64
|
||||
|
||||
db := r.db.WithContext(ctx).Model(&model.Account{})
|
||||
|
||||
// Apply filters
|
||||
if platform != "" {
|
||||
db = db.Where("platform = ?", platform)
|
||||
}
|
||||
if accountType != "" {
|
||||
db = db.Where("type = ?", accountType)
|
||||
}
|
||||
if status != "" {
|
||||
db = db.Where("status = ?", status)
|
||||
}
|
||||
if search != "" {
|
||||
searchPattern := "%" + search + "%"
|
||||
db = db.Where("name ILIKE ?", searchPattern)
|
||||
}
|
||||
|
||||
if err := db.Count(&total).Error; err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
if err := db.Preload("Proxy").Preload("AccountGroups").Offset(params.Offset()).Limit(params.Limit()).Order("id DESC").Find(&accounts).Error; err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// 填充每个 Account 的 GroupIDs 虚拟字段
|
||||
for i := range accounts {
|
||||
accounts[i].GroupIDs = make([]int64, 0, len(accounts[i].AccountGroups))
|
||||
for _, ag := range accounts[i].AccountGroups {
|
||||
accounts[i].GroupIDs = append(accounts[i].GroupIDs, ag.GroupID)
|
||||
}
|
||||
}
|
||||
|
||||
pages := int(total) / params.Limit()
|
||||
if int(total)%params.Limit() > 0 {
|
||||
pages++
|
||||
}
|
||||
|
||||
return accounts, &PaginationResult{
|
||||
Total: total,
|
||||
Page: params.Page,
|
||||
PageSize: params.Limit(),
|
||||
Pages: pages,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (r *AccountRepository) ListByGroup(ctx context.Context, groupID int64) ([]model.Account, error) {
|
||||
var accounts []model.Account
|
||||
err := r.db.WithContext(ctx).
|
||||
Joins("JOIN account_groups ON account_groups.account_id = accounts.id").
|
||||
Where("account_groups.group_id = ? AND accounts.status = ?", groupID, model.StatusActive).
|
||||
Preload("Proxy").
|
||||
Order("account_groups.priority ASC, accounts.priority ASC").
|
||||
Find(&accounts).Error
|
||||
return accounts, err
|
||||
}
|
||||
|
||||
func (r *AccountRepository) ListActive(ctx context.Context) ([]model.Account, error) {
|
||||
var accounts []model.Account
|
||||
err := r.db.WithContext(ctx).
|
||||
Where("status = ?", model.StatusActive).
|
||||
Preload("Proxy").
|
||||
Order("priority ASC").
|
||||
Find(&accounts).Error
|
||||
return accounts, err
|
||||
}
|
||||
|
||||
func (r *AccountRepository) UpdateLastUsed(ctx context.Context, id int64) error {
|
||||
now := time.Now()
|
||||
return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id).Update("last_used_at", now).Error
|
||||
}
|
||||
|
||||
func (r *AccountRepository) SetError(ctx context.Context, id int64, errorMsg string) error {
|
||||
return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id).
|
||||
Updates(map[string]interface{}{
|
||||
"status": model.StatusError,
|
||||
"error_message": errorMsg,
|
||||
}).Error
|
||||
}
|
||||
|
||||
func (r *AccountRepository) AddToGroup(ctx context.Context, accountID, groupID int64, priority int) error {
|
||||
ag := &model.AccountGroup{
|
||||
AccountID: accountID,
|
||||
GroupID: groupID,
|
||||
Priority: priority,
|
||||
}
|
||||
return r.db.WithContext(ctx).Create(ag).Error
|
||||
}
|
||||
|
||||
func (r *AccountRepository) RemoveFromGroup(ctx context.Context, accountID, groupID int64) error {
|
||||
return r.db.WithContext(ctx).Where("account_id = ? AND group_id = ?", accountID, groupID).
|
||||
Delete(&model.AccountGroup{}).Error
|
||||
}
|
||||
|
||||
func (r *AccountRepository) GetGroups(ctx context.Context, accountID int64) ([]model.Group, error) {
|
||||
var groups []model.Group
|
||||
err := r.db.WithContext(ctx).
|
||||
Joins("JOIN account_groups ON account_groups.group_id = groups.id").
|
||||
Where("account_groups.account_id = ?", accountID).
|
||||
Find(&groups).Error
|
||||
return groups, err
|
||||
}
|
||||
|
||||
func (r *AccountRepository) ListByPlatform(ctx context.Context, platform string) ([]model.Account, error) {
|
||||
var accounts []model.Account
|
||||
err := r.db.WithContext(ctx).
|
||||
Where("platform = ? AND status = ?", platform, model.StatusActive).
|
||||
Preload("Proxy").
|
||||
Order("priority ASC").
|
||||
Find(&accounts).Error
|
||||
return accounts, err
|
||||
}
|
||||
|
||||
func (r *AccountRepository) BindGroups(ctx context.Context, accountID int64, groupIDs []int64) error {
|
||||
// 删除现有绑定
|
||||
if err := r.db.WithContext(ctx).Where("account_id = ?", accountID).Delete(&model.AccountGroup{}).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 添加新绑定
|
||||
if len(groupIDs) > 0 {
|
||||
accountGroups := make([]model.AccountGroup, 0, len(groupIDs))
|
||||
for i, groupID := range groupIDs {
|
||||
accountGroups = append(accountGroups, model.AccountGroup{
|
||||
AccountID: accountID,
|
||||
GroupID: groupID,
|
||||
Priority: i + 1, // 使用索引作为优先级
|
||||
})
|
||||
}
|
||||
return r.db.WithContext(ctx).Create(&accountGroups).Error
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ListSchedulable 获取所有可调度的账号
|
||||
func (r *AccountRepository) ListSchedulable(ctx context.Context) ([]model.Account, error) {
|
||||
var accounts []model.Account
|
||||
now := time.Now()
|
||||
err := r.db.WithContext(ctx).
|
||||
Where("status = ? AND schedulable = ?", model.StatusActive, true).
|
||||
Where("(overload_until IS NULL OR overload_until <= ?)", now).
|
||||
Where("(rate_limit_reset_at IS NULL OR rate_limit_reset_at <= ?)", now).
|
||||
Preload("Proxy").
|
||||
Order("priority ASC").
|
||||
Find(&accounts).Error
|
||||
return accounts, err
|
||||
}
|
||||
|
||||
// ListSchedulableByGroupID 按组获取可调度的账号
|
||||
func (r *AccountRepository) ListSchedulableByGroupID(ctx context.Context, groupID int64) ([]model.Account, error) {
|
||||
var accounts []model.Account
|
||||
now := time.Now()
|
||||
err := r.db.WithContext(ctx).
|
||||
Joins("JOIN account_groups ON account_groups.account_id = accounts.id").
|
||||
Where("account_groups.group_id = ?", groupID).
|
||||
Where("accounts.status = ? AND accounts.schedulable = ?", model.StatusActive, true).
|
||||
Where("(accounts.overload_until IS NULL OR accounts.overload_until <= ?)", now).
|
||||
Where("(accounts.rate_limit_reset_at IS NULL OR accounts.rate_limit_reset_at <= ?)", now).
|
||||
Preload("Proxy").
|
||||
Order("account_groups.priority ASC, accounts.priority ASC").
|
||||
Find(&accounts).Error
|
||||
return accounts, err
|
||||
}
|
||||
|
||||
// SetRateLimited 标记账号为限流状态(429)
|
||||
func (r *AccountRepository) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error {
|
||||
now := time.Now()
|
||||
return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id).
|
||||
Updates(map[string]interface{}{
|
||||
"rate_limited_at": now,
|
||||
"rate_limit_reset_at": resetAt,
|
||||
}).Error
|
||||
}
|
||||
|
||||
// SetOverloaded 标记账号为过载状态(529)
|
||||
func (r *AccountRepository) SetOverloaded(ctx context.Context, id int64, until time.Time) error {
|
||||
return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id).
|
||||
Update("overload_until", until).Error
|
||||
}
|
||||
|
||||
// ClearRateLimit 清除账号的限流状态
|
||||
func (r *AccountRepository) ClearRateLimit(ctx context.Context, id int64) error {
|
||||
return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id).
|
||||
Updates(map[string]interface{}{
|
||||
"rate_limited_at": nil,
|
||||
"rate_limit_reset_at": nil,
|
||||
"overload_until": nil,
|
||||
}).Error
|
||||
}
|
||||
|
||||
// UpdateSessionWindow 更新账号的5小时时间窗口信息
|
||||
func (r *AccountRepository) UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error {
|
||||
updates := map[string]interface{}{
|
||||
"session_window_status": status,
|
||||
}
|
||||
if start != nil {
|
||||
updates["session_window_start"] = start
|
||||
}
|
||||
if end != nil {
|
||||
updates["session_window_end"] = end
|
||||
}
|
||||
return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id).Updates(updates).Error
|
||||
}
|
||||
|
||||
// SetSchedulable 设置账号的调度开关
|
||||
func (r *AccountRepository) SetSchedulable(ctx context.Context, id int64, schedulable bool) error {
|
||||
return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id).
|
||||
Update("schedulable", schedulable).Error
|
||||
}
|
||||
149
backend/internal/repository/api_key_repo.go
Normal file
149
backend/internal/repository/api_key_repo.go
Normal file
@@ -0,0 +1,149 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sub2api/internal/model"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type ApiKeyRepository struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
func NewApiKeyRepository(db *gorm.DB) *ApiKeyRepository {
|
||||
return &ApiKeyRepository{db: db}
|
||||
}
|
||||
|
||||
func (r *ApiKeyRepository) Create(ctx context.Context, key *model.ApiKey) error {
|
||||
return r.db.WithContext(ctx).Create(key).Error
|
||||
}
|
||||
|
||||
func (r *ApiKeyRepository) GetByID(ctx context.Context, id int64) (*model.ApiKey, error) {
|
||||
var key model.ApiKey
|
||||
err := r.db.WithContext(ctx).Preload("User").Preload("Group").First(&key, id).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &key, nil
|
||||
}
|
||||
|
||||
func (r *ApiKeyRepository) GetByKey(ctx context.Context, key string) (*model.ApiKey, error) {
|
||||
var apiKey model.ApiKey
|
||||
err := r.db.WithContext(ctx).Preload("User").Preload("Group").Where("key = ?", key).First(&apiKey).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &apiKey, nil
|
||||
}
|
||||
|
||||
func (r *ApiKeyRepository) Update(ctx context.Context, key *model.ApiKey) error {
|
||||
return r.db.WithContext(ctx).Model(key).Select("name", "group_id", "status", "updated_at").Updates(key).Error
|
||||
}
|
||||
|
||||
func (r *ApiKeyRepository) Delete(ctx context.Context, id int64) error {
|
||||
return r.db.WithContext(ctx).Delete(&model.ApiKey{}, id).Error
|
||||
}
|
||||
|
||||
func (r *ApiKeyRepository) ListByUserID(ctx context.Context, userID int64, params PaginationParams) ([]model.ApiKey, *PaginationResult, error) {
|
||||
var keys []model.ApiKey
|
||||
var total int64
|
||||
|
||||
db := r.db.WithContext(ctx).Model(&model.ApiKey{}).Where("user_id = ?", userID)
|
||||
|
||||
if err := db.Count(&total).Error; err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
if err := db.Preload("Group").Offset(params.Offset()).Limit(params.Limit()).Order("id DESC").Find(&keys).Error; err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
pages := int(total) / params.Limit()
|
||||
if int(total)%params.Limit() > 0 {
|
||||
pages++
|
||||
}
|
||||
|
||||
return keys, &PaginationResult{
|
||||
Total: total,
|
||||
Page: params.Page,
|
||||
PageSize: params.Limit(),
|
||||
Pages: pages,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (r *ApiKeyRepository) CountByUserID(ctx context.Context, userID int64) (int64, error) {
|
||||
var count int64
|
||||
err := r.db.WithContext(ctx).Model(&model.ApiKey{}).Where("user_id = ?", userID).Count(&count).Error
|
||||
return count, err
|
||||
}
|
||||
|
||||
func (r *ApiKeyRepository) ExistsByKey(ctx context.Context, key string) (bool, error) {
|
||||
var count int64
|
||||
err := r.db.WithContext(ctx).Model(&model.ApiKey{}).Where("key = ?", key).Count(&count).Error
|
||||
return count > 0, err
|
||||
}
|
||||
|
||||
func (r *ApiKeyRepository) ListByGroupID(ctx context.Context, groupID int64, params PaginationParams) ([]model.ApiKey, *PaginationResult, error) {
|
||||
var keys []model.ApiKey
|
||||
var total int64
|
||||
|
||||
db := r.db.WithContext(ctx).Model(&model.ApiKey{}).Where("group_id = ?", groupID)
|
||||
|
||||
if err := db.Count(&total).Error; err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
if err := db.Preload("User").Offset(params.Offset()).Limit(params.Limit()).Order("id DESC").Find(&keys).Error; err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
pages := int(total) / params.Limit()
|
||||
if int(total)%params.Limit() > 0 {
|
||||
pages++
|
||||
}
|
||||
|
||||
return keys, &PaginationResult{
|
||||
Total: total,
|
||||
Page: params.Page,
|
||||
PageSize: params.Limit(),
|
||||
Pages: pages,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// SearchApiKeys searches API keys by user ID and/or keyword (name)
|
||||
func (r *ApiKeyRepository) SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]model.ApiKey, error) {
|
||||
var keys []model.ApiKey
|
||||
|
||||
db := r.db.WithContext(ctx).Model(&model.ApiKey{})
|
||||
|
||||
if userID > 0 {
|
||||
db = db.Where("user_id = ?", userID)
|
||||
}
|
||||
|
||||
if keyword != "" {
|
||||
searchPattern := "%" + keyword + "%"
|
||||
db = db.Where("name ILIKE ?", searchPattern)
|
||||
}
|
||||
|
||||
if err := db.Limit(limit).Order("id DESC").Find(&keys).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return keys, nil
|
||||
}
|
||||
|
||||
// ClearGroupIDByGroupID 将指定分组的所有 API Key 的 group_id 设为 nil
|
||||
func (r *ApiKeyRepository) ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error) {
|
||||
result := r.db.WithContext(ctx).Model(&model.ApiKey{}).
|
||||
Where("group_id = ?", groupID).
|
||||
Update("group_id", nil)
|
||||
return result.RowsAffected, result.Error
|
||||
}
|
||||
|
||||
// CountByGroupID 获取分组的 API Key 数量
|
||||
func (r *ApiKeyRepository) CountByGroupID(ctx context.Context, groupID int64) (int64, error) {
|
||||
var count int64
|
||||
err := r.db.WithContext(ctx).Model(&model.ApiKey{}).Where("group_id = ?", groupID).Count(&count).Error
|
||||
return count, err
|
||||
}
|
||||
137
backend/internal/repository/group_repo.go
Normal file
137
backend/internal/repository/group_repo.go
Normal file
@@ -0,0 +1,137 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sub2api/internal/model"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type GroupRepository struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
func NewGroupRepository(db *gorm.DB) *GroupRepository {
|
||||
return &GroupRepository{db: db}
|
||||
}
|
||||
|
||||
func (r *GroupRepository) Create(ctx context.Context, group *model.Group) error {
|
||||
return r.db.WithContext(ctx).Create(group).Error
|
||||
}
|
||||
|
||||
func (r *GroupRepository) GetByID(ctx context.Context, id int64) (*model.Group, error) {
|
||||
var group model.Group
|
||||
err := r.db.WithContext(ctx).First(&group, id).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &group, nil
|
||||
}
|
||||
|
||||
func (r *GroupRepository) Update(ctx context.Context, group *model.Group) error {
|
||||
return r.db.WithContext(ctx).Save(group).Error
|
||||
}
|
||||
|
||||
func (r *GroupRepository) Delete(ctx context.Context, id int64) error {
|
||||
return r.db.WithContext(ctx).Delete(&model.Group{}, id).Error
|
||||
}
|
||||
|
||||
func (r *GroupRepository) List(ctx context.Context, params PaginationParams) ([]model.Group, *PaginationResult, error) {
|
||||
return r.ListWithFilters(ctx, params, "", "", nil)
|
||||
}
|
||||
|
||||
// ListWithFilters lists groups with optional filtering by platform, status, and is_exclusive
|
||||
func (r *GroupRepository) ListWithFilters(ctx context.Context, params PaginationParams, platform, status string, isExclusive *bool) ([]model.Group, *PaginationResult, error) {
|
||||
var groups []model.Group
|
||||
var total int64
|
||||
|
||||
db := r.db.WithContext(ctx).Model(&model.Group{})
|
||||
|
||||
// Apply filters
|
||||
if platform != "" {
|
||||
db = db.Where("platform = ?", platform)
|
||||
}
|
||||
if status != "" {
|
||||
db = db.Where("status = ?", status)
|
||||
}
|
||||
if isExclusive != nil {
|
||||
db = db.Where("is_exclusive = ?", *isExclusive)
|
||||
}
|
||||
|
||||
if err := db.Count(&total).Error; err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
if err := db.Offset(params.Offset()).Limit(params.Limit()).Order("id ASC").Find(&groups).Error; err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// 获取每个分组的账号数量
|
||||
for i := range groups {
|
||||
count, _ := r.GetAccountCount(ctx, groups[i].ID)
|
||||
groups[i].AccountCount = count
|
||||
}
|
||||
|
||||
pages := int(total) / params.Limit()
|
||||
if int(total)%params.Limit() > 0 {
|
||||
pages++
|
||||
}
|
||||
|
||||
return groups, &PaginationResult{
|
||||
Total: total,
|
||||
Page: params.Page,
|
||||
PageSize: params.Limit(),
|
||||
Pages: pages,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (r *GroupRepository) ListActive(ctx context.Context) ([]model.Group, error) {
|
||||
var groups []model.Group
|
||||
err := r.db.WithContext(ctx).Where("status = ?", model.StatusActive).Order("id ASC").Find(&groups).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// 获取每个分组的账号数量
|
||||
for i := range groups {
|
||||
count, _ := r.GetAccountCount(ctx, groups[i].ID)
|
||||
groups[i].AccountCount = count
|
||||
}
|
||||
return groups, nil
|
||||
}
|
||||
|
||||
func (r *GroupRepository) ListActiveByPlatform(ctx context.Context, platform string) ([]model.Group, error) {
|
||||
var groups []model.Group
|
||||
err := r.db.WithContext(ctx).Where("status = ? AND platform = ?", model.StatusActive, platform).Order("id ASC").Find(&groups).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// 获取每个分组的账号数量
|
||||
for i := range groups {
|
||||
count, _ := r.GetAccountCount(ctx, groups[i].ID)
|
||||
groups[i].AccountCount = count
|
||||
}
|
||||
return groups, nil
|
||||
}
|
||||
|
||||
func (r *GroupRepository) ExistsByName(ctx context.Context, name string) (bool, error) {
|
||||
var count int64
|
||||
err := r.db.WithContext(ctx).Model(&model.Group{}).Where("name = ?", name).Count(&count).Error
|
||||
return count > 0, err
|
||||
}
|
||||
|
||||
func (r *GroupRepository) GetAccountCount(ctx context.Context, groupID int64) (int64, error) {
|
||||
var count int64
|
||||
err := r.db.WithContext(ctx).Model(&model.AccountGroup{}).Where("group_id = ?", groupID).Count(&count).Error
|
||||
return count, err
|
||||
}
|
||||
|
||||
// DeleteAccountGroupsByGroupID 删除分组与账号的关联关系
|
||||
func (r *GroupRepository) DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error) {
|
||||
result := r.db.WithContext(ctx).Where("group_id = ?", groupID).Delete(&model.AccountGroup{})
|
||||
return result.RowsAffected, result.Error
|
||||
}
|
||||
|
||||
// DB 返回底层数据库连接,用于事务处理
|
||||
func (r *GroupRepository) DB() *gorm.DB {
|
||||
return r.db
|
||||
}
|
||||
161
backend/internal/repository/proxy_repo.go
Normal file
161
backend/internal/repository/proxy_repo.go
Normal file
@@ -0,0 +1,161 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sub2api/internal/model"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type ProxyRepository struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
func NewProxyRepository(db *gorm.DB) *ProxyRepository {
|
||||
return &ProxyRepository{db: db}
|
||||
}
|
||||
|
||||
func (r *ProxyRepository) Create(ctx context.Context, proxy *model.Proxy) error {
|
||||
return r.db.WithContext(ctx).Create(proxy).Error
|
||||
}
|
||||
|
||||
func (r *ProxyRepository) GetByID(ctx context.Context, id int64) (*model.Proxy, error) {
|
||||
var proxy model.Proxy
|
||||
err := r.db.WithContext(ctx).First(&proxy, id).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &proxy, nil
|
||||
}
|
||||
|
||||
func (r *ProxyRepository) Update(ctx context.Context, proxy *model.Proxy) error {
|
||||
return r.db.WithContext(ctx).Save(proxy).Error
|
||||
}
|
||||
|
||||
func (r *ProxyRepository) Delete(ctx context.Context, id int64) error {
|
||||
return r.db.WithContext(ctx).Delete(&model.Proxy{}, id).Error
|
||||
}
|
||||
|
||||
func (r *ProxyRepository) List(ctx context.Context, params PaginationParams) ([]model.Proxy, *PaginationResult, error) {
|
||||
return r.ListWithFilters(ctx, params, "", "", "")
|
||||
}
|
||||
|
||||
// ListWithFilters lists proxies with optional filtering by protocol, status, and search query
|
||||
func (r *ProxyRepository) ListWithFilters(ctx context.Context, params PaginationParams, protocol, status, search string) ([]model.Proxy, *PaginationResult, error) {
|
||||
var proxies []model.Proxy
|
||||
var total int64
|
||||
|
||||
db := r.db.WithContext(ctx).Model(&model.Proxy{})
|
||||
|
||||
// Apply filters
|
||||
if protocol != "" {
|
||||
db = db.Where("protocol = ?", protocol)
|
||||
}
|
||||
if status != "" {
|
||||
db = db.Where("status = ?", status)
|
||||
}
|
||||
if search != "" {
|
||||
searchPattern := "%" + search + "%"
|
||||
db = db.Where("name ILIKE ?", searchPattern)
|
||||
}
|
||||
|
||||
if err := db.Count(&total).Error; err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
if err := db.Offset(params.Offset()).Limit(params.Limit()).Order("id DESC").Find(&proxies).Error; err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
pages := int(total) / params.Limit()
|
||||
if int(total)%params.Limit() > 0 {
|
||||
pages++
|
||||
}
|
||||
|
||||
return proxies, &PaginationResult{
|
||||
Total: total,
|
||||
Page: params.Page,
|
||||
PageSize: params.Limit(),
|
||||
Pages: pages,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (r *ProxyRepository) ListActive(ctx context.Context) ([]model.Proxy, error) {
|
||||
var proxies []model.Proxy
|
||||
err := r.db.WithContext(ctx).Where("status = ?", model.StatusActive).Find(&proxies).Error
|
||||
return proxies, err
|
||||
}
|
||||
|
||||
// ExistsByHostPortAuth checks if a proxy with the same host, port, username, and password exists
|
||||
func (r *ProxyRepository) ExistsByHostPortAuth(ctx context.Context, host string, port int, username, password string) (bool, error) {
|
||||
var count int64
|
||||
err := r.db.WithContext(ctx).Model(&model.Proxy{}).
|
||||
Where("host = ? AND port = ? AND username = ? AND password = ?", host, port, username, password).
|
||||
Count(&count).Error
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return count > 0, nil
|
||||
}
|
||||
|
||||
// CountAccountsByProxyID returns the number of accounts using a specific proxy
|
||||
func (r *ProxyRepository) CountAccountsByProxyID(ctx context.Context, proxyID int64) (int64, error) {
|
||||
var count int64
|
||||
err := r.db.WithContext(ctx).Model(&model.Account{}).
|
||||
Where("proxy_id = ?", proxyID).
|
||||
Count(&count).Error
|
||||
return count, err
|
||||
}
|
||||
|
||||
// GetAccountCountsForProxies returns a map of proxy ID to account count for all proxies
|
||||
func (r *ProxyRepository) GetAccountCountsForProxies(ctx context.Context) (map[int64]int64, error) {
|
||||
type result struct {
|
||||
ProxyID int64 `gorm:"column:proxy_id"`
|
||||
Count int64 `gorm:"column:count"`
|
||||
}
|
||||
var results []result
|
||||
err := r.db.WithContext(ctx).
|
||||
Model(&model.Account{}).
|
||||
Select("proxy_id, COUNT(*) as count").
|
||||
Where("proxy_id IS NOT NULL").
|
||||
Group("proxy_id").
|
||||
Scan(&results).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
counts := make(map[int64]int64)
|
||||
for _, r := range results {
|
||||
counts[r.ProxyID] = r.Count
|
||||
}
|
||||
return counts, nil
|
||||
}
|
||||
|
||||
// ListActiveWithAccountCount returns all active proxies with account count, sorted by creation time descending
|
||||
func (r *ProxyRepository) ListActiveWithAccountCount(ctx context.Context) ([]model.ProxyWithAccountCount, error) {
|
||||
var proxies []model.Proxy
|
||||
err := r.db.WithContext(ctx).
|
||||
Where("status = ?", model.StatusActive).
|
||||
Order("created_at DESC").
|
||||
Find(&proxies).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Get account counts
|
||||
counts, err := r.GetAccountCountsForProxies(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Build result with account counts
|
||||
result := make([]model.ProxyWithAccountCount, len(proxies))
|
||||
for i, proxy := range proxies {
|
||||
result[i] = model.ProxyWithAccountCount{
|
||||
Proxy: proxy,
|
||||
AccountCount: counts[proxy.ID],
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
133
backend/internal/repository/redeem_code_repo.go
Normal file
133
backend/internal/repository/redeem_code_repo.go
Normal file
@@ -0,0 +1,133 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sub2api/internal/model"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type RedeemCodeRepository struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
func NewRedeemCodeRepository(db *gorm.DB) *RedeemCodeRepository {
|
||||
return &RedeemCodeRepository{db: db}
|
||||
}
|
||||
|
||||
func (r *RedeemCodeRepository) Create(ctx context.Context, code *model.RedeemCode) error {
|
||||
return r.db.WithContext(ctx).Create(code).Error
|
||||
}
|
||||
|
||||
func (r *RedeemCodeRepository) CreateBatch(ctx context.Context, codes []model.RedeemCode) error {
|
||||
return r.db.WithContext(ctx).Create(&codes).Error
|
||||
}
|
||||
|
||||
func (r *RedeemCodeRepository) GetByID(ctx context.Context, id int64) (*model.RedeemCode, error) {
|
||||
var code model.RedeemCode
|
||||
err := r.db.WithContext(ctx).First(&code, id).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &code, nil
|
||||
}
|
||||
|
||||
func (r *RedeemCodeRepository) GetByCode(ctx context.Context, code string) (*model.RedeemCode, error) {
|
||||
var redeemCode model.RedeemCode
|
||||
err := r.db.WithContext(ctx).Where("code = ?", code).First(&redeemCode).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &redeemCode, nil
|
||||
}
|
||||
|
||||
func (r *RedeemCodeRepository) Delete(ctx context.Context, id int64) error {
|
||||
return r.db.WithContext(ctx).Delete(&model.RedeemCode{}, id).Error
|
||||
}
|
||||
|
||||
func (r *RedeemCodeRepository) List(ctx context.Context, params PaginationParams) ([]model.RedeemCode, *PaginationResult, error) {
|
||||
return r.ListWithFilters(ctx, params, "", "", "")
|
||||
}
|
||||
|
||||
// ListWithFilters lists redeem codes with optional filtering by type, status, and search query
|
||||
func (r *RedeemCodeRepository) ListWithFilters(ctx context.Context, params PaginationParams, codeType, status, search string) ([]model.RedeemCode, *PaginationResult, error) {
|
||||
var codes []model.RedeemCode
|
||||
var total int64
|
||||
|
||||
db := r.db.WithContext(ctx).Model(&model.RedeemCode{})
|
||||
|
||||
// Apply filters
|
||||
if codeType != "" {
|
||||
db = db.Where("type = ?", codeType)
|
||||
}
|
||||
if status != "" {
|
||||
db = db.Where("status = ?", status)
|
||||
}
|
||||
if search != "" {
|
||||
searchPattern := "%" + search + "%"
|
||||
db = db.Where("code ILIKE ?", searchPattern)
|
||||
}
|
||||
|
||||
if err := db.Count(&total).Error; err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
if err := db.Preload("User").Preload("Group").Offset(params.Offset()).Limit(params.Limit()).Order("id DESC").Find(&codes).Error; err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
pages := int(total) / params.Limit()
|
||||
if int(total)%params.Limit() > 0 {
|
||||
pages++
|
||||
}
|
||||
|
||||
return codes, &PaginationResult{
|
||||
Total: total,
|
||||
Page: params.Page,
|
||||
PageSize: params.Limit(),
|
||||
Pages: pages,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (r *RedeemCodeRepository) Update(ctx context.Context, code *model.RedeemCode) error {
|
||||
return r.db.WithContext(ctx).Save(code).Error
|
||||
}
|
||||
|
||||
func (r *RedeemCodeRepository) Use(ctx context.Context, id, userID int64) error {
|
||||
now := time.Now()
|
||||
result := r.db.WithContext(ctx).Model(&model.RedeemCode{}).
|
||||
Where("id = ? AND status = ?", id, model.StatusUnused).
|
||||
Updates(map[string]interface{}{
|
||||
"status": model.StatusUsed,
|
||||
"used_by": userID,
|
||||
"used_at": now,
|
||||
})
|
||||
if result.Error != nil {
|
||||
return result.Error
|
||||
}
|
||||
if result.RowsAffected == 0 {
|
||||
return gorm.ErrRecordNotFound // 兑换码不存在或已被使用
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ListByUser returns all redeem codes used by a specific user
|
||||
func (r *RedeemCodeRepository) ListByUser(ctx context.Context, userID int64, limit int) ([]model.RedeemCode, error) {
|
||||
var codes []model.RedeemCode
|
||||
if limit <= 0 {
|
||||
limit = 10
|
||||
}
|
||||
|
||||
err := r.db.WithContext(ctx).
|
||||
Preload("Group").
|
||||
Where("used_by = ?", userID).
|
||||
Order("used_at DESC").
|
||||
Limit(limit).
|
||||
Find(&codes).Error
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return codes, nil
|
||||
}
|
||||
74
backend/internal/repository/repository.go
Normal file
74
backend/internal/repository/repository.go
Normal file
@@ -0,0 +1,74 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// Repositories 所有仓库的集合
|
||||
type Repositories struct {
|
||||
User *UserRepository
|
||||
ApiKey *ApiKeyRepository
|
||||
Group *GroupRepository
|
||||
Account *AccountRepository
|
||||
Proxy *ProxyRepository
|
||||
RedeemCode *RedeemCodeRepository
|
||||
UsageLog *UsageLogRepository
|
||||
Setting *SettingRepository
|
||||
UserSubscription *UserSubscriptionRepository
|
||||
}
|
||||
|
||||
// NewRepositories 创建所有仓库
|
||||
func NewRepositories(db *gorm.DB) *Repositories {
|
||||
return &Repositories{
|
||||
User: NewUserRepository(db),
|
||||
ApiKey: NewApiKeyRepository(db),
|
||||
Group: NewGroupRepository(db),
|
||||
Account: NewAccountRepository(db),
|
||||
Proxy: NewProxyRepository(db),
|
||||
RedeemCode: NewRedeemCodeRepository(db),
|
||||
UsageLog: NewUsageLogRepository(db),
|
||||
Setting: NewSettingRepository(db),
|
||||
UserSubscription: NewUserSubscriptionRepository(db),
|
||||
}
|
||||
}
|
||||
|
||||
// PaginationParams 分页参数
|
||||
type PaginationParams struct {
|
||||
Page int
|
||||
PageSize int
|
||||
}
|
||||
|
||||
// PaginationResult 分页结果
|
||||
type PaginationResult struct {
|
||||
Total int64
|
||||
Page int
|
||||
PageSize int
|
||||
Pages int
|
||||
}
|
||||
|
||||
// DefaultPagination 默认分页参数
|
||||
func DefaultPagination() PaginationParams {
|
||||
return PaginationParams{
|
||||
Page: 1,
|
||||
PageSize: 20,
|
||||
}
|
||||
}
|
||||
|
||||
// Offset 计算偏移量
|
||||
func (p PaginationParams) Offset() int {
|
||||
if p.Page < 1 {
|
||||
p.Page = 1
|
||||
}
|
||||
return (p.Page - 1) * p.PageSize
|
||||
}
|
||||
|
||||
// Limit 获取限制数
|
||||
func (p PaginationParams) Limit() int {
|
||||
if p.PageSize < 1 {
|
||||
return 20
|
||||
}
|
||||
if p.PageSize > 100 {
|
||||
return 100
|
||||
}
|
||||
return p.PageSize
|
||||
}
|
||||
108
backend/internal/repository/setting_repo.go
Normal file
108
backend/internal/repository/setting_repo.go
Normal file
@@ -0,0 +1,108 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sub2api/internal/model"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
)
|
||||
|
||||
// SettingRepository 系统设置数据访问层
|
||||
type SettingRepository struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
// NewSettingRepository 创建系统设置仓库实例
|
||||
func NewSettingRepository(db *gorm.DB) *SettingRepository {
|
||||
return &SettingRepository{db: db}
|
||||
}
|
||||
|
||||
// Get 根据Key获取设置值
|
||||
func (r *SettingRepository) Get(ctx context.Context, key string) (*model.Setting, error) {
|
||||
var setting model.Setting
|
||||
err := r.db.WithContext(ctx).Where("key = ?", key).First(&setting).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &setting, nil
|
||||
}
|
||||
|
||||
// GetValue 获取设置值字符串
|
||||
func (r *SettingRepository) GetValue(ctx context.Context, key string) (string, error) {
|
||||
setting, err := r.Get(ctx, key)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return setting.Value, nil
|
||||
}
|
||||
|
||||
// Set 设置值(存在则更新,不存在则创建)
|
||||
func (r *SettingRepository) Set(ctx context.Context, key, value string) error {
|
||||
setting := &model.Setting{
|
||||
Key: key,
|
||||
Value: value,
|
||||
UpdatedAt: time.Now(),
|
||||
}
|
||||
|
||||
return r.db.WithContext(ctx).Clauses(clause.OnConflict{
|
||||
Columns: []clause.Column{{Name: "key"}},
|
||||
DoUpdates: clause.AssignmentColumns([]string{"value", "updated_at"}),
|
||||
}).Create(setting).Error
|
||||
}
|
||||
|
||||
// GetMultiple 批量获取设置
|
||||
func (r *SettingRepository) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) {
|
||||
var settings []model.Setting
|
||||
err := r.db.WithContext(ctx).Where("key IN ?", keys).Find(&settings).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result := make(map[string]string)
|
||||
for _, s := range settings {
|
||||
result[s.Key] = s.Value
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// SetMultiple 批量设置值
|
||||
func (r *SettingRepository) SetMultiple(ctx context.Context, settings map[string]string) error {
|
||||
return r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||||
for key, value := range settings {
|
||||
setting := &model.Setting{
|
||||
Key: key,
|
||||
Value: value,
|
||||
UpdatedAt: time.Now(),
|
||||
}
|
||||
if err := tx.Clauses(clause.OnConflict{
|
||||
Columns: []clause.Column{{Name: "key"}},
|
||||
DoUpdates: clause.AssignmentColumns([]string{"value", "updated_at"}),
|
||||
}).Create(setting).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// GetAll 获取所有设置
|
||||
func (r *SettingRepository) GetAll(ctx context.Context) (map[string]string, error) {
|
||||
var settings []model.Setting
|
||||
err := r.db.WithContext(ctx).Find(&settings).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result := make(map[string]string)
|
||||
for _, s := range settings {
|
||||
result[s.Key] = s.Value
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// Delete 删除设置
|
||||
func (r *SettingRepository) Delete(ctx context.Context, key string) error {
|
||||
return r.db.WithContext(ctx).Where("key = ?", key).Delete(&model.Setting{}).Error
|
||||
}
|
||||
1006
backend/internal/repository/usage_log_repo.go
Normal file
1006
backend/internal/repository/usage_log_repo.go
Normal file
File diff suppressed because it is too large
Load Diff
130
backend/internal/repository/user_repo.go
Normal file
130
backend/internal/repository/user_repo.go
Normal file
@@ -0,0 +1,130 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sub2api/internal/model"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type UserRepository struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
func NewUserRepository(db *gorm.DB) *UserRepository {
|
||||
return &UserRepository{db: db}
|
||||
}
|
||||
|
||||
func (r *UserRepository) Create(ctx context.Context, user *model.User) error {
|
||||
return r.db.WithContext(ctx).Create(user).Error
|
||||
}
|
||||
|
||||
func (r *UserRepository) GetByID(ctx context.Context, id int64) (*model.User, error) {
|
||||
var user model.User
|
||||
err := r.db.WithContext(ctx).First(&user, id).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
func (r *UserRepository) GetByEmail(ctx context.Context, email string) (*model.User, error) {
|
||||
var user model.User
|
||||
err := r.db.WithContext(ctx).Where("email = ?", email).First(&user).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
func (r *UserRepository) Update(ctx context.Context, user *model.User) error {
|
||||
return r.db.WithContext(ctx).Save(user).Error
|
||||
}
|
||||
|
||||
func (r *UserRepository) Delete(ctx context.Context, id int64) error {
|
||||
return r.db.WithContext(ctx).Delete(&model.User{}, id).Error
|
||||
}
|
||||
|
||||
func (r *UserRepository) List(ctx context.Context, params PaginationParams) ([]model.User, *PaginationResult, error) {
|
||||
return r.ListWithFilters(ctx, params, "", "", "")
|
||||
}
|
||||
|
||||
// ListWithFilters lists users with optional filtering by status, role, and search query
|
||||
func (r *UserRepository) ListWithFilters(ctx context.Context, params PaginationParams, status, role, search string) ([]model.User, *PaginationResult, error) {
|
||||
var users []model.User
|
||||
var total int64
|
||||
|
||||
db := r.db.WithContext(ctx).Model(&model.User{})
|
||||
|
||||
// Apply filters
|
||||
if status != "" {
|
||||
db = db.Where("status = ?", status)
|
||||
}
|
||||
if role != "" {
|
||||
db = db.Where("role = ?", role)
|
||||
}
|
||||
if search != "" {
|
||||
searchPattern := "%" + search + "%"
|
||||
db = db.Where("email ILIKE ?", searchPattern)
|
||||
}
|
||||
|
||||
if err := db.Count(&total).Error; err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
if err := db.Offset(params.Offset()).Limit(params.Limit()).Order("id DESC").Find(&users).Error; err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
pages := int(total) / params.Limit()
|
||||
if int(total)%params.Limit() > 0 {
|
||||
pages++
|
||||
}
|
||||
|
||||
return users, &PaginationResult{
|
||||
Total: total,
|
||||
Page: params.Page,
|
||||
PageSize: params.Limit(),
|
||||
Pages: pages,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (r *UserRepository) UpdateBalance(ctx context.Context, id int64, amount float64) error {
|
||||
return r.db.WithContext(ctx).Model(&model.User{}).Where("id = ?", id).
|
||||
Update("balance", gorm.Expr("balance + ?", amount)).Error
|
||||
}
|
||||
|
||||
// DeductBalance 扣减用户余额,仅当余额充足时执行
|
||||
func (r *UserRepository) DeductBalance(ctx context.Context, id int64, amount float64) error {
|
||||
result := r.db.WithContext(ctx).Model(&model.User{}).
|
||||
Where("id = ? AND balance >= ?", id, amount).
|
||||
Update("balance", gorm.Expr("balance - ?", amount))
|
||||
if result.Error != nil {
|
||||
return result.Error
|
||||
}
|
||||
if result.RowsAffected == 0 {
|
||||
return gorm.ErrRecordNotFound // 余额不足或用户不存在
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *UserRepository) UpdateConcurrency(ctx context.Context, id int64, amount int) error {
|
||||
return r.db.WithContext(ctx).Model(&model.User{}).Where("id = ?", id).
|
||||
Update("concurrency", gorm.Expr("concurrency + ?", amount)).Error
|
||||
}
|
||||
|
||||
func (r *UserRepository) ExistsByEmail(ctx context.Context, email string) (bool, error) {
|
||||
var count int64
|
||||
err := r.db.WithContext(ctx).Model(&model.User{}).Where("email = ?", email).Count(&count).Error
|
||||
return count > 0, err
|
||||
}
|
||||
|
||||
// RemoveGroupFromAllowedGroups 从所有用户的 allowed_groups 数组中移除指定的分组ID
|
||||
// 使用 PostgreSQL 的 array_remove 函数
|
||||
func (r *UserRepository) RemoveGroupFromAllowedGroups(ctx context.Context, groupID int64) (int64, error) {
|
||||
result := r.db.WithContext(ctx).Model(&model.User{}).
|
||||
Where("? = ANY(allowed_groups)", groupID).
|
||||
Update("allowed_groups", gorm.Expr("array_remove(allowed_groups, ?)", groupID))
|
||||
return result.RowsAffected, result.Error
|
||||
}
|
||||
|
||||
322
backend/internal/repository/user_subscription_repo.go
Normal file
322
backend/internal/repository/user_subscription_repo.go
Normal file
@@ -0,0 +1,322 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"sub2api/internal/model"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// UserSubscriptionRepository 用户订阅仓库
|
||||
type UserSubscriptionRepository struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
// NewUserSubscriptionRepository 创建用户订阅仓库
|
||||
func NewUserSubscriptionRepository(db *gorm.DB) *UserSubscriptionRepository {
|
||||
return &UserSubscriptionRepository{db: db}
|
||||
}
|
||||
|
||||
// Create 创建订阅
|
||||
func (r *UserSubscriptionRepository) Create(ctx context.Context, sub *model.UserSubscription) error {
|
||||
return r.db.WithContext(ctx).Create(sub).Error
|
||||
}
|
||||
|
||||
// GetByID 根据ID获取订阅
|
||||
func (r *UserSubscriptionRepository) GetByID(ctx context.Context, id int64) (*model.UserSubscription, error) {
|
||||
var sub model.UserSubscription
|
||||
err := r.db.WithContext(ctx).
|
||||
Preload("User").
|
||||
Preload("Group").
|
||||
Preload("AssignedByUser").
|
||||
First(&sub, id).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &sub, nil
|
||||
}
|
||||
|
||||
// GetByUserIDAndGroupID 根据用户ID和分组ID获取订阅
|
||||
func (r *UserSubscriptionRepository) GetByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (*model.UserSubscription, error) {
|
||||
var sub model.UserSubscription
|
||||
err := r.db.WithContext(ctx).
|
||||
Preload("Group").
|
||||
Where("user_id = ? AND group_id = ?", userID, groupID).
|
||||
First(&sub).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &sub, nil
|
||||
}
|
||||
|
||||
// GetActiveByUserIDAndGroupID 获取用户对特定分组的有效订阅
|
||||
func (r *UserSubscriptionRepository) GetActiveByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (*model.UserSubscription, error) {
|
||||
var sub model.UserSubscription
|
||||
err := r.db.WithContext(ctx).
|
||||
Preload("Group").
|
||||
Where("user_id = ? AND group_id = ? AND status = ? AND expires_at > ?",
|
||||
userID, groupID, model.SubscriptionStatusActive, time.Now()).
|
||||
First(&sub).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &sub, nil
|
||||
}
|
||||
|
||||
// Update 更新订阅
|
||||
func (r *UserSubscriptionRepository) Update(ctx context.Context, sub *model.UserSubscription) error {
|
||||
sub.UpdatedAt = time.Now()
|
||||
return r.db.WithContext(ctx).Save(sub).Error
|
||||
}
|
||||
|
||||
// Delete 删除订阅
|
||||
func (r *UserSubscriptionRepository) Delete(ctx context.Context, id int64) error {
|
||||
return r.db.WithContext(ctx).Delete(&model.UserSubscription{}, id).Error
|
||||
}
|
||||
|
||||
// ListByUserID 获取用户的所有订阅
|
||||
func (r *UserSubscriptionRepository) ListByUserID(ctx context.Context, userID int64) ([]model.UserSubscription, error) {
|
||||
var subs []model.UserSubscription
|
||||
err := r.db.WithContext(ctx).
|
||||
Preload("Group").
|
||||
Where("user_id = ?", userID).
|
||||
Order("created_at DESC").
|
||||
Find(&subs).Error
|
||||
return subs, err
|
||||
}
|
||||
|
||||
// ListActiveByUserID 获取用户的所有有效订阅
|
||||
func (r *UserSubscriptionRepository) ListActiveByUserID(ctx context.Context, userID int64) ([]model.UserSubscription, error) {
|
||||
var subs []model.UserSubscription
|
||||
err := r.db.WithContext(ctx).
|
||||
Preload("Group").
|
||||
Where("user_id = ? AND status = ? AND expires_at > ?",
|
||||
userID, model.SubscriptionStatusActive, time.Now()).
|
||||
Order("created_at DESC").
|
||||
Find(&subs).Error
|
||||
return subs, err
|
||||
}
|
||||
|
||||
// ListByGroupID 获取分组的所有订阅(分页)
|
||||
func (r *UserSubscriptionRepository) ListByGroupID(ctx context.Context, groupID int64, params PaginationParams) ([]model.UserSubscription, *PaginationResult, error) {
|
||||
var subs []model.UserSubscription
|
||||
var total int64
|
||||
|
||||
query := r.db.WithContext(ctx).Model(&model.UserSubscription{}).Where("group_id = ?", groupID)
|
||||
|
||||
if err := query.Count(&total).Error; err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
err := query.
|
||||
Preload("User").
|
||||
Preload("Group").
|
||||
Order("created_at DESC").
|
||||
Offset(params.Offset()).
|
||||
Limit(params.Limit()).
|
||||
Find(&subs).Error
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
pages := int(total) / params.Limit()
|
||||
if int(total)%params.Limit() > 0 {
|
||||
pages++
|
||||
}
|
||||
|
||||
return subs, &PaginationResult{
|
||||
Total: total,
|
||||
Page: params.Page,
|
||||
PageSize: params.Limit(),
|
||||
Pages: pages,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// List 获取所有订阅(分页,支持筛选)
|
||||
func (r *UserSubscriptionRepository) List(ctx context.Context, params PaginationParams, userID, groupID *int64, status string) ([]model.UserSubscription, *PaginationResult, error) {
|
||||
var subs []model.UserSubscription
|
||||
var total int64
|
||||
|
||||
query := r.db.WithContext(ctx).Model(&model.UserSubscription{})
|
||||
|
||||
if userID != nil {
|
||||
query = query.Where("user_id = ?", *userID)
|
||||
}
|
||||
if groupID != nil {
|
||||
query = query.Where("group_id = ?", *groupID)
|
||||
}
|
||||
if status != "" {
|
||||
query = query.Where("status = ?", status)
|
||||
}
|
||||
|
||||
if err := query.Count(&total).Error; err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
err := query.
|
||||
Preload("User").
|
||||
Preload("Group").
|
||||
Preload("AssignedByUser").
|
||||
Order("created_at DESC").
|
||||
Offset(params.Offset()).
|
||||
Limit(params.Limit()).
|
||||
Find(&subs).Error
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
pages := int(total) / params.Limit()
|
||||
if int(total)%params.Limit() > 0 {
|
||||
pages++
|
||||
}
|
||||
|
||||
return subs, &PaginationResult{
|
||||
Total: total,
|
||||
Page: params.Page,
|
||||
PageSize: params.Limit(),
|
||||
Pages: pages,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// IncrementUsage 增加使用量
|
||||
func (r *UserSubscriptionRepository) IncrementUsage(ctx context.Context, id int64, costUSD float64) error {
|
||||
return r.db.WithContext(ctx).Model(&model.UserSubscription{}).
|
||||
Where("id = ?", id).
|
||||
Updates(map[string]interface{}{
|
||||
"daily_usage_usd": gorm.Expr("daily_usage_usd + ?", costUSD),
|
||||
"weekly_usage_usd": gorm.Expr("weekly_usage_usd + ?", costUSD),
|
||||
"monthly_usage_usd": gorm.Expr("monthly_usage_usd + ?", costUSD),
|
||||
"updated_at": time.Now(),
|
||||
}).Error
|
||||
}
|
||||
|
||||
// ResetDailyUsage 重置日使用量
|
||||
func (r *UserSubscriptionRepository) ResetDailyUsage(ctx context.Context, id int64, newWindowStart time.Time) error {
|
||||
return r.db.WithContext(ctx).Model(&model.UserSubscription{}).
|
||||
Where("id = ?", id).
|
||||
Updates(map[string]interface{}{
|
||||
"daily_usage_usd": 0,
|
||||
"daily_window_start": newWindowStart,
|
||||
"updated_at": time.Now(),
|
||||
}).Error
|
||||
}
|
||||
|
||||
// ResetWeeklyUsage 重置周使用量
|
||||
func (r *UserSubscriptionRepository) ResetWeeklyUsage(ctx context.Context, id int64, newWindowStart time.Time) error {
|
||||
return r.db.WithContext(ctx).Model(&model.UserSubscription{}).
|
||||
Where("id = ?", id).
|
||||
Updates(map[string]interface{}{
|
||||
"weekly_usage_usd": 0,
|
||||
"weekly_window_start": newWindowStart,
|
||||
"updated_at": time.Now(),
|
||||
}).Error
|
||||
}
|
||||
|
||||
// ResetMonthlyUsage 重置月使用量
|
||||
func (r *UserSubscriptionRepository) ResetMonthlyUsage(ctx context.Context, id int64, newWindowStart time.Time) error {
|
||||
return r.db.WithContext(ctx).Model(&model.UserSubscription{}).
|
||||
Where("id = ?", id).
|
||||
Updates(map[string]interface{}{
|
||||
"monthly_usage_usd": 0,
|
||||
"monthly_window_start": newWindowStart,
|
||||
"updated_at": time.Now(),
|
||||
}).Error
|
||||
}
|
||||
|
||||
// ActivateWindows 激活所有窗口(首次使用时)
|
||||
func (r *UserSubscriptionRepository) ActivateWindows(ctx context.Context, id int64, activateTime time.Time) error {
|
||||
return r.db.WithContext(ctx).Model(&model.UserSubscription{}).
|
||||
Where("id = ?", id).
|
||||
Updates(map[string]interface{}{
|
||||
"daily_window_start": activateTime,
|
||||
"weekly_window_start": activateTime,
|
||||
"monthly_window_start": activateTime,
|
||||
"updated_at": time.Now(),
|
||||
}).Error
|
||||
}
|
||||
|
||||
// UpdateStatus 更新订阅状态
|
||||
func (r *UserSubscriptionRepository) UpdateStatus(ctx context.Context, id int64, status string) error {
|
||||
return r.db.WithContext(ctx).Model(&model.UserSubscription{}).
|
||||
Where("id = ?", id).
|
||||
Updates(map[string]interface{}{
|
||||
"status": status,
|
||||
"updated_at": time.Now(),
|
||||
}).Error
|
||||
}
|
||||
|
||||
// ExtendExpiry 延长订阅过期时间
|
||||
func (r *UserSubscriptionRepository) ExtendExpiry(ctx context.Context, id int64, newExpiresAt time.Time) error {
|
||||
return r.db.WithContext(ctx).Model(&model.UserSubscription{}).
|
||||
Where("id = ?", id).
|
||||
Updates(map[string]interface{}{
|
||||
"expires_at": newExpiresAt,
|
||||
"updated_at": time.Now(),
|
||||
}).Error
|
||||
}
|
||||
|
||||
// UpdateNotes 更新订阅备注
|
||||
func (r *UserSubscriptionRepository) UpdateNotes(ctx context.Context, id int64, notes string) error {
|
||||
return r.db.WithContext(ctx).Model(&model.UserSubscription{}).
|
||||
Where("id = ?", id).
|
||||
Updates(map[string]interface{}{
|
||||
"notes": notes,
|
||||
"updated_at": time.Now(),
|
||||
}).Error
|
||||
}
|
||||
|
||||
// ListExpired 获取所有已过期但状态仍为active的订阅
|
||||
func (r *UserSubscriptionRepository) ListExpired(ctx context.Context) ([]model.UserSubscription, error) {
|
||||
var subs []model.UserSubscription
|
||||
err := r.db.WithContext(ctx).
|
||||
Where("status = ? AND expires_at <= ?", model.SubscriptionStatusActive, time.Now()).
|
||||
Find(&subs).Error
|
||||
return subs, err
|
||||
}
|
||||
|
||||
// BatchUpdateExpiredStatus 批量更新过期订阅状态
|
||||
func (r *UserSubscriptionRepository) BatchUpdateExpiredStatus(ctx context.Context) (int64, error) {
|
||||
result := r.db.WithContext(ctx).Model(&model.UserSubscription{}).
|
||||
Where("status = ? AND expires_at <= ?", model.SubscriptionStatusActive, time.Now()).
|
||||
Updates(map[string]interface{}{
|
||||
"status": model.SubscriptionStatusExpired,
|
||||
"updated_at": time.Now(),
|
||||
})
|
||||
return result.RowsAffected, result.Error
|
||||
}
|
||||
|
||||
// ExistsByUserIDAndGroupID 检查用户是否已有该分组的订阅
|
||||
func (r *UserSubscriptionRepository) ExistsByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (bool, error) {
|
||||
var count int64
|
||||
err := r.db.WithContext(ctx).Model(&model.UserSubscription{}).
|
||||
Where("user_id = ? AND group_id = ?", userID, groupID).
|
||||
Count(&count).Error
|
||||
return count > 0, err
|
||||
}
|
||||
|
||||
// CountByGroupID 获取分组的订阅数量
|
||||
func (r *UserSubscriptionRepository) CountByGroupID(ctx context.Context, groupID int64) (int64, error) {
|
||||
var count int64
|
||||
err := r.db.WithContext(ctx).Model(&model.UserSubscription{}).
|
||||
Where("group_id = ?", groupID).
|
||||
Count(&count).Error
|
||||
return count, err
|
||||
}
|
||||
|
||||
// CountActiveByGroupID 获取分组的有效订阅数量
|
||||
func (r *UserSubscriptionRepository) CountActiveByGroupID(ctx context.Context, groupID int64) (int64, error) {
|
||||
var count int64
|
||||
err := r.db.WithContext(ctx).Model(&model.UserSubscription{}).
|
||||
Where("group_id = ? AND status = ? AND expires_at > ?",
|
||||
groupID, model.SubscriptionStatusActive, time.Now()).
|
||||
Count(&count).Error
|
||||
return count, err
|
||||
}
|
||||
|
||||
// DeleteByGroupID 删除分组相关的所有订阅记录
|
||||
func (r *UserSubscriptionRepository) DeleteByGroupID(ctx context.Context, groupID int64) (int64, error) {
|
||||
result := r.db.WithContext(ctx).Where("group_id = ?", groupID).Delete(&model.UserSubscription{})
|
||||
return result.RowsAffected, result.Error
|
||||
}
|
||||
284
backend/internal/service/account_service.go
Normal file
284
backend/internal/service/account_service.go
Normal file
@@ -0,0 +1,284 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/repository"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrAccountNotFound = errors.New("account not found")
|
||||
)
|
||||
|
||||
// CreateAccountRequest 创建账号请求
|
||||
type CreateAccountRequest struct {
|
||||
Name string `json:"name"`
|
||||
Platform string `json:"platform"`
|
||||
Type string `json:"type"`
|
||||
Credentials map[string]interface{} `json:"credentials"`
|
||||
Extra map[string]interface{} `json:"extra"`
|
||||
ProxyID *int64 `json:"proxy_id"`
|
||||
Concurrency int `json:"concurrency"`
|
||||
Priority int `json:"priority"`
|
||||
GroupIDs []int64 `json:"group_ids"`
|
||||
}
|
||||
|
||||
// UpdateAccountRequest 更新账号请求
|
||||
type UpdateAccountRequest struct {
|
||||
Name *string `json:"name"`
|
||||
Credentials *map[string]interface{} `json:"credentials"`
|
||||
Extra *map[string]interface{} `json:"extra"`
|
||||
ProxyID *int64 `json:"proxy_id"`
|
||||
Concurrency *int `json:"concurrency"`
|
||||
Priority *int `json:"priority"`
|
||||
Status *string `json:"status"`
|
||||
GroupIDs *[]int64 `json:"group_ids"`
|
||||
}
|
||||
|
||||
// AccountService 账号管理服务
|
||||
type AccountService struct {
|
||||
accountRepo *repository.AccountRepository
|
||||
groupRepo *repository.GroupRepository
|
||||
}
|
||||
|
||||
// NewAccountService 创建账号服务实例
|
||||
func NewAccountService(accountRepo *repository.AccountRepository, groupRepo *repository.GroupRepository) *AccountService {
|
||||
return &AccountService{
|
||||
accountRepo: accountRepo,
|
||||
groupRepo: groupRepo,
|
||||
}
|
||||
}
|
||||
|
||||
// Create 创建账号
|
||||
func (s *AccountService) Create(ctx context.Context, req CreateAccountRequest) (*model.Account, error) {
|
||||
// 验证分组是否存在(如果指定了分组)
|
||||
if len(req.GroupIDs) > 0 {
|
||||
for _, groupID := range req.GroupIDs {
|
||||
_, err := s.groupRepo.GetByID(ctx, groupID)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, fmt.Errorf("group %d not found", groupID)
|
||||
}
|
||||
return nil, fmt.Errorf("get group: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 创建账号
|
||||
account := &model.Account{
|
||||
Name: req.Name,
|
||||
Platform: req.Platform,
|
||||
Type: req.Type,
|
||||
Credentials: req.Credentials,
|
||||
Extra: req.Extra,
|
||||
ProxyID: req.ProxyID,
|
||||
Concurrency: req.Concurrency,
|
||||
Priority: req.Priority,
|
||||
Status: model.StatusActive,
|
||||
}
|
||||
|
||||
if err := s.accountRepo.Create(ctx, account); err != nil {
|
||||
return nil, fmt.Errorf("create account: %w", err)
|
||||
}
|
||||
|
||||
// 绑定分组
|
||||
if len(req.GroupIDs) > 0 {
|
||||
if err := s.accountRepo.BindGroups(ctx, account.ID, req.GroupIDs); err != nil {
|
||||
return nil, fmt.Errorf("bind groups: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return account, nil
|
||||
}
|
||||
|
||||
// GetByID 根据ID获取账号
|
||||
func (s *AccountService) GetByID(ctx context.Context, id int64) (*model.Account, error) {
|
||||
account, err := s.accountRepo.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, ErrAccountNotFound
|
||||
}
|
||||
return nil, fmt.Errorf("get account: %w", err)
|
||||
}
|
||||
return account, nil
|
||||
}
|
||||
|
||||
// List 获取账号列表
|
||||
func (s *AccountService) List(ctx context.Context, params repository.PaginationParams) ([]model.Account, *repository.PaginationResult, error) {
|
||||
accounts, pagination, err := s.accountRepo.List(ctx, params)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("list accounts: %w", err)
|
||||
}
|
||||
return accounts, pagination, nil
|
||||
}
|
||||
|
||||
// ListByPlatform 根据平台获取账号列表
|
||||
func (s *AccountService) ListByPlatform(ctx context.Context, platform string) ([]model.Account, error) {
|
||||
accounts, err := s.accountRepo.ListByPlatform(ctx, platform)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list accounts by platform: %w", err)
|
||||
}
|
||||
return accounts, nil
|
||||
}
|
||||
|
||||
// ListByGroup 根据分组获取账号列表
|
||||
func (s *AccountService) ListByGroup(ctx context.Context, groupID int64) ([]model.Account, error) {
|
||||
accounts, err := s.accountRepo.ListByGroup(ctx, groupID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list accounts by group: %w", err)
|
||||
}
|
||||
return accounts, nil
|
||||
}
|
||||
|
||||
// Update 更新账号
|
||||
func (s *AccountService) Update(ctx context.Context, id int64, req UpdateAccountRequest) (*model.Account, error) {
|
||||
account, err := s.accountRepo.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, ErrAccountNotFound
|
||||
}
|
||||
return nil, fmt.Errorf("get account: %w", err)
|
||||
}
|
||||
|
||||
// 更新字段
|
||||
if req.Name != nil {
|
||||
account.Name = *req.Name
|
||||
}
|
||||
|
||||
if req.Credentials != nil {
|
||||
account.Credentials = *req.Credentials
|
||||
}
|
||||
|
||||
if req.Extra != nil {
|
||||
account.Extra = *req.Extra
|
||||
}
|
||||
|
||||
if req.ProxyID != nil {
|
||||
account.ProxyID = req.ProxyID
|
||||
}
|
||||
|
||||
if req.Concurrency != nil {
|
||||
account.Concurrency = *req.Concurrency
|
||||
}
|
||||
|
||||
if req.Priority != nil {
|
||||
account.Priority = *req.Priority
|
||||
}
|
||||
|
||||
if req.Status != nil {
|
||||
account.Status = *req.Status
|
||||
}
|
||||
|
||||
if err := s.accountRepo.Update(ctx, account); err != nil {
|
||||
return nil, fmt.Errorf("update account: %w", err)
|
||||
}
|
||||
|
||||
// 更新分组绑定
|
||||
if req.GroupIDs != nil {
|
||||
// 验证分组是否存在
|
||||
for _, groupID := range *req.GroupIDs {
|
||||
_, err := s.groupRepo.GetByID(ctx, groupID)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, fmt.Errorf("group %d not found", groupID)
|
||||
}
|
||||
return nil, fmt.Errorf("get group: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := s.accountRepo.BindGroups(ctx, account.ID, *req.GroupIDs); err != nil {
|
||||
return nil, fmt.Errorf("bind groups: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return account, nil
|
||||
}
|
||||
|
||||
// Delete 删除账号
|
||||
func (s *AccountService) Delete(ctx context.Context, id int64) error {
|
||||
// 检查账号是否存在
|
||||
_, err := s.accountRepo.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return ErrAccountNotFound
|
||||
}
|
||||
return fmt.Errorf("get account: %w", err)
|
||||
}
|
||||
|
||||
if err := s.accountRepo.Delete(ctx, id); err != nil {
|
||||
return fmt.Errorf("delete account: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateStatus 更新账号状态
|
||||
func (s *AccountService) UpdateStatus(ctx context.Context, id int64, status string, errorMessage string) error {
|
||||
account, err := s.accountRepo.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return ErrAccountNotFound
|
||||
}
|
||||
return fmt.Errorf("get account: %w", err)
|
||||
}
|
||||
|
||||
account.Status = status
|
||||
account.ErrorMessage = errorMessage
|
||||
|
||||
if err := s.accountRepo.Update(ctx, account); err != nil {
|
||||
return fmt.Errorf("update account: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateLastUsed 更新最后使用时间
|
||||
func (s *AccountService) UpdateLastUsed(ctx context.Context, id int64) error {
|
||||
if err := s.accountRepo.UpdateLastUsed(ctx, id); err != nil {
|
||||
return fmt.Errorf("update last used: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetCredential 获取账号凭证(安全访问)
|
||||
func (s *AccountService) GetCredential(ctx context.Context, id int64, key string) (string, error) {
|
||||
account, err := s.accountRepo.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return "", ErrAccountNotFound
|
||||
}
|
||||
return "", fmt.Errorf("get account: %w", err)
|
||||
}
|
||||
|
||||
return account.GetCredential(key), nil
|
||||
}
|
||||
|
||||
// TestCredentials 测试账号凭证是否有效(需要实现具体平台的测试逻辑)
|
||||
func (s *AccountService) TestCredentials(ctx context.Context, id int64) error {
|
||||
account, err := s.accountRepo.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return ErrAccountNotFound
|
||||
}
|
||||
return fmt.Errorf("get account: %w", err)
|
||||
}
|
||||
|
||||
// 根据平台执行不同的测试逻辑
|
||||
switch account.Platform {
|
||||
case model.PlatformAnthropic:
|
||||
// TODO: 测试Anthropic API凭证
|
||||
return nil
|
||||
case model.PlatformOpenAI:
|
||||
// TODO: 测试OpenAI API凭证
|
||||
return nil
|
||||
case model.PlatformGemini:
|
||||
// TODO: 测试Gemini API凭证
|
||||
return nil
|
||||
default:
|
||||
return fmt.Errorf("unsupported platform: %s", account.Platform)
|
||||
}
|
||||
}
|
||||
314
backend/internal/service/account_test_service.go
Normal file
314
backend/internal/service/account_test_service.go
Normal file
@@ -0,0 +1,314 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"sub2api/internal/repository"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
const (
|
||||
testClaudeAPIURL = "https://api.anthropic.com/v1/messages"
|
||||
testModel = "claude-sonnet-4-5-20250929"
|
||||
)
|
||||
|
||||
// TestEvent represents a SSE event for account testing
|
||||
type TestEvent struct {
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text,omitempty"`
|
||||
Model string `json:"model,omitempty"`
|
||||
Success bool `json:"success,omitempty"`
|
||||
Error string `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
// AccountTestService handles account testing operations
|
||||
type AccountTestService struct {
|
||||
repos *repository.Repositories
|
||||
oauthService *OAuthService
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
// NewAccountTestService creates a new AccountTestService
|
||||
func NewAccountTestService(repos *repository.Repositories, oauthService *OAuthService) *AccountTestService {
|
||||
return &AccountTestService{
|
||||
repos: repos,
|
||||
oauthService: oauthService,
|
||||
httpClient: &http.Client{
|
||||
Timeout: 60 * time.Second,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// generateSessionString generates a Claude Code style session string
|
||||
func generateSessionString() string {
|
||||
bytes := make([]byte, 32)
|
||||
rand.Read(bytes)
|
||||
hex64 := hex.EncodeToString(bytes)
|
||||
sessionUUID := uuid.New().String()
|
||||
return fmt.Sprintf("user_%s_account__session_%s", hex64, sessionUUID)
|
||||
}
|
||||
|
||||
// createTestPayload creates a minimal test request payload for OAuth/Setup Token accounts
|
||||
func createTestPayload() map[string]interface{} {
|
||||
return map[string]interface{}{
|
||||
"model": testModel,
|
||||
"messages": []map[string]interface{}{
|
||||
{
|
||||
"role": "user",
|
||||
"content": []map[string]interface{}{
|
||||
{
|
||||
"type": "text",
|
||||
"text": "hi",
|
||||
"cache_control": map[string]string{
|
||||
"type": "ephemeral",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
"system": []map[string]interface{}{
|
||||
{
|
||||
"type": "text",
|
||||
"text": "You are Claude Code, Anthropic's official CLI for Claude.",
|
||||
"cache_control": map[string]string{
|
||||
"type": "ephemeral",
|
||||
},
|
||||
},
|
||||
},
|
||||
"metadata": map[string]string{
|
||||
"user_id": generateSessionString(),
|
||||
},
|
||||
"max_tokens": 1024,
|
||||
"temperature": 1,
|
||||
"stream": true,
|
||||
}
|
||||
}
|
||||
|
||||
// createApiKeyTestPayload creates a simpler test request payload for API Key accounts
|
||||
func createApiKeyTestPayload(model string) map[string]interface{} {
|
||||
return map[string]interface{}{
|
||||
"model": model,
|
||||
"messages": []map[string]interface{}{
|
||||
{
|
||||
"role": "user",
|
||||
"content": "hi",
|
||||
},
|
||||
},
|
||||
"max_tokens": 1024,
|
||||
"stream": true,
|
||||
}
|
||||
}
|
||||
|
||||
// TestAccountConnection tests an account's connection by sending a test request
|
||||
func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int64) error {
|
||||
ctx := c.Request.Context()
|
||||
|
||||
// Get account
|
||||
account, err := s.repos.Account.GetByID(ctx, accountID)
|
||||
if err != nil {
|
||||
return s.sendErrorAndEnd(c, "Account not found")
|
||||
}
|
||||
|
||||
// Determine authentication method based on account type
|
||||
var authToken string
|
||||
var authType string // "bearer" for OAuth, "apikey" for API Key
|
||||
var apiURL string
|
||||
|
||||
if account.IsOAuth() {
|
||||
// OAuth or Setup Token account
|
||||
authType = "bearer"
|
||||
apiURL = testClaudeAPIURL
|
||||
authToken = account.GetCredential("access_token")
|
||||
if authToken == "" {
|
||||
return s.sendErrorAndEnd(c, "No access token available")
|
||||
}
|
||||
|
||||
// Check if token needs refresh
|
||||
needRefresh := false
|
||||
if expiresAtStr := account.GetCredential("expires_at"); expiresAtStr != "" {
|
||||
expiresAt, err := strconv.ParseInt(expiresAtStr, 10, 64)
|
||||
if err == nil && time.Now().Unix()+300 > expiresAt { // 5 minute buffer
|
||||
needRefresh = true
|
||||
}
|
||||
}
|
||||
|
||||
if needRefresh && s.oauthService != nil {
|
||||
tokenInfo, err := s.oauthService.RefreshAccountToken(ctx, account)
|
||||
if err != nil {
|
||||
return s.sendErrorAndEnd(c, fmt.Sprintf("Failed to refresh token: %s", err.Error()))
|
||||
}
|
||||
authToken = tokenInfo.AccessToken
|
||||
}
|
||||
} else if account.Type == "apikey" {
|
||||
// API Key account
|
||||
authType = "apikey"
|
||||
authToken = account.GetCredential("api_key")
|
||||
if authToken == "" {
|
||||
return s.sendErrorAndEnd(c, "No API key available")
|
||||
}
|
||||
|
||||
// Get base URL (use default if not set)
|
||||
apiURL = account.GetBaseURL()
|
||||
if apiURL == "" {
|
||||
apiURL = "https://api.anthropic.com"
|
||||
}
|
||||
// Append /v1/messages endpoint
|
||||
apiURL = strings.TrimSuffix(apiURL, "/") + "/v1/messages"
|
||||
} else {
|
||||
return s.sendErrorAndEnd(c, fmt.Sprintf("Unsupported account type: %s", account.Type))
|
||||
}
|
||||
|
||||
// Set SSE headers
|
||||
c.Writer.Header().Set("Content-Type", "text/event-stream")
|
||||
c.Writer.Header().Set("Cache-Control", "no-cache")
|
||||
c.Writer.Header().Set("Connection", "keep-alive")
|
||||
c.Writer.Header().Set("X-Accel-Buffering", "no")
|
||||
c.Writer.Flush()
|
||||
|
||||
// Create test request payload
|
||||
var payload map[string]interface{}
|
||||
var actualModel string
|
||||
if authType == "apikey" {
|
||||
// Use simpler payload for API Key (without Claude Code specific fields)
|
||||
// Apply model mapping if configured
|
||||
actualModel = account.GetMappedModel(testModel)
|
||||
payload = createApiKeyTestPayload(actualModel)
|
||||
} else {
|
||||
actualModel = testModel
|
||||
payload = createTestPayload()
|
||||
}
|
||||
payloadBytes, _ := json.Marshal(payload)
|
||||
|
||||
// Send test_start event with model info
|
||||
s.sendEvent(c, TestEvent{Type: "test_start", Model: actualModel})
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", apiURL, bytes.NewReader(payloadBytes))
|
||||
if err != nil {
|
||||
return s.sendErrorAndEnd(c, "Failed to create request")
|
||||
}
|
||||
|
||||
// Set headers based on auth type
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("anthropic-version", "2023-06-01")
|
||||
|
||||
if authType == "bearer" {
|
||||
req.Header.Set("Authorization", "Bearer "+authToken)
|
||||
req.Header.Set("anthropic-beta", "prompt-caching-2024-07-31,interleaved-thinking-2025-05-14,output-128k-2025-02-19")
|
||||
} else {
|
||||
// API Key uses x-api-key header
|
||||
req.Header.Set("x-api-key", authToken)
|
||||
}
|
||||
|
||||
// Configure proxy if account has one
|
||||
transport := http.DefaultTransport.(*http.Transport).Clone()
|
||||
if account.ProxyID != nil && account.Proxy != nil {
|
||||
proxyURL := account.Proxy.URL()
|
||||
if proxyURL != "" {
|
||||
if parsedURL, err := url.Parse(proxyURL); err == nil {
|
||||
transport.Proxy = http.ProxyURL(parsedURL)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
client := &http.Client{
|
||||
Transport: transport,
|
||||
Timeout: 60 * time.Second,
|
||||
}
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return s.sendErrorAndEnd(c, fmt.Sprintf("Request failed: %s", err.Error()))
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return s.sendErrorAndEnd(c, fmt.Sprintf("API returned %d: %s", resp.StatusCode, string(body)))
|
||||
}
|
||||
|
||||
// Process SSE stream
|
||||
return s.processStream(c, resp.Body)
|
||||
}
|
||||
|
||||
// processStream processes the SSE stream from Claude API
|
||||
func (s *AccountTestService) processStream(c *gin.Context, body io.Reader) error {
|
||||
reader := bufio.NewReader(body)
|
||||
|
||||
for {
|
||||
line, err := reader.ReadString('\n')
|
||||
if err != nil {
|
||||
if err == io.EOF {
|
||||
// Stream ended, send complete event
|
||||
s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
|
||||
return nil
|
||||
}
|
||||
return s.sendErrorAndEnd(c, fmt.Sprintf("Stream read error: %s", err.Error()))
|
||||
}
|
||||
|
||||
line = strings.TrimSpace(line)
|
||||
if line == "" || !strings.HasPrefix(line, "data: ") {
|
||||
continue
|
||||
}
|
||||
|
||||
jsonStr := strings.TrimPrefix(line, "data: ")
|
||||
if jsonStr == "[DONE]" {
|
||||
s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
|
||||
return nil
|
||||
}
|
||||
|
||||
var data map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(jsonStr), &data); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
eventType, _ := data["type"].(string)
|
||||
|
||||
switch eventType {
|
||||
case "content_block_delta":
|
||||
if delta, ok := data["delta"].(map[string]interface{}); ok {
|
||||
if text, ok := delta["text"].(string); ok {
|
||||
s.sendEvent(c, TestEvent{Type: "content", Text: text})
|
||||
}
|
||||
}
|
||||
case "message_stop":
|
||||
s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
|
||||
return nil
|
||||
case "error":
|
||||
errorMsg := "Unknown error"
|
||||
if errData, ok := data["error"].(map[string]interface{}); ok {
|
||||
if msg, ok := errData["message"].(string); ok {
|
||||
errorMsg = msg
|
||||
}
|
||||
}
|
||||
return s.sendErrorAndEnd(c, errorMsg)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// sendEvent sends a SSE event to the client
|
||||
func (s *AccountTestService) sendEvent(c *gin.Context, event TestEvent) {
|
||||
eventJSON, _ := json.Marshal(event)
|
||||
fmt.Fprintf(c.Writer, "data: %s\n\n", eventJSON)
|
||||
c.Writer.Flush()
|
||||
}
|
||||
|
||||
// sendErrorAndEnd sends an error event and ends the stream
|
||||
func (s *AccountTestService) sendErrorAndEnd(c *gin.Context, errorMsg string) error {
|
||||
log.Printf("Account test error: %s", errorMsg)
|
||||
s.sendEvent(c, TestEvent{Type: "error", Error: errorMsg})
|
||||
return fmt.Errorf(errorMsg)
|
||||
}
|
||||
345
backend/internal/service/account_usage_service.go
Normal file
345
backend/internal/service/account_usage_service.go
Normal file
@@ -0,0 +1,345 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/repository"
|
||||
)
|
||||
|
||||
// usageCache 用于缓存usage数据
|
||||
type usageCache struct {
|
||||
data *UsageInfo
|
||||
timestamp time.Time
|
||||
}
|
||||
|
||||
var (
|
||||
usageCacheMap = sync.Map{}
|
||||
cacheTTL = 10 * time.Minute
|
||||
)
|
||||
|
||||
// WindowStats 窗口期统计
|
||||
type WindowStats struct {
|
||||
Requests int64 `json:"requests"`
|
||||
Tokens int64 `json:"tokens"`
|
||||
Cost float64 `json:"cost"`
|
||||
}
|
||||
|
||||
// UsageProgress 使用量进度
|
||||
type UsageProgress struct {
|
||||
Utilization float64 `json:"utilization"` // 使用率百分比 (0-100+,100表示100%)
|
||||
ResetsAt *time.Time `json:"resets_at"` // 重置时间
|
||||
RemainingSeconds int `json:"remaining_seconds"` // 距重置剩余秒数
|
||||
WindowStats *WindowStats `json:"window_stats,omitempty"` // 窗口期统计(从窗口开始到当前的使用量)
|
||||
}
|
||||
|
||||
// UsageInfo 账号使用量信息
|
||||
type UsageInfo struct {
|
||||
UpdatedAt *time.Time `json:"updated_at,omitempty"` // 更新时间
|
||||
FiveHour *UsageProgress `json:"five_hour"` // 5小时窗口
|
||||
SevenDay *UsageProgress `json:"seven_day,omitempty"` // 7天窗口
|
||||
SevenDaySonnet *UsageProgress `json:"seven_day_sonnet,omitempty"` // 7天Sonnet窗口
|
||||
}
|
||||
|
||||
// ClaudeUsageResponse Anthropic API返回的usage结构
|
||||
type ClaudeUsageResponse struct {
|
||||
FiveHour struct {
|
||||
Utilization float64 `json:"utilization"`
|
||||
ResetsAt string `json:"resets_at"`
|
||||
} `json:"five_hour"`
|
||||
SevenDay struct {
|
||||
Utilization float64 `json:"utilization"`
|
||||
ResetsAt string `json:"resets_at"`
|
||||
} `json:"seven_day"`
|
||||
SevenDaySonnet struct {
|
||||
Utilization float64 `json:"utilization"`
|
||||
ResetsAt string `json:"resets_at"`
|
||||
} `json:"seven_day_sonnet"`
|
||||
}
|
||||
|
||||
// AccountUsageService 账号使用量查询服务
|
||||
type AccountUsageService struct {
|
||||
repos *repository.Repositories
|
||||
oauthService *OAuthService
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
// NewAccountUsageService 创建AccountUsageService实例
|
||||
func NewAccountUsageService(repos *repository.Repositories, oauthService *OAuthService) *AccountUsageService {
|
||||
return &AccountUsageService{
|
||||
repos: repos,
|
||||
oauthService: oauthService,
|
||||
httpClient: &http.Client{
|
||||
Timeout: 30 * time.Second,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// GetUsage 获取账号使用量
|
||||
// OAuth账号: 调用Anthropic API获取真实数据(需要profile scope),缓存10分钟
|
||||
// Setup Token账号: 根据session_window推算5h窗口,7d数据不可用(没有profile scope)
|
||||
// API Key账号: 不支持usage查询
|
||||
func (s *AccountUsageService) GetUsage(ctx context.Context, accountID int64) (*UsageInfo, error) {
|
||||
account, err := s.repos.Account.GetByID(ctx, accountID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get account failed: %w", err)
|
||||
}
|
||||
|
||||
// 只有oauth类型账号可以通过API获取usage(有profile scope)
|
||||
if account.CanGetUsage() {
|
||||
// 检查缓存
|
||||
if cached, ok := usageCacheMap.Load(accountID); ok {
|
||||
cache := cached.(*usageCache)
|
||||
if time.Since(cache.timestamp) < cacheTTL {
|
||||
return cache.data, nil
|
||||
}
|
||||
}
|
||||
|
||||
// 从API获取数据
|
||||
usage, err := s.fetchOAuthUsage(ctx, account)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 添加5h窗口统计数据
|
||||
s.addWindowStats(ctx, account, usage)
|
||||
|
||||
// 缓存结果
|
||||
usageCacheMap.Store(accountID, &usageCache{
|
||||
data: usage,
|
||||
timestamp: time.Now(),
|
||||
})
|
||||
|
||||
return usage, nil
|
||||
}
|
||||
|
||||
// Setup Token账号:根据session_window推算(没有profile scope,无法调用usage API)
|
||||
if account.Type == model.AccountTypeSetupToken {
|
||||
usage := s.estimateSetupTokenUsage(account)
|
||||
// 添加窗口统计
|
||||
s.addWindowStats(ctx, account, usage)
|
||||
return usage, nil
|
||||
}
|
||||
|
||||
// API Key账号不支持usage查询
|
||||
return nil, fmt.Errorf("account type %s does not support usage query", account.Type)
|
||||
}
|
||||
|
||||
// addWindowStats 为usage数据添加窗口期统计
|
||||
func (s *AccountUsageService) addWindowStats(ctx context.Context, account *model.Account, usage *UsageInfo) {
|
||||
if usage.FiveHour == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// 使用session_window_start作为统计起始时间
|
||||
var startTime time.Time
|
||||
if account.SessionWindowStart != nil {
|
||||
startTime = *account.SessionWindowStart
|
||||
} else {
|
||||
// 如果没有窗口信息,使用5小时前作为默认
|
||||
startTime = time.Now().Add(-5 * time.Hour)
|
||||
}
|
||||
|
||||
stats, err := s.repos.UsageLog.GetAccountWindowStats(ctx, account.ID, startTime)
|
||||
if err != nil {
|
||||
log.Printf("Failed to get window stats for account %d: %v", account.ID, err)
|
||||
return
|
||||
}
|
||||
|
||||
usage.FiveHour.WindowStats = &WindowStats{
|
||||
Requests: stats.Requests,
|
||||
Tokens: stats.Tokens,
|
||||
Cost: stats.Cost,
|
||||
}
|
||||
}
|
||||
|
||||
// GetTodayStats 获取账号今日统计
|
||||
func (s *AccountUsageService) GetTodayStats(ctx context.Context, accountID int64) (*WindowStats, error) {
|
||||
stats, err := s.repos.UsageLog.GetAccountTodayStats(ctx, accountID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get today stats failed: %w", err)
|
||||
}
|
||||
|
||||
return &WindowStats{
|
||||
Requests: stats.Requests,
|
||||
Tokens: stats.Tokens,
|
||||
Cost: stats.Cost,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// fetchOAuthUsage 从Anthropic API获取OAuth账号的使用量
|
||||
func (s *AccountUsageService) fetchOAuthUsage(ctx context.Context, account *model.Account) (*UsageInfo, error) {
|
||||
// 获取access token(从credentials中获取)
|
||||
accessToken := account.GetCredential("access_token")
|
||||
if accessToken == "" {
|
||||
return nil, fmt.Errorf("no access token available")
|
||||
}
|
||||
|
||||
// 获取代理配置
|
||||
transport := http.DefaultTransport.(*http.Transport).Clone()
|
||||
if account.ProxyID != nil && account.Proxy != nil {
|
||||
proxyURL := account.Proxy.URL()
|
||||
if proxyURL != "" {
|
||||
if parsedURL, err := url.Parse(proxyURL); err == nil {
|
||||
transport.Proxy = http.ProxyURL(parsedURL)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
client := &http.Client{
|
||||
Transport: transport,
|
||||
Timeout: 30 * time.Second,
|
||||
}
|
||||
|
||||
// 构建请求
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", "https://api.anthropic.com/api/oauth/usage", nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create request failed: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
req.Header.Set("anthropic-beta", "oauth-2025-04-20")
|
||||
|
||||
// 发送请求
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return nil, fmt.Errorf("API returned status %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
// 解析响应
|
||||
var usageResp ClaudeUsageResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&usageResp); err != nil {
|
||||
return nil, fmt.Errorf("decode response failed: %w", err)
|
||||
}
|
||||
|
||||
// 转换为UsageInfo
|
||||
now := time.Now()
|
||||
return s.buildUsageInfo(&usageResp, &now), nil
|
||||
}
|
||||
|
||||
// parseTime 尝试多种格式解析时间
|
||||
func parseTime(s string) (time.Time, error) {
|
||||
formats := []string{
|
||||
time.RFC3339,
|
||||
time.RFC3339Nano,
|
||||
"2006-01-02T15:04:05Z",
|
||||
"2006-01-02T15:04:05.000Z",
|
||||
}
|
||||
for _, format := range formats {
|
||||
if t, err := time.Parse(format, s); err == nil {
|
||||
return t, nil
|
||||
}
|
||||
}
|
||||
return time.Time{}, fmt.Errorf("unable to parse time: %s", s)
|
||||
}
|
||||
|
||||
// buildUsageInfo 构建UsageInfo
|
||||
func (s *AccountUsageService) buildUsageInfo(resp *ClaudeUsageResponse, updatedAt *time.Time) *UsageInfo {
|
||||
info := &UsageInfo{
|
||||
UpdatedAt: updatedAt,
|
||||
}
|
||||
|
||||
// 5小时窗口
|
||||
if resp.FiveHour.ResetsAt != "" {
|
||||
if fiveHourReset, err := parseTime(resp.FiveHour.ResetsAt); err == nil {
|
||||
info.FiveHour = &UsageProgress{
|
||||
Utilization: resp.FiveHour.Utilization,
|
||||
ResetsAt: &fiveHourReset,
|
||||
RemainingSeconds: int(time.Until(fiveHourReset).Seconds()),
|
||||
}
|
||||
} else {
|
||||
log.Printf("Failed to parse FiveHour.ResetsAt: %s, error: %v", resp.FiveHour.ResetsAt, err)
|
||||
// 即使解析失败也返回utilization
|
||||
info.FiveHour = &UsageProgress{
|
||||
Utilization: resp.FiveHour.Utilization,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 7天窗口
|
||||
if resp.SevenDay.ResetsAt != "" {
|
||||
if sevenDayReset, err := parseTime(resp.SevenDay.ResetsAt); err == nil {
|
||||
info.SevenDay = &UsageProgress{
|
||||
Utilization: resp.SevenDay.Utilization,
|
||||
ResetsAt: &sevenDayReset,
|
||||
RemainingSeconds: int(time.Until(sevenDayReset).Seconds()),
|
||||
}
|
||||
} else {
|
||||
log.Printf("Failed to parse SevenDay.ResetsAt: %s, error: %v", resp.SevenDay.ResetsAt, err)
|
||||
info.SevenDay = &UsageProgress{
|
||||
Utilization: resp.SevenDay.Utilization,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 7天Sonnet窗口
|
||||
if resp.SevenDaySonnet.ResetsAt != "" {
|
||||
if sonnetReset, err := parseTime(resp.SevenDaySonnet.ResetsAt); err == nil {
|
||||
info.SevenDaySonnet = &UsageProgress{
|
||||
Utilization: resp.SevenDaySonnet.Utilization,
|
||||
ResetsAt: &sonnetReset,
|
||||
RemainingSeconds: int(time.Until(sonnetReset).Seconds()),
|
||||
}
|
||||
} else {
|
||||
log.Printf("Failed to parse SevenDaySonnet.ResetsAt: %s, error: %v", resp.SevenDaySonnet.ResetsAt, err)
|
||||
info.SevenDaySonnet = &UsageProgress{
|
||||
Utilization: resp.SevenDaySonnet.Utilization,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return info
|
||||
}
|
||||
|
||||
// estimateSetupTokenUsage 根据session_window推算Setup Token账号的使用量
|
||||
func (s *AccountUsageService) estimateSetupTokenUsage(account *model.Account) *UsageInfo {
|
||||
info := &UsageInfo{}
|
||||
|
||||
// 如果有session_window信息
|
||||
if account.SessionWindowEnd != nil {
|
||||
remaining := int(time.Until(*account.SessionWindowEnd).Seconds())
|
||||
if remaining < 0 {
|
||||
remaining = 0
|
||||
}
|
||||
|
||||
// 根据状态估算使用率 (百分比形式,100 = 100%)
|
||||
var utilization float64
|
||||
switch account.SessionWindowStatus {
|
||||
case "rejected":
|
||||
utilization = 100.0
|
||||
case "allowed_warning":
|
||||
utilization = 80.0
|
||||
default:
|
||||
utilization = 0.0
|
||||
}
|
||||
|
||||
info.FiveHour = &UsageProgress{
|
||||
Utilization: utilization,
|
||||
ResetsAt: account.SessionWindowEnd,
|
||||
RemainingSeconds: remaining,
|
||||
}
|
||||
} else {
|
||||
// 没有窗口信息,返回空数据
|
||||
info.FiveHour = &UsageProgress{
|
||||
Utilization: 0,
|
||||
RemainingSeconds: 0,
|
||||
}
|
||||
}
|
||||
|
||||
// Setup Token无法获取7d数据
|
||||
return info
|
||||
}
|
||||
989
backend/internal/service/admin_service.go
Normal file
989
backend/internal/service/admin_service.go
Normal file
@@ -0,0 +1,989 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/repository"
|
||||
|
||||
"golang.org/x/net/proxy"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// AdminService interface defines admin management operations
|
||||
type AdminService interface {
|
||||
// User management
|
||||
ListUsers(ctx context.Context, page, pageSize int, status, role, search string) ([]model.User, int64, error)
|
||||
GetUser(ctx context.Context, id int64) (*model.User, error)
|
||||
CreateUser(ctx context.Context, input *CreateUserInput) (*model.User, error)
|
||||
UpdateUser(ctx context.Context, id int64, input *UpdateUserInput) (*model.User, error)
|
||||
DeleteUser(ctx context.Context, id int64) error
|
||||
UpdateUserBalance(ctx context.Context, userID int64, balance float64, operation string) (*model.User, error)
|
||||
GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int) ([]model.ApiKey, int64, error)
|
||||
GetUserUsageStats(ctx context.Context, userID int64, period string) (interface{}, error)
|
||||
|
||||
// Group management
|
||||
ListGroups(ctx context.Context, page, pageSize int, platform, status string, isExclusive *bool) ([]model.Group, int64, error)
|
||||
GetAllGroups(ctx context.Context) ([]model.Group, error)
|
||||
GetAllGroupsByPlatform(ctx context.Context, platform string) ([]model.Group, error)
|
||||
GetGroup(ctx context.Context, id int64) (*model.Group, error)
|
||||
CreateGroup(ctx context.Context, input *CreateGroupInput) (*model.Group, error)
|
||||
UpdateGroup(ctx context.Context, id int64, input *UpdateGroupInput) (*model.Group, error)
|
||||
DeleteGroup(ctx context.Context, id int64) error
|
||||
GetGroupAPIKeys(ctx context.Context, groupID int64, page, pageSize int) ([]model.ApiKey, int64, error)
|
||||
|
||||
// Account management
|
||||
ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string) ([]model.Account, int64, error)
|
||||
GetAccount(ctx context.Context, id int64) (*model.Account, error)
|
||||
CreateAccount(ctx context.Context, input *CreateAccountInput) (*model.Account, error)
|
||||
UpdateAccount(ctx context.Context, id int64, input *UpdateAccountInput) (*model.Account, error)
|
||||
DeleteAccount(ctx context.Context, id int64) error
|
||||
RefreshAccountCredentials(ctx context.Context, id int64) (*model.Account, error)
|
||||
ClearAccountError(ctx context.Context, id int64) (*model.Account, error)
|
||||
SetAccountSchedulable(ctx context.Context, id int64, schedulable bool) (*model.Account, error)
|
||||
|
||||
// Proxy management
|
||||
ListProxies(ctx context.Context, page, pageSize int, protocol, status, search string) ([]model.Proxy, int64, error)
|
||||
GetAllProxies(ctx context.Context) ([]model.Proxy, error)
|
||||
GetAllProxiesWithAccountCount(ctx context.Context) ([]model.ProxyWithAccountCount, error)
|
||||
GetProxy(ctx context.Context, id int64) (*model.Proxy, error)
|
||||
CreateProxy(ctx context.Context, input *CreateProxyInput) (*model.Proxy, error)
|
||||
UpdateProxy(ctx context.Context, id int64, input *UpdateProxyInput) (*model.Proxy, error)
|
||||
DeleteProxy(ctx context.Context, id int64) error
|
||||
GetProxyAccounts(ctx context.Context, proxyID int64, page, pageSize int) ([]model.Account, int64, error)
|
||||
CheckProxyExists(ctx context.Context, host string, port int, username, password string) (bool, error)
|
||||
TestProxy(ctx context.Context, id int64) (*ProxyTestResult, error)
|
||||
|
||||
// Redeem code management
|
||||
ListRedeemCodes(ctx context.Context, page, pageSize int, codeType, status, search string) ([]model.RedeemCode, int64, error)
|
||||
GetRedeemCode(ctx context.Context, id int64) (*model.RedeemCode, error)
|
||||
GenerateRedeemCodes(ctx context.Context, input *GenerateRedeemCodesInput) ([]model.RedeemCode, error)
|
||||
DeleteRedeemCode(ctx context.Context, id int64) error
|
||||
BatchDeleteRedeemCodes(ctx context.Context, ids []int64) (int64, error)
|
||||
ExpireRedeemCode(ctx context.Context, id int64) (*model.RedeemCode, error)
|
||||
}
|
||||
|
||||
// Input types for admin operations
|
||||
type CreateUserInput struct {
|
||||
Email string
|
||||
Password string
|
||||
Balance float64
|
||||
Concurrency int
|
||||
AllowedGroups []int64
|
||||
}
|
||||
|
||||
type UpdateUserInput struct {
|
||||
Email string
|
||||
Password string
|
||||
Balance *float64 // 使用指针区分"未提供"和"设置为0"
|
||||
Concurrency *int // 使用指针区分"未提供"和"设置为0"
|
||||
Status string
|
||||
AllowedGroups *[]int64 // 使用指针区分"未提供"和"设置为空数组"
|
||||
}
|
||||
|
||||
type CreateGroupInput struct {
|
||||
Name string
|
||||
Description string
|
||||
Platform string
|
||||
RateMultiplier float64
|
||||
IsExclusive bool
|
||||
SubscriptionType string // standard/subscription
|
||||
DailyLimitUSD *float64 // 日限额 (USD)
|
||||
WeeklyLimitUSD *float64 // 周限额 (USD)
|
||||
MonthlyLimitUSD *float64 // 月限额 (USD)
|
||||
}
|
||||
|
||||
type UpdateGroupInput struct {
|
||||
Name string
|
||||
Description string
|
||||
Platform string
|
||||
RateMultiplier *float64 // 使用指针以支持设置为0
|
||||
IsExclusive *bool
|
||||
Status string
|
||||
SubscriptionType string // standard/subscription
|
||||
DailyLimitUSD *float64 // 日限额 (USD)
|
||||
WeeklyLimitUSD *float64 // 周限额 (USD)
|
||||
MonthlyLimitUSD *float64 // 月限额 (USD)
|
||||
}
|
||||
|
||||
type CreateAccountInput struct {
|
||||
Name string
|
||||
Platform string
|
||||
Type string
|
||||
Credentials map[string]interface{}
|
||||
Extra map[string]interface{}
|
||||
ProxyID *int64
|
||||
Concurrency int
|
||||
Priority int
|
||||
GroupIDs []int64
|
||||
}
|
||||
|
||||
type UpdateAccountInput struct {
|
||||
Name string
|
||||
Type string // Account type: oauth, setup-token, apikey
|
||||
Credentials map[string]interface{}
|
||||
Extra map[string]interface{}
|
||||
ProxyID *int64
|
||||
Concurrency *int // 使用指针区分"未提供"和"设置为0"
|
||||
Priority *int // 使用指针区分"未提供"和"设置为0"
|
||||
Status string
|
||||
GroupIDs *[]int64
|
||||
}
|
||||
|
||||
type CreateProxyInput struct {
|
||||
Name string
|
||||
Protocol string
|
||||
Host string
|
||||
Port int
|
||||
Username string
|
||||
Password string
|
||||
}
|
||||
|
||||
type UpdateProxyInput struct {
|
||||
Name string
|
||||
Protocol string
|
||||
Host string
|
||||
Port int
|
||||
Username string
|
||||
Password string
|
||||
Status string
|
||||
}
|
||||
|
||||
type GenerateRedeemCodesInput struct {
|
||||
Count int
|
||||
Type string
|
||||
Value float64
|
||||
GroupID *int64 // 订阅类型专用:关联的分组ID
|
||||
ValidityDays int // 订阅类型专用:有效天数
|
||||
}
|
||||
|
||||
// ProxyTestResult represents the result of testing a proxy
|
||||
type ProxyTestResult struct {
|
||||
Success bool `json:"success"`
|
||||
Message string `json:"message"`
|
||||
LatencyMs int64 `json:"latency_ms,omitempty"`
|
||||
IPAddress string `json:"ip_address,omitempty"`
|
||||
City string `json:"city,omitempty"`
|
||||
Region string `json:"region,omitempty"`
|
||||
Country string `json:"country,omitempty"`
|
||||
}
|
||||
|
||||
// adminServiceImpl implements AdminService
|
||||
type adminServiceImpl struct {
|
||||
userRepo *repository.UserRepository
|
||||
groupRepo *repository.GroupRepository
|
||||
accountRepo *repository.AccountRepository
|
||||
proxyRepo *repository.ProxyRepository
|
||||
apiKeyRepo *repository.ApiKeyRepository
|
||||
redeemCodeRepo *repository.RedeemCodeRepository
|
||||
usageLogRepo *repository.UsageLogRepository
|
||||
userSubRepo *repository.UserSubscriptionRepository
|
||||
billingCacheService *BillingCacheService
|
||||
}
|
||||
|
||||
// NewAdminService creates a new AdminService
|
||||
func NewAdminService(repos *repository.Repositories) AdminService {
|
||||
return &adminServiceImpl{
|
||||
userRepo: repos.User,
|
||||
groupRepo: repos.Group,
|
||||
accountRepo: repos.Account,
|
||||
proxyRepo: repos.Proxy,
|
||||
apiKeyRepo: repos.ApiKey,
|
||||
redeemCodeRepo: repos.RedeemCode,
|
||||
usageLogRepo: repos.UsageLog,
|
||||
userSubRepo: repos.UserSubscription,
|
||||
}
|
||||
}
|
||||
|
||||
// SetBillingCacheService 设置计费缓存服务(用于缓存失效)
|
||||
// 注意:AdminService是接口,需要类型断言
|
||||
func SetAdminServiceBillingCache(adminService AdminService, billingCacheService *BillingCacheService) {
|
||||
if impl, ok := adminService.(*adminServiceImpl); ok {
|
||||
impl.billingCacheService = billingCacheService
|
||||
}
|
||||
}
|
||||
|
||||
// User management implementations
|
||||
func (s *adminServiceImpl) ListUsers(ctx context.Context, page, pageSize int, status, role, search string) ([]model.User, int64, error) {
|
||||
params := repository.PaginationParams{Page: page, PageSize: pageSize}
|
||||
users, result, err := s.userRepo.ListWithFilters(ctx, params, status, role, search)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
return users, result.Total, nil
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) GetUser(ctx context.Context, id int64) (*model.User, error) {
|
||||
return s.userRepo.GetByID(ctx, id)
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) CreateUser(ctx context.Context, input *CreateUserInput) (*model.User, error) {
|
||||
user := &model.User{
|
||||
Email: input.Email,
|
||||
Role: "user", // Always create as regular user, never admin
|
||||
Balance: input.Balance,
|
||||
Concurrency: input.Concurrency,
|
||||
Status: model.StatusActive,
|
||||
}
|
||||
if err := user.SetPassword(input.Password); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := s.userRepo.Create(ctx, user); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *UpdateUserInput) (*model.User, error) {
|
||||
user, err := s.userRepo.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Protect admin users: cannot disable admin accounts
|
||||
if user.Role == "admin" && input.Status == "disabled" {
|
||||
return nil, errors.New("cannot disable admin user")
|
||||
}
|
||||
|
||||
// Track balance and concurrency changes for logging
|
||||
oldBalance := user.Balance
|
||||
oldConcurrency := user.Concurrency
|
||||
|
||||
if input.Email != "" {
|
||||
user.Email = input.Email
|
||||
}
|
||||
if input.Password != "" {
|
||||
if err := user.SetPassword(input.Password); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
// Role is not allowed to be changed via API to prevent privilege escalation
|
||||
if input.Status != "" {
|
||||
user.Status = input.Status
|
||||
}
|
||||
|
||||
// 只在指针非 nil 时更新 Balance(支持设置为 0)
|
||||
if input.Balance != nil {
|
||||
user.Balance = *input.Balance
|
||||
}
|
||||
|
||||
// 只在指针非 nil 时更新 Concurrency(支持设置为任意值)
|
||||
if input.Concurrency != nil {
|
||||
user.Concurrency = *input.Concurrency
|
||||
}
|
||||
|
||||
// 只在指针非 nil 时更新 AllowedGroups
|
||||
if input.AllowedGroups != nil {
|
||||
user.AllowedGroups = *input.AllowedGroups
|
||||
}
|
||||
|
||||
if err := s.userRepo.Update(ctx, user); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 余额变化时失效缓存
|
||||
if input.Balance != nil && *input.Balance != oldBalance {
|
||||
if s.billingCacheService != nil {
|
||||
go func() {
|
||||
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
s.billingCacheService.InvalidateUserBalance(cacheCtx, id)
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
// Create adjustment records for balance/concurrency changes
|
||||
balanceDiff := user.Balance - oldBalance
|
||||
if balanceDiff != 0 {
|
||||
adjustmentRecord := &model.RedeemCode{
|
||||
Code: model.GenerateRedeemCode(),
|
||||
Type: model.AdjustmentTypeAdminBalance,
|
||||
Value: balanceDiff,
|
||||
Status: model.StatusUsed,
|
||||
UsedBy: &user.ID,
|
||||
}
|
||||
now := time.Now()
|
||||
adjustmentRecord.UsedAt = &now
|
||||
if err := s.redeemCodeRepo.Create(ctx, adjustmentRecord); err != nil {
|
||||
// Log error but don't fail the update
|
||||
// The user update has already succeeded
|
||||
}
|
||||
}
|
||||
|
||||
concurrencyDiff := user.Concurrency - oldConcurrency
|
||||
if concurrencyDiff != 0 {
|
||||
adjustmentRecord := &model.RedeemCode{
|
||||
Code: model.GenerateRedeemCode(),
|
||||
Type: model.AdjustmentTypeAdminConcurrency,
|
||||
Value: float64(concurrencyDiff),
|
||||
Status: model.StatusUsed,
|
||||
UsedBy: &user.ID,
|
||||
}
|
||||
now := time.Now()
|
||||
adjustmentRecord.UsedAt = &now
|
||||
if err := s.redeemCodeRepo.Create(ctx, adjustmentRecord); err != nil {
|
||||
// Log error but don't fail the update
|
||||
// The user update has already succeeded
|
||||
}
|
||||
}
|
||||
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) DeleteUser(ctx context.Context, id int64) error {
|
||||
// Protect admin users: cannot delete admin accounts
|
||||
user, err := s.userRepo.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if user.Role == "admin" {
|
||||
return errors.New("cannot delete admin user")
|
||||
}
|
||||
return s.userRepo.Delete(ctx, id)
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) UpdateUserBalance(ctx context.Context, userID int64, balance float64, operation string) (*model.User, error) {
|
||||
user, err := s.userRepo.GetByID(ctx, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
switch operation {
|
||||
case "set":
|
||||
user.Balance = balance
|
||||
case "add":
|
||||
user.Balance += balance
|
||||
case "subtract":
|
||||
user.Balance -= balance
|
||||
}
|
||||
|
||||
if err := s.userRepo.Update(ctx, user); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 失效余额缓存
|
||||
if s.billingCacheService != nil {
|
||||
go func() {
|
||||
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
s.billingCacheService.InvalidateUserBalance(cacheCtx, userID)
|
||||
}()
|
||||
}
|
||||
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int) ([]model.ApiKey, int64, error) {
|
||||
params := repository.PaginationParams{Page: page, PageSize: pageSize}
|
||||
keys, result, err := s.apiKeyRepo.ListByUserID(ctx, userID, params)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
return keys, result.Total, nil
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) GetUserUsageStats(ctx context.Context, userID int64, period string) (interface{}, error) {
|
||||
// Return mock data for now
|
||||
return map[string]interface{}{
|
||||
"period": period,
|
||||
"total_requests": 0,
|
||||
"total_cost": 0.0,
|
||||
"total_tokens": 0,
|
||||
"avg_duration_ms": 0,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Group management implementations
|
||||
func (s *adminServiceImpl) ListGroups(ctx context.Context, page, pageSize int, platform, status string, isExclusive *bool) ([]model.Group, int64, error) {
|
||||
params := repository.PaginationParams{Page: page, PageSize: pageSize}
|
||||
groups, result, err := s.groupRepo.ListWithFilters(ctx, params, platform, status, isExclusive)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
return groups, result.Total, nil
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) GetAllGroups(ctx context.Context) ([]model.Group, error) {
|
||||
return s.groupRepo.ListActive(ctx)
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) GetAllGroupsByPlatform(ctx context.Context, platform string) ([]model.Group, error) {
|
||||
return s.groupRepo.ListActiveByPlatform(ctx, platform)
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) GetGroup(ctx context.Context, id int64) (*model.Group, error) {
|
||||
return s.groupRepo.GetByID(ctx, id)
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupInput) (*model.Group, error) {
|
||||
platform := input.Platform
|
||||
if platform == "" {
|
||||
platform = model.PlatformAnthropic
|
||||
}
|
||||
|
||||
subscriptionType := input.SubscriptionType
|
||||
if subscriptionType == "" {
|
||||
subscriptionType = model.SubscriptionTypeStandard
|
||||
}
|
||||
|
||||
group := &model.Group{
|
||||
Name: input.Name,
|
||||
Description: input.Description,
|
||||
Platform: platform,
|
||||
RateMultiplier: input.RateMultiplier,
|
||||
IsExclusive: input.IsExclusive,
|
||||
Status: model.StatusActive,
|
||||
SubscriptionType: subscriptionType,
|
||||
DailyLimitUSD: input.DailyLimitUSD,
|
||||
WeeklyLimitUSD: input.WeeklyLimitUSD,
|
||||
MonthlyLimitUSD: input.MonthlyLimitUSD,
|
||||
}
|
||||
if err := s.groupRepo.Create(ctx, group); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return group, nil
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *UpdateGroupInput) (*model.Group, error) {
|
||||
group, err := s.groupRepo.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if input.Name != "" {
|
||||
group.Name = input.Name
|
||||
}
|
||||
if input.Description != "" {
|
||||
group.Description = input.Description
|
||||
}
|
||||
if input.Platform != "" {
|
||||
group.Platform = input.Platform
|
||||
}
|
||||
if input.RateMultiplier != nil {
|
||||
group.RateMultiplier = *input.RateMultiplier
|
||||
}
|
||||
if input.IsExclusive != nil {
|
||||
group.IsExclusive = *input.IsExclusive
|
||||
}
|
||||
if input.Status != "" {
|
||||
group.Status = input.Status
|
||||
}
|
||||
|
||||
// 订阅相关字段
|
||||
if input.SubscriptionType != "" {
|
||||
group.SubscriptionType = input.SubscriptionType
|
||||
}
|
||||
// 限额字段支持设置为nil(清除限额)或具体值
|
||||
if input.DailyLimitUSD != nil {
|
||||
group.DailyLimitUSD = input.DailyLimitUSD
|
||||
}
|
||||
if input.WeeklyLimitUSD != nil {
|
||||
group.WeeklyLimitUSD = input.WeeklyLimitUSD
|
||||
}
|
||||
if input.MonthlyLimitUSD != nil {
|
||||
group.MonthlyLimitUSD = input.MonthlyLimitUSD
|
||||
}
|
||||
|
||||
if err := s.groupRepo.Update(ctx, group); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return group, nil
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) DeleteGroup(ctx context.Context, id int64) error {
|
||||
// 先获取分组信息,检查是否存在
|
||||
group, err := s.groupRepo.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("group not found: %w", err)
|
||||
}
|
||||
|
||||
// 订阅类型分组:先获取受影响的用户ID列表(用于事务后失效缓存)
|
||||
var affectedUserIDs []int64
|
||||
if group.IsSubscriptionType() && s.billingCacheService != nil {
|
||||
var subscriptions []model.UserSubscription
|
||||
if err := s.groupRepo.DB().WithContext(ctx).
|
||||
Where("group_id = ?", id).
|
||||
Select("user_id").
|
||||
Find(&subscriptions).Error; err == nil {
|
||||
for _, sub := range subscriptions {
|
||||
affectedUserIDs = append(affectedUserIDs, sub.UserID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 使用事务处理所有级联删除
|
||||
db := s.groupRepo.DB()
|
||||
err = db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||||
// 1. 如果是订阅类型分组,删除 user_subscriptions 中的相关记录
|
||||
if group.IsSubscriptionType() {
|
||||
if err := tx.Where("group_id = ?", id).Delete(&model.UserSubscription{}).Error; err != nil {
|
||||
return fmt.Errorf("delete user subscriptions: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// 2. 将 api_keys 中绑定该分组的 group_id 设为 nil(任何类型的分组都需要)
|
||||
if err := tx.Model(&model.ApiKey{}).Where("group_id = ?", id).Update("group_id", nil).Error; err != nil {
|
||||
return fmt.Errorf("clear api key group_id: %w", err)
|
||||
}
|
||||
|
||||
// 3. 从 users.allowed_groups 数组中移除该分组 ID
|
||||
if err := tx.Model(&model.User{}).
|
||||
Where("? = ANY(allowed_groups)", id).
|
||||
Update("allowed_groups", gorm.Expr("array_remove(allowed_groups, ?)", id)).Error; err != nil {
|
||||
return fmt.Errorf("remove from allowed_groups: %w", err)
|
||||
}
|
||||
|
||||
// 4. 删除 account_groups 中间表的数据
|
||||
if err := tx.Where("group_id = ?", id).Delete(&model.AccountGroup{}).Error; err != nil {
|
||||
return fmt.Errorf("delete account groups: %w", err)
|
||||
}
|
||||
|
||||
// 5. 删除分组本身
|
||||
if err := tx.Delete(&model.Group{}, id).Error; err != nil {
|
||||
return fmt.Errorf("delete group: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 事务成功后,异步失效受影响用户的订阅缓存
|
||||
if len(affectedUserIDs) > 0 && s.billingCacheService != nil {
|
||||
groupID := id
|
||||
go func() {
|
||||
cacheCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
for _, userID := range affectedUserIDs {
|
||||
s.billingCacheService.InvalidateSubscription(cacheCtx, userID, groupID)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) GetGroupAPIKeys(ctx context.Context, groupID int64, page, pageSize int) ([]model.ApiKey, int64, error) {
|
||||
params := repository.PaginationParams{Page: page, PageSize: pageSize}
|
||||
keys, result, err := s.apiKeyRepo.ListByGroupID(ctx, groupID, params)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
return keys, result.Total, nil
|
||||
}
|
||||
|
||||
// Account management implementations
|
||||
func (s *adminServiceImpl) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string) ([]model.Account, int64, error) {
|
||||
params := repository.PaginationParams{Page: page, PageSize: pageSize}
|
||||
accounts, result, err := s.accountRepo.ListWithFilters(ctx, params, platform, accountType, status, search)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
return accounts, result.Total, nil
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) GetAccount(ctx context.Context, id int64) (*model.Account, error) {
|
||||
return s.accountRepo.GetByID(ctx, id)
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) CreateAccount(ctx context.Context, input *CreateAccountInput) (*model.Account, error) {
|
||||
account := &model.Account{
|
||||
Name: input.Name,
|
||||
Platform: input.Platform,
|
||||
Type: input.Type,
|
||||
Credentials: model.JSONB(input.Credentials),
|
||||
Extra: model.JSONB(input.Extra),
|
||||
ProxyID: input.ProxyID,
|
||||
Concurrency: input.Concurrency,
|
||||
Priority: input.Priority,
|
||||
Status: model.StatusActive,
|
||||
}
|
||||
if err := s.accountRepo.Create(ctx, account); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// 绑定分组
|
||||
if len(input.GroupIDs) > 0 {
|
||||
if err := s.accountRepo.BindGroups(ctx, account.ID, input.GroupIDs); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return account, nil
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *UpdateAccountInput) (*model.Account, error) {
|
||||
account, err := s.accountRepo.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if input.Name != "" {
|
||||
account.Name = input.Name
|
||||
}
|
||||
if input.Type != "" {
|
||||
account.Type = input.Type
|
||||
}
|
||||
if input.Credentials != nil && len(input.Credentials) > 0 {
|
||||
account.Credentials = model.JSONB(input.Credentials)
|
||||
}
|
||||
if input.Extra != nil && len(input.Extra) > 0 {
|
||||
account.Extra = model.JSONB(input.Extra)
|
||||
}
|
||||
if input.ProxyID != nil {
|
||||
account.ProxyID = input.ProxyID
|
||||
}
|
||||
// 只在指针非 nil 时更新 Concurrency(支持设置为 0)
|
||||
if input.Concurrency != nil {
|
||||
account.Concurrency = *input.Concurrency
|
||||
}
|
||||
// 只在指针非 nil 时更新 Priority(支持设置为 0)
|
||||
if input.Priority != nil {
|
||||
account.Priority = *input.Priority
|
||||
}
|
||||
if input.Status != "" {
|
||||
account.Status = input.Status
|
||||
}
|
||||
|
||||
if err := s.accountRepo.Update(ctx, account); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 更新分组绑定
|
||||
if input.GroupIDs != nil {
|
||||
if err := s.accountRepo.BindGroups(ctx, account.ID, *input.GroupIDs); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return account, nil
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) DeleteAccount(ctx context.Context, id int64) error {
|
||||
return s.accountRepo.Delete(ctx, id)
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) RefreshAccountCredentials(ctx context.Context, id int64) (*model.Account, error) {
|
||||
account, err := s.accountRepo.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// TODO: Implement refresh logic
|
||||
return account, nil
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) ClearAccountError(ctx context.Context, id int64) (*model.Account, error) {
|
||||
account, err := s.accountRepo.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
account.Status = model.StatusActive
|
||||
account.ErrorMessage = ""
|
||||
if err := s.accountRepo.Update(ctx, account); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return account, nil
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) SetAccountSchedulable(ctx context.Context, id int64, schedulable bool) (*model.Account, error) {
|
||||
if err := s.accountRepo.SetSchedulable(ctx, id, schedulable); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s.accountRepo.GetByID(ctx, id)
|
||||
}
|
||||
|
||||
// Proxy management implementations
|
||||
func (s *adminServiceImpl) ListProxies(ctx context.Context, page, pageSize int, protocol, status, search string) ([]model.Proxy, int64, error) {
|
||||
params := repository.PaginationParams{Page: page, PageSize: pageSize}
|
||||
proxies, result, err := s.proxyRepo.ListWithFilters(ctx, params, protocol, status, search)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
return proxies, result.Total, nil
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) GetAllProxies(ctx context.Context) ([]model.Proxy, error) {
|
||||
return s.proxyRepo.ListActive(ctx)
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) GetAllProxiesWithAccountCount(ctx context.Context) ([]model.ProxyWithAccountCount, error) {
|
||||
return s.proxyRepo.ListActiveWithAccountCount(ctx)
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) GetProxy(ctx context.Context, id int64) (*model.Proxy, error) {
|
||||
return s.proxyRepo.GetByID(ctx, id)
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) CreateProxy(ctx context.Context, input *CreateProxyInput) (*model.Proxy, error) {
|
||||
proxy := &model.Proxy{
|
||||
Name: input.Name,
|
||||
Protocol: input.Protocol,
|
||||
Host: input.Host,
|
||||
Port: input.Port,
|
||||
Username: input.Username,
|
||||
Password: input.Password,
|
||||
Status: model.StatusActive,
|
||||
}
|
||||
if err := s.proxyRepo.Create(ctx, proxy); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return proxy, nil
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) UpdateProxy(ctx context.Context, id int64, input *UpdateProxyInput) (*model.Proxy, error) {
|
||||
proxy, err := s.proxyRepo.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if input.Name != "" {
|
||||
proxy.Name = input.Name
|
||||
}
|
||||
if input.Protocol != "" {
|
||||
proxy.Protocol = input.Protocol
|
||||
}
|
||||
if input.Host != "" {
|
||||
proxy.Host = input.Host
|
||||
}
|
||||
if input.Port != 0 {
|
||||
proxy.Port = input.Port
|
||||
}
|
||||
if input.Username != "" {
|
||||
proxy.Username = input.Username
|
||||
}
|
||||
if input.Password != "" {
|
||||
proxy.Password = input.Password
|
||||
}
|
||||
if input.Status != "" {
|
||||
proxy.Status = input.Status
|
||||
}
|
||||
|
||||
if err := s.proxyRepo.Update(ctx, proxy); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return proxy, nil
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) DeleteProxy(ctx context.Context, id int64) error {
|
||||
return s.proxyRepo.Delete(ctx, id)
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) GetProxyAccounts(ctx context.Context, proxyID int64, page, pageSize int) ([]model.Account, int64, error) {
|
||||
// Return mock data for now - would need a dedicated repository method
|
||||
return []model.Account{}, 0, nil
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) CheckProxyExists(ctx context.Context, host string, port int, username, password string) (bool, error) {
|
||||
return s.proxyRepo.ExistsByHostPortAuth(ctx, host, port, username, password)
|
||||
}
|
||||
|
||||
// Redeem code management implementations
|
||||
func (s *adminServiceImpl) ListRedeemCodes(ctx context.Context, page, pageSize int, codeType, status, search string) ([]model.RedeemCode, int64, error) {
|
||||
params := repository.PaginationParams{Page: page, PageSize: pageSize}
|
||||
codes, result, err := s.redeemCodeRepo.ListWithFilters(ctx, params, codeType, status, search)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
return codes, result.Total, nil
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) GetRedeemCode(ctx context.Context, id int64) (*model.RedeemCode, error) {
|
||||
return s.redeemCodeRepo.GetByID(ctx, id)
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) GenerateRedeemCodes(ctx context.Context, input *GenerateRedeemCodesInput) ([]model.RedeemCode, error) {
|
||||
// 如果是订阅类型,验证必须有 GroupID
|
||||
if input.Type == model.RedeemTypeSubscription {
|
||||
if input.GroupID == nil {
|
||||
return nil, errors.New("group_id is required for subscription type")
|
||||
}
|
||||
// 验证分组存在且为订阅类型
|
||||
group, err := s.groupRepo.GetByID(ctx, *input.GroupID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("group not found: %w", err)
|
||||
}
|
||||
if !group.IsSubscriptionType() {
|
||||
return nil, errors.New("group must be subscription type")
|
||||
}
|
||||
}
|
||||
|
||||
codes := make([]model.RedeemCode, 0, input.Count)
|
||||
for i := 0; i < input.Count; i++ {
|
||||
code := model.RedeemCode{
|
||||
Code: model.GenerateRedeemCode(),
|
||||
Type: input.Type,
|
||||
Value: input.Value,
|
||||
Status: model.StatusUnused,
|
||||
}
|
||||
// 订阅类型专用字段
|
||||
if input.Type == model.RedeemTypeSubscription {
|
||||
code.GroupID = input.GroupID
|
||||
code.ValidityDays = input.ValidityDays
|
||||
if code.ValidityDays <= 0 {
|
||||
code.ValidityDays = 30 // 默认30天
|
||||
}
|
||||
}
|
||||
if err := s.redeemCodeRepo.Create(ctx, &code); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
codes = append(codes, code)
|
||||
}
|
||||
return codes, nil
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) DeleteRedeemCode(ctx context.Context, id int64) error {
|
||||
return s.redeemCodeRepo.Delete(ctx, id)
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) BatchDeleteRedeemCodes(ctx context.Context, ids []int64) (int64, error) {
|
||||
var deleted int64
|
||||
for _, id := range ids {
|
||||
if err := s.redeemCodeRepo.Delete(ctx, id); err == nil {
|
||||
deleted++
|
||||
}
|
||||
}
|
||||
return deleted, nil
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) ExpireRedeemCode(ctx context.Context, id int64) (*model.RedeemCode, error) {
|
||||
code, err := s.redeemCodeRepo.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
code.Status = model.StatusExpired
|
||||
if err := s.redeemCodeRepo.Update(ctx, code); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return code, nil
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) TestProxy(ctx context.Context, id int64) (*ProxyTestResult, error) {
|
||||
proxy, err := s.proxyRepo.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return testProxyConnection(ctx, proxy)
|
||||
}
|
||||
|
||||
// testProxyConnection tests proxy connectivity by requesting ipinfo.io/json
|
||||
func testProxyConnection(ctx context.Context, proxy *model.Proxy) (*ProxyTestResult, error) {
|
||||
proxyURL := proxy.URL()
|
||||
|
||||
// Create HTTP client with proxy
|
||||
transport, err := createProxyTransport(proxyURL)
|
||||
if err != nil {
|
||||
return &ProxyTestResult{
|
||||
Success: false,
|
||||
Message: fmt.Sprintf("Failed to create proxy transport: %v", err),
|
||||
}, nil
|
||||
}
|
||||
|
||||
client := &http.Client{
|
||||
Transport: transport,
|
||||
Timeout: 15 * time.Second,
|
||||
}
|
||||
|
||||
// Measure latency
|
||||
startTime := time.Now()
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", "https://ipinfo.io/json", nil)
|
||||
if err != nil {
|
||||
return &ProxyTestResult{
|
||||
Success: false,
|
||||
Message: fmt.Sprintf("Failed to create request: %v", err),
|
||||
}, nil
|
||||
}
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return &ProxyTestResult{
|
||||
Success: false,
|
||||
Message: fmt.Sprintf("Proxy connection failed: %v", err),
|
||||
}, nil
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
latencyMs := time.Since(startTime).Milliseconds()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return &ProxyTestResult{
|
||||
Success: false,
|
||||
Message: fmt.Sprintf("Request failed with status: %d", resp.StatusCode),
|
||||
LatencyMs: latencyMs,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Parse ipinfo.io response
|
||||
var ipInfo struct {
|
||||
IP string `json:"ip"`
|
||||
City string `json:"city"`
|
||||
Region string `json:"region"`
|
||||
Country string `json:"country"`
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return &ProxyTestResult{
|
||||
Success: true,
|
||||
Message: "Proxy is accessible but failed to read response",
|
||||
LatencyMs: latencyMs,
|
||||
}, nil
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(body, &ipInfo); err != nil {
|
||||
return &ProxyTestResult{
|
||||
Success: true,
|
||||
Message: "Proxy is accessible but failed to parse response",
|
||||
LatencyMs: latencyMs,
|
||||
}, nil
|
||||
}
|
||||
|
||||
return &ProxyTestResult{
|
||||
Success: true,
|
||||
Message: "Proxy is accessible",
|
||||
LatencyMs: latencyMs,
|
||||
IPAddress: ipInfo.IP,
|
||||
City: ipInfo.City,
|
||||
Region: ipInfo.Region,
|
||||
Country: ipInfo.Country,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// createProxyTransport creates an HTTP transport with the given proxy URL
|
||||
func createProxyTransport(proxyURL string) (*http.Transport, error) {
|
||||
parsedURL, err := url.Parse(proxyURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid proxy URL: %w", err)
|
||||
}
|
||||
|
||||
transport := &http.Transport{
|
||||
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
|
||||
}
|
||||
|
||||
switch parsedURL.Scheme {
|
||||
case "http", "https":
|
||||
transport.Proxy = http.ProxyURL(parsedURL)
|
||||
case "socks5":
|
||||
dialer, err := proxy.FromURL(parsedURL, proxy.Direct)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create socks5 dialer: %w", err)
|
||||
}
|
||||
transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
return dialer.Dial(network, addr)
|
||||
}
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported proxy protocol: %s", parsedURL.Scheme)
|
||||
}
|
||||
|
||||
return transport, nil
|
||||
}
|
||||
464
backend/internal/service/api_key_service.go
Normal file
464
backend/internal/service/api_key_service.go
Normal file
@@ -0,0 +1,464 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sub2api/internal/config"
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/pkg/timezone"
|
||||
"sub2api/internal/repository"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrApiKeyNotFound = errors.New("api key not found")
|
||||
ErrGroupNotAllowed = errors.New("user is not allowed to bind this group")
|
||||
ErrApiKeyExists = errors.New("api key already exists")
|
||||
ErrApiKeyTooShort = errors.New("api key must be at least 16 characters")
|
||||
ErrApiKeyInvalidChars = errors.New("api key can only contain letters, numbers, underscores, and hyphens")
|
||||
ErrApiKeyRateLimited = errors.New("too many failed attempts, please try again later")
|
||||
)
|
||||
|
||||
const (
|
||||
apiKeyRateLimitKeyPrefix = "apikey:create_rate_limit:"
|
||||
apiKeyMaxErrorsPerHour = 20
|
||||
apiKeyRateLimitDuration = time.Hour
|
||||
)
|
||||
|
||||
// CreateApiKeyRequest 创建API Key请求
|
||||
type CreateApiKeyRequest struct {
|
||||
Name string `json:"name"`
|
||||
GroupID *int64 `json:"group_id"`
|
||||
CustomKey *string `json:"custom_key"` // 可选的自定义key
|
||||
}
|
||||
|
||||
// UpdateApiKeyRequest 更新API Key请求
|
||||
type UpdateApiKeyRequest struct {
|
||||
Name *string `json:"name"`
|
||||
GroupID *int64 `json:"group_id"`
|
||||
Status *string `json:"status"`
|
||||
}
|
||||
|
||||
// ApiKeyService API Key服务
|
||||
type ApiKeyService struct {
|
||||
apiKeyRepo *repository.ApiKeyRepository
|
||||
userRepo *repository.UserRepository
|
||||
groupRepo *repository.GroupRepository
|
||||
userSubRepo *repository.UserSubscriptionRepository
|
||||
rdb *redis.Client
|
||||
cfg *config.Config
|
||||
}
|
||||
|
||||
// NewApiKeyService 创建API Key服务实例
|
||||
func NewApiKeyService(
|
||||
apiKeyRepo *repository.ApiKeyRepository,
|
||||
userRepo *repository.UserRepository,
|
||||
groupRepo *repository.GroupRepository,
|
||||
userSubRepo *repository.UserSubscriptionRepository,
|
||||
rdb *redis.Client,
|
||||
cfg *config.Config,
|
||||
) *ApiKeyService {
|
||||
return &ApiKeyService{
|
||||
apiKeyRepo: apiKeyRepo,
|
||||
userRepo: userRepo,
|
||||
groupRepo: groupRepo,
|
||||
userSubRepo: userSubRepo,
|
||||
rdb: rdb,
|
||||
cfg: cfg,
|
||||
}
|
||||
}
|
||||
|
||||
// GenerateKey 生成随机API Key
|
||||
func (s *ApiKeyService) GenerateKey() (string, error) {
|
||||
// 生成32字节随机数据
|
||||
bytes := make([]byte, 32)
|
||||
if _, err := rand.Read(bytes); err != nil {
|
||||
return "", fmt.Errorf("generate random bytes: %w", err)
|
||||
}
|
||||
|
||||
// 转换为十六进制字符串并添加前缀
|
||||
prefix := s.cfg.Default.ApiKeyPrefix
|
||||
if prefix == "" {
|
||||
prefix = "sk-"
|
||||
}
|
||||
|
||||
key := prefix + hex.EncodeToString(bytes)
|
||||
return key, nil
|
||||
}
|
||||
|
||||
// ValidateCustomKey 验证自定义API Key格式
|
||||
func (s *ApiKeyService) ValidateCustomKey(key string) error {
|
||||
// 检查长度
|
||||
if len(key) < 16 {
|
||||
return ErrApiKeyTooShort
|
||||
}
|
||||
|
||||
// 检查字符:只允许字母、数字、下划线、连字符
|
||||
for _, c := range key {
|
||||
if !((c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') ||
|
||||
(c >= '0' && c <= '9') || c == '_' || c == '-') {
|
||||
return ErrApiKeyInvalidChars
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// checkApiKeyRateLimit 检查用户创建自定义Key的错误次数是否超限
|
||||
func (s *ApiKeyService) checkApiKeyRateLimit(ctx context.Context, userID int64) error {
|
||||
if s.rdb == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
key := fmt.Sprintf("%s%d", apiKeyRateLimitKeyPrefix, userID)
|
||||
|
||||
count, err := s.rdb.Get(ctx, key).Int()
|
||||
if err != nil && !errors.Is(err, redis.Nil) {
|
||||
// Redis 出错时不阻止用户操作
|
||||
return nil
|
||||
}
|
||||
|
||||
if count >= apiKeyMaxErrorsPerHour {
|
||||
return ErrApiKeyRateLimited
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// incrementApiKeyErrorCount 增加用户创建自定义Key的错误计数
|
||||
func (s *ApiKeyService) incrementApiKeyErrorCount(ctx context.Context, userID int64) {
|
||||
if s.rdb == nil {
|
||||
return
|
||||
}
|
||||
|
||||
key := fmt.Sprintf("%s%d", apiKeyRateLimitKeyPrefix, userID)
|
||||
|
||||
pipe := s.rdb.Pipeline()
|
||||
pipe.Incr(ctx, key)
|
||||
pipe.Expire(ctx, key, apiKeyRateLimitDuration)
|
||||
_, _ = pipe.Exec(ctx)
|
||||
}
|
||||
|
||||
// canUserBindGroup 检查用户是否可以绑定指定分组
|
||||
// 对于订阅类型分组:检查用户是否有有效订阅
|
||||
// 对于标准类型分组:使用原有的 AllowedGroups 和 IsExclusive 逻辑
|
||||
func (s *ApiKeyService) canUserBindGroup(ctx context.Context, user *model.User, group *model.Group) bool {
|
||||
// 订阅类型分组:需要有效订阅
|
||||
if group.IsSubscriptionType() {
|
||||
_, err := s.userSubRepo.GetActiveByUserIDAndGroupID(ctx, user.ID, group.ID)
|
||||
return err == nil // 有有效订阅则允许
|
||||
}
|
||||
// 标准类型分组:使用原有逻辑
|
||||
return user.CanBindGroup(group.ID, group.IsExclusive)
|
||||
}
|
||||
|
||||
// Create 创建API Key
|
||||
func (s *ApiKeyService) Create(ctx context.Context, userID int64, req CreateApiKeyRequest) (*model.ApiKey, error) {
|
||||
// 验证用户存在
|
||||
user, err := s.userRepo.GetByID(ctx, userID)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, ErrUserNotFound
|
||||
}
|
||||
return nil, fmt.Errorf("get user: %w", err)
|
||||
}
|
||||
|
||||
// 验证分组权限(如果指定了分组)
|
||||
if req.GroupID != nil {
|
||||
group, err := s.groupRepo.GetByID(ctx, *req.GroupID)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, errors.New("group not found")
|
||||
}
|
||||
return nil, fmt.Errorf("get group: %w", err)
|
||||
}
|
||||
|
||||
// 检查用户是否可以绑定该分组
|
||||
if !s.canUserBindGroup(ctx, user, group) {
|
||||
return nil, ErrGroupNotAllowed
|
||||
}
|
||||
}
|
||||
|
||||
var key string
|
||||
|
||||
// 判断是否使用自定义Key
|
||||
if req.CustomKey != nil && *req.CustomKey != "" {
|
||||
// 检查限流(仅对自定义key进行限流)
|
||||
if err := s.checkApiKeyRateLimit(ctx, userID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 验证自定义Key格式
|
||||
if err := s.ValidateCustomKey(*req.CustomKey); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 检查Key是否已存在
|
||||
exists, err := s.apiKeyRepo.ExistsByKey(ctx, *req.CustomKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("check key exists: %w", err)
|
||||
}
|
||||
if exists {
|
||||
// Key已存在,增加错误计数
|
||||
s.incrementApiKeyErrorCount(ctx, userID)
|
||||
return nil, ErrApiKeyExists
|
||||
}
|
||||
|
||||
key = *req.CustomKey
|
||||
} else {
|
||||
// 生成随机API Key
|
||||
var err error
|
||||
key, err = s.GenerateKey()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("generate key: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// 创建API Key记录
|
||||
apiKey := &model.ApiKey{
|
||||
UserID: userID,
|
||||
Key: key,
|
||||
Name: req.Name,
|
||||
GroupID: req.GroupID,
|
||||
Status: model.StatusActive,
|
||||
}
|
||||
|
||||
if err := s.apiKeyRepo.Create(ctx, apiKey); err != nil {
|
||||
return nil, fmt.Errorf("create api key: %w", err)
|
||||
}
|
||||
|
||||
return apiKey, nil
|
||||
}
|
||||
|
||||
// List 获取用户的API Key列表
|
||||
func (s *ApiKeyService) List(ctx context.Context, userID int64, params repository.PaginationParams) ([]model.ApiKey, *repository.PaginationResult, error) {
|
||||
keys, pagination, err := s.apiKeyRepo.ListByUserID(ctx, userID, params)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("list api keys: %w", err)
|
||||
}
|
||||
return keys, pagination, nil
|
||||
}
|
||||
|
||||
// GetByID 根据ID获取API Key
|
||||
func (s *ApiKeyService) GetByID(ctx context.Context, id int64) (*model.ApiKey, error) {
|
||||
apiKey, err := s.apiKeyRepo.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, ErrApiKeyNotFound
|
||||
}
|
||||
return nil, fmt.Errorf("get api key: %w", err)
|
||||
}
|
||||
return apiKey, nil
|
||||
}
|
||||
|
||||
// GetByKey 根据Key字符串获取API Key(用于认证)
|
||||
func (s *ApiKeyService) GetByKey(ctx context.Context, key string) (*model.ApiKey, error) {
|
||||
// 尝试从Redis缓存获取
|
||||
cacheKey := fmt.Sprintf("apikey:%s", key)
|
||||
|
||||
// 这里可以添加Redis缓存逻辑,暂时直接查询数据库
|
||||
apiKey, err := s.apiKeyRepo.GetByKey(ctx, key)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, ErrApiKeyNotFound
|
||||
}
|
||||
return nil, fmt.Errorf("get api key: %w", err)
|
||||
}
|
||||
|
||||
// 缓存到Redis(可选,TTL设置为5分钟)
|
||||
if s.rdb != nil {
|
||||
// 这里可以序列化并缓存API Key
|
||||
_ = cacheKey // 使用变量避免未使用错误
|
||||
}
|
||||
|
||||
return apiKey, nil
|
||||
}
|
||||
|
||||
// Update 更新API Key
|
||||
func (s *ApiKeyService) Update(ctx context.Context, id int64, userID int64, req UpdateApiKeyRequest) (*model.ApiKey, error) {
|
||||
apiKey, err := s.apiKeyRepo.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, ErrApiKeyNotFound
|
||||
}
|
||||
return nil, fmt.Errorf("get api key: %w", err)
|
||||
}
|
||||
|
||||
// 验证所有权
|
||||
if apiKey.UserID != userID {
|
||||
return nil, ErrInsufficientPerms
|
||||
}
|
||||
|
||||
// 更新字段
|
||||
if req.Name != nil {
|
||||
apiKey.Name = *req.Name
|
||||
}
|
||||
|
||||
if req.GroupID != nil {
|
||||
// 验证分组权限
|
||||
user, err := s.userRepo.GetByID(ctx, userID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get user: %w", err)
|
||||
}
|
||||
|
||||
group, err := s.groupRepo.GetByID(ctx, *req.GroupID)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, errors.New("group not found")
|
||||
}
|
||||
return nil, fmt.Errorf("get group: %w", err)
|
||||
}
|
||||
|
||||
if !s.canUserBindGroup(ctx, user, group) {
|
||||
return nil, ErrGroupNotAllowed
|
||||
}
|
||||
|
||||
apiKey.GroupID = req.GroupID
|
||||
}
|
||||
|
||||
if req.Status != nil {
|
||||
apiKey.Status = *req.Status
|
||||
// 如果状态改变,清除Redis缓存
|
||||
if s.rdb != nil {
|
||||
cacheKey := fmt.Sprintf("apikey:%s", apiKey.Key)
|
||||
_ = s.rdb.Del(ctx, cacheKey)
|
||||
}
|
||||
}
|
||||
|
||||
if err := s.apiKeyRepo.Update(ctx, apiKey); err != nil {
|
||||
return nil, fmt.Errorf("update api key: %w", err)
|
||||
}
|
||||
|
||||
return apiKey, nil
|
||||
}
|
||||
|
||||
// Delete 删除API Key
|
||||
func (s *ApiKeyService) Delete(ctx context.Context, id int64, userID int64) error {
|
||||
apiKey, err := s.apiKeyRepo.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return ErrApiKeyNotFound
|
||||
}
|
||||
return fmt.Errorf("get api key: %w", err)
|
||||
}
|
||||
|
||||
// 验证所有权
|
||||
if apiKey.UserID != userID {
|
||||
return ErrInsufficientPerms
|
||||
}
|
||||
|
||||
// 清除Redis缓存
|
||||
if s.rdb != nil {
|
||||
cacheKey := fmt.Sprintf("apikey:%s", apiKey.Key)
|
||||
_ = s.rdb.Del(ctx, cacheKey)
|
||||
}
|
||||
|
||||
if err := s.apiKeyRepo.Delete(ctx, id); err != nil {
|
||||
return fmt.Errorf("delete api key: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidateKey 验证API Key是否有效(用于认证中间件)
|
||||
func (s *ApiKeyService) ValidateKey(ctx context.Context, key string) (*model.ApiKey, *model.User, error) {
|
||||
// 获取API Key
|
||||
apiKey, err := s.GetByKey(ctx, key)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// 检查API Key状态
|
||||
if !apiKey.IsActive() {
|
||||
return nil, nil, errors.New("api key is not active")
|
||||
}
|
||||
|
||||
// 获取用户信息
|
||||
user, err := s.userRepo.GetByID(ctx, apiKey.UserID)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, nil, ErrUserNotFound
|
||||
}
|
||||
return nil, nil, fmt.Errorf("get user: %w", err)
|
||||
}
|
||||
|
||||
// 检查用户状态
|
||||
if !user.IsActive() {
|
||||
return nil, nil, ErrUserNotActive
|
||||
}
|
||||
|
||||
return apiKey, user, nil
|
||||
}
|
||||
|
||||
// IncrementUsage 增加API Key使用次数(可选:用于统计)
|
||||
func (s *ApiKeyService) IncrementUsage(ctx context.Context, keyID int64) error {
|
||||
// 使用Redis计数器
|
||||
if s.rdb != nil {
|
||||
cacheKey := fmt.Sprintf("apikey:usage:%d:%s", keyID, timezone.Now().Format("2006-01-02"))
|
||||
if err := s.rdb.Incr(ctx, cacheKey).Err(); err != nil {
|
||||
return fmt.Errorf("increment usage: %w", err)
|
||||
}
|
||||
// 设置24小时过期
|
||||
_ = s.rdb.Expire(ctx, cacheKey, 24*time.Hour)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetAvailableGroups 获取用户有权限绑定的分组列表
|
||||
// 返回用户可以选择的分组:
|
||||
// - 标准类型分组:公开的(非专属)或用户被明确允许的
|
||||
// - 订阅类型分组:用户有有效订阅的
|
||||
func (s *ApiKeyService) GetAvailableGroups(ctx context.Context, userID int64) ([]model.Group, error) {
|
||||
// 获取用户信息
|
||||
user, err := s.userRepo.GetByID(ctx, userID)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, ErrUserNotFound
|
||||
}
|
||||
return nil, fmt.Errorf("get user: %w", err)
|
||||
}
|
||||
|
||||
// 获取所有活跃分组
|
||||
allGroups, err := s.groupRepo.ListActive(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list active groups: %w", err)
|
||||
}
|
||||
|
||||
// 获取用户的所有有效订阅
|
||||
activeSubscriptions, err := s.userSubRepo.ListActiveByUserID(ctx, userID)
|
||||
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, fmt.Errorf("list active subscriptions: %w", err)
|
||||
}
|
||||
|
||||
// 构建订阅分组 ID 集合
|
||||
subscribedGroupIDs := make(map[int64]bool)
|
||||
for _, sub := range activeSubscriptions {
|
||||
subscribedGroupIDs[sub.GroupID] = true
|
||||
}
|
||||
|
||||
// 过滤出用户有权限的分组
|
||||
availableGroups := make([]model.Group, 0)
|
||||
for _, group := range allGroups {
|
||||
if s.canUserBindGroupInternal(user, &group, subscribedGroupIDs) {
|
||||
availableGroups = append(availableGroups, group)
|
||||
}
|
||||
}
|
||||
|
||||
return availableGroups, nil
|
||||
}
|
||||
|
||||
// canUserBindGroupInternal 内部方法,检查用户是否可以绑定分组(使用预加载的订阅数据)
|
||||
func (s *ApiKeyService) canUserBindGroupInternal(user *model.User, group *model.Group, subscribedGroupIDs map[int64]bool) bool {
|
||||
// 订阅类型分组:需要有效订阅
|
||||
if group.IsSubscriptionType() {
|
||||
return subscribedGroupIDs[group.ID]
|
||||
}
|
||||
// 标准类型分组:使用原有逻辑
|
||||
return user.CanBindGroup(group.ID, group.IsExclusive)
|
||||
}
|
||||
376
backend/internal/service/auth_service.go
Normal file
376
backend/internal/service/auth_service.go
Normal file
@@ -0,0 +1,376 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"sub2api/internal/config"
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/repository"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrInvalidCredentials = errors.New("invalid email or password")
|
||||
ErrUserNotActive = errors.New("user is not active")
|
||||
ErrEmailExists = errors.New("email already exists")
|
||||
ErrInvalidToken = errors.New("invalid token")
|
||||
ErrTokenExpired = errors.New("token has expired")
|
||||
ErrEmailVerifyRequired = errors.New("email verification is required")
|
||||
ErrRegDisabled = errors.New("registration is currently disabled")
|
||||
)
|
||||
|
||||
// JWTClaims JWT载荷数据
|
||||
type JWTClaims struct {
|
||||
UserID int64 `json:"user_id"`
|
||||
Email string `json:"email"`
|
||||
Role string `json:"role"`
|
||||
jwt.RegisteredClaims
|
||||
}
|
||||
|
||||
// AuthService 认证服务
|
||||
type AuthService struct {
|
||||
userRepo *repository.UserRepository
|
||||
cfg *config.Config
|
||||
settingService *SettingService
|
||||
emailService *EmailService
|
||||
turnstileService *TurnstileService
|
||||
emailQueueService *EmailQueueService
|
||||
}
|
||||
|
||||
// NewAuthService 创建认证服务实例
|
||||
func NewAuthService(userRepo *repository.UserRepository, cfg *config.Config) *AuthService {
|
||||
return &AuthService{
|
||||
userRepo: userRepo,
|
||||
cfg: cfg,
|
||||
}
|
||||
}
|
||||
|
||||
// SetSettingService 设置系统设置服务(用于检查注册开关和邮件验证)
|
||||
func (s *AuthService) SetSettingService(settingService *SettingService) {
|
||||
s.settingService = settingService
|
||||
}
|
||||
|
||||
// SetEmailService 设置邮件服务(用于邮件验证)
|
||||
func (s *AuthService) SetEmailService(emailService *EmailService) {
|
||||
s.emailService = emailService
|
||||
}
|
||||
|
||||
// SetTurnstileService 设置Turnstile服务(用于验证码校验)
|
||||
func (s *AuthService) SetTurnstileService(turnstileService *TurnstileService) {
|
||||
s.turnstileService = turnstileService
|
||||
}
|
||||
|
||||
// SetEmailQueueService 设置邮件队列服务(用于异步发送邮件)
|
||||
func (s *AuthService) SetEmailQueueService(emailQueueService *EmailQueueService) {
|
||||
s.emailQueueService = emailQueueService
|
||||
}
|
||||
|
||||
// Register 用户注册,返回token和用户
|
||||
func (s *AuthService) Register(ctx context.Context, email, password string) (string, *model.User, error) {
|
||||
return s.RegisterWithVerification(ctx, email, password, "")
|
||||
}
|
||||
|
||||
// RegisterWithVerification 用户注册(支持邮件验证),返回token和用户
|
||||
func (s *AuthService) RegisterWithVerification(ctx context.Context, email, password, verifyCode string) (string, *model.User, error) {
|
||||
// 检查是否开放注册
|
||||
if s.settingService != nil && !s.settingService.IsRegistrationEnabled(ctx) {
|
||||
return "", nil, ErrRegDisabled
|
||||
}
|
||||
|
||||
// 检查是否需要邮件验证
|
||||
if s.settingService != nil && s.settingService.IsEmailVerifyEnabled(ctx) {
|
||||
if verifyCode == "" {
|
||||
return "", nil, ErrEmailVerifyRequired
|
||||
}
|
||||
// 验证邮箱验证码
|
||||
if s.emailService != nil {
|
||||
if err := s.emailService.VerifyCode(ctx, email, verifyCode); err != nil {
|
||||
return "", nil, fmt.Errorf("verify code: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 检查邮箱是否已存在
|
||||
existsEmail, err := s.userRepo.ExistsByEmail(ctx, email)
|
||||
if err != nil {
|
||||
return "", nil, fmt.Errorf("check email exists: %w", err)
|
||||
}
|
||||
if existsEmail {
|
||||
return "", nil, ErrEmailExists
|
||||
}
|
||||
|
||||
// 密码哈希
|
||||
hashedPassword, err := s.HashPassword(password)
|
||||
if err != nil {
|
||||
return "", nil, fmt.Errorf("hash password: %w", err)
|
||||
}
|
||||
|
||||
// 获取默认配置
|
||||
defaultBalance := s.cfg.Default.UserBalance
|
||||
defaultConcurrency := s.cfg.Default.UserConcurrency
|
||||
if s.settingService != nil {
|
||||
defaultBalance = s.settingService.GetDefaultBalance(ctx)
|
||||
defaultConcurrency = s.settingService.GetDefaultConcurrency(ctx)
|
||||
}
|
||||
|
||||
// 创建用户
|
||||
user := &model.User{
|
||||
Email: email,
|
||||
PasswordHash: hashedPassword,
|
||||
Role: model.RoleUser,
|
||||
Balance: defaultBalance,
|
||||
Concurrency: defaultConcurrency,
|
||||
Status: model.StatusActive,
|
||||
}
|
||||
|
||||
if err := s.userRepo.Create(ctx, user); err != nil {
|
||||
return "", nil, fmt.Errorf("create user: %w", err)
|
||||
}
|
||||
|
||||
// 生成token
|
||||
token, err := s.GenerateToken(user)
|
||||
if err != nil {
|
||||
return "", nil, fmt.Errorf("generate token: %w", err)
|
||||
}
|
||||
|
||||
return token, user, nil
|
||||
}
|
||||
|
||||
// SendVerifyCodeResult 发送验证码返回结果
|
||||
type SendVerifyCodeResult struct {
|
||||
Countdown int `json:"countdown"` // 倒计时秒数
|
||||
}
|
||||
|
||||
// SendVerifyCode 发送邮箱验证码(同步方式)
|
||||
func (s *AuthService) SendVerifyCode(ctx context.Context, email string) error {
|
||||
// 检查是否开放注册
|
||||
if s.settingService != nil && !s.settingService.IsRegistrationEnabled(ctx) {
|
||||
return ErrRegDisabled
|
||||
}
|
||||
|
||||
// 检查邮箱是否已存在
|
||||
existsEmail, err := s.userRepo.ExistsByEmail(ctx, email)
|
||||
if err != nil {
|
||||
return fmt.Errorf("check email exists: %w", err)
|
||||
}
|
||||
if existsEmail {
|
||||
return ErrEmailExists
|
||||
}
|
||||
|
||||
// 发送验证码
|
||||
if s.emailService == nil {
|
||||
return errors.New("email service not configured")
|
||||
}
|
||||
|
||||
// 获取网站名称
|
||||
siteName := "Sub2API"
|
||||
if s.settingService != nil {
|
||||
siteName = s.settingService.GetSiteName(ctx)
|
||||
}
|
||||
|
||||
return s.emailService.SendVerifyCode(ctx, email, siteName)
|
||||
}
|
||||
|
||||
// SendVerifyCodeAsync 异步发送邮箱验证码并返回倒计时
|
||||
func (s *AuthService) SendVerifyCodeAsync(ctx context.Context, email string) (*SendVerifyCodeResult, error) {
|
||||
log.Printf("[Auth] SendVerifyCodeAsync called for email: %s", email)
|
||||
|
||||
// 检查是否开放注册
|
||||
if s.settingService != nil && !s.settingService.IsRegistrationEnabled(ctx) {
|
||||
log.Println("[Auth] Registration is disabled")
|
||||
return nil, ErrRegDisabled
|
||||
}
|
||||
|
||||
// 检查邮箱是否已存在
|
||||
existsEmail, err := s.userRepo.ExistsByEmail(ctx, email)
|
||||
if err != nil {
|
||||
log.Printf("[Auth] Error checking email exists: %v", err)
|
||||
return nil, fmt.Errorf("check email exists: %w", err)
|
||||
}
|
||||
if existsEmail {
|
||||
log.Printf("[Auth] Email already exists: %s", email)
|
||||
return nil, ErrEmailExists
|
||||
}
|
||||
|
||||
// 检查邮件队列服务是否配置
|
||||
if s.emailQueueService == nil {
|
||||
log.Println("[Auth] Email queue service not configured")
|
||||
return nil, errors.New("email queue service not configured")
|
||||
}
|
||||
|
||||
// 获取网站名称
|
||||
siteName := "Sub2API"
|
||||
if s.settingService != nil {
|
||||
siteName = s.settingService.GetSiteName(ctx)
|
||||
}
|
||||
|
||||
// 异步发送
|
||||
log.Printf("[Auth] Enqueueing verify code for: %s", email)
|
||||
if err := s.emailQueueService.EnqueueVerifyCode(email, siteName); err != nil {
|
||||
log.Printf("[Auth] Failed to enqueue: %v", err)
|
||||
return nil, fmt.Errorf("enqueue verify code: %w", err)
|
||||
}
|
||||
|
||||
log.Printf("[Auth] Verify code enqueued successfully for: %s", email)
|
||||
return &SendVerifyCodeResult{
|
||||
Countdown: 60, // 60秒倒计时
|
||||
}, nil
|
||||
}
|
||||
|
||||
// VerifyTurnstile 验证Turnstile token
|
||||
func (s *AuthService) VerifyTurnstile(ctx context.Context, token string, remoteIP string) error {
|
||||
if s.turnstileService == nil {
|
||||
return nil // 服务未配置则跳过验证
|
||||
}
|
||||
return s.turnstileService.VerifyToken(ctx, token, remoteIP)
|
||||
}
|
||||
|
||||
// IsTurnstileEnabled 检查是否启用Turnstile验证
|
||||
func (s *AuthService) IsTurnstileEnabled(ctx context.Context) bool {
|
||||
if s.turnstileService == nil {
|
||||
return false
|
||||
}
|
||||
return s.turnstileService.IsEnabled(ctx)
|
||||
}
|
||||
|
||||
// IsRegistrationEnabled 检查是否开放注册
|
||||
func (s *AuthService) IsRegistrationEnabled(ctx context.Context) bool {
|
||||
if s.settingService == nil {
|
||||
return true
|
||||
}
|
||||
return s.settingService.IsRegistrationEnabled(ctx)
|
||||
}
|
||||
|
||||
// IsEmailVerifyEnabled 检查是否开启邮件验证
|
||||
func (s *AuthService) IsEmailVerifyEnabled(ctx context.Context) bool {
|
||||
if s.settingService == nil {
|
||||
return false
|
||||
}
|
||||
return s.settingService.IsEmailVerifyEnabled(ctx)
|
||||
}
|
||||
|
||||
// Login 用户登录,返回JWT token
|
||||
func (s *AuthService) Login(ctx context.Context, email, password string) (string, *model.User, error) {
|
||||
// 查找用户
|
||||
user, err := s.userRepo.GetByEmail(ctx, email)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return "", nil, ErrInvalidCredentials
|
||||
}
|
||||
return "", nil, fmt.Errorf("get user by email: %w", err)
|
||||
}
|
||||
|
||||
// 验证密码
|
||||
if !s.CheckPassword(password, user.PasswordHash) {
|
||||
return "", nil, ErrInvalidCredentials
|
||||
}
|
||||
|
||||
// 检查用户状态
|
||||
if !user.IsActive() {
|
||||
return "", nil, ErrUserNotActive
|
||||
}
|
||||
|
||||
// 生成JWT token
|
||||
token, err := s.GenerateToken(user)
|
||||
if err != nil {
|
||||
return "", nil, fmt.Errorf("generate token: %w", err)
|
||||
}
|
||||
|
||||
return token, user, nil
|
||||
}
|
||||
|
||||
// ValidateToken 验证JWT token并返回用户声明
|
||||
func (s *AuthService) ValidateToken(tokenString string) (*JWTClaims, error) {
|
||||
token, err := jwt.ParseWithClaims(tokenString, &JWTClaims{}, func(token *jwt.Token) (interface{}, error) {
|
||||
// 验证签名方法
|
||||
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
||||
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
|
||||
}
|
||||
return []byte(s.cfg.JWT.Secret), nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
if errors.Is(err, jwt.ErrTokenExpired) {
|
||||
return nil, ErrTokenExpired
|
||||
}
|
||||
return nil, ErrInvalidToken
|
||||
}
|
||||
|
||||
if claims, ok := token.Claims.(*JWTClaims); ok && token.Valid {
|
||||
return claims, nil
|
||||
}
|
||||
|
||||
return nil, ErrInvalidToken
|
||||
}
|
||||
|
||||
// GenerateToken 生成JWT token
|
||||
func (s *AuthService) GenerateToken(user *model.User) (string, error) {
|
||||
now := time.Now()
|
||||
expiresAt := now.Add(time.Duration(s.cfg.JWT.ExpireHour) * time.Hour)
|
||||
|
||||
claims := &JWTClaims{
|
||||
UserID: user.ID,
|
||||
Email: user.Email,
|
||||
Role: user.Role,
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
ExpiresAt: jwt.NewNumericDate(expiresAt),
|
||||
IssuedAt: jwt.NewNumericDate(now),
|
||||
NotBefore: jwt.NewNumericDate(now),
|
||||
},
|
||||
}
|
||||
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||
tokenString, err := token.SignedString([]byte(s.cfg.JWT.Secret))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("sign token: %w", err)
|
||||
}
|
||||
|
||||
return tokenString, nil
|
||||
}
|
||||
|
||||
// HashPassword 使用bcrypt加密密码
|
||||
func (s *AuthService) HashPassword(password string) (string, error) {
|
||||
hashedBytes, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(hashedBytes), nil
|
||||
}
|
||||
|
||||
// CheckPassword 验证密码是否匹配
|
||||
func (s *AuthService) CheckPassword(password, hashedPassword string) bool {
|
||||
err := bcrypt.CompareHashAndPassword([]byte(hashedPassword), []byte(password))
|
||||
return err == nil
|
||||
}
|
||||
|
||||
// RefreshToken 刷新token
|
||||
func (s *AuthService) RefreshToken(ctx context.Context, oldTokenString string) (string, error) {
|
||||
// 验证旧token(即使过期也允许,用于刷新)
|
||||
claims, err := s.ValidateToken(oldTokenString)
|
||||
if err != nil && !errors.Is(err, ErrTokenExpired) {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// 获取最新的用户信息
|
||||
user, err := s.userRepo.GetByID(ctx, claims.UserID)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return "", ErrInvalidToken
|
||||
}
|
||||
return "", fmt.Errorf("get user: %w", err)
|
||||
}
|
||||
|
||||
// 检查用户状态
|
||||
if !user.IsActive() {
|
||||
return "", ErrUserNotActive
|
||||
}
|
||||
|
||||
// 生成新token
|
||||
return s.GenerateToken(user)
|
||||
}
|
||||
422
backend/internal/service/billing_cache_service.go
Normal file
422
backend/internal/service/billing_cache_service.go
Normal file
@@ -0,0 +1,422 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/repository"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
// 缓存Key前缀和TTL
|
||||
const (
|
||||
billingBalanceKeyPrefix = "billing:balance:"
|
||||
billingSubKeyPrefix = "billing:sub:"
|
||||
billingCacheTTL = 5 * time.Minute
|
||||
)
|
||||
|
||||
// 订阅缓存Hash字段
|
||||
const (
|
||||
subFieldStatus = "status"
|
||||
subFieldExpiresAt = "expires_at"
|
||||
subFieldDailyUsage = "daily_usage"
|
||||
subFieldWeeklyUsage = "weekly_usage"
|
||||
subFieldMonthlyUsage = "monthly_usage"
|
||||
subFieldVersion = "version"
|
||||
)
|
||||
|
||||
// 错误定义
|
||||
// 注:ErrInsufficientBalance在redeem_service.go中定义
|
||||
// 注:ErrDailyLimitExceeded/ErrWeeklyLimitExceeded/ErrMonthlyLimitExceeded在subscription_service.go中定义
|
||||
var (
|
||||
ErrSubscriptionInvalid = errors.New("subscription is invalid or expired")
|
||||
)
|
||||
|
||||
// 预编译的Lua脚本
|
||||
var (
|
||||
// deductBalanceScript: 扣减余额缓存,key不存在则忽略
|
||||
deductBalanceScript = redis.NewScript(`
|
||||
local current = redis.call('GET', KEYS[1])
|
||||
if current == false then
|
||||
return 0
|
||||
end
|
||||
local newVal = tonumber(current) - tonumber(ARGV[1])
|
||||
redis.call('SET', KEYS[1], newVal)
|
||||
redis.call('EXPIRE', KEYS[1], ARGV[2])
|
||||
return 1
|
||||
`)
|
||||
|
||||
// updateSubUsageScript: 更新订阅用量缓存,key不存在则忽略
|
||||
updateSubUsageScript = redis.NewScript(`
|
||||
local exists = redis.call('EXISTS', KEYS[1])
|
||||
if exists == 0 then
|
||||
return 0
|
||||
end
|
||||
local cost = tonumber(ARGV[1])
|
||||
redis.call('HINCRBYFLOAT', KEYS[1], 'daily_usage', cost)
|
||||
redis.call('HINCRBYFLOAT', KEYS[1], 'weekly_usage', cost)
|
||||
redis.call('HINCRBYFLOAT', KEYS[1], 'monthly_usage', cost)
|
||||
redis.call('EXPIRE', KEYS[1], ARGV[2])
|
||||
return 1
|
||||
`)
|
||||
)
|
||||
|
||||
// subscriptionCacheData 订阅缓存数据结构(内部使用)
|
||||
type subscriptionCacheData struct {
|
||||
Status string
|
||||
ExpiresAt time.Time
|
||||
DailyUsage float64
|
||||
WeeklyUsage float64
|
||||
MonthlyUsage float64
|
||||
Version int64
|
||||
}
|
||||
|
||||
// BillingCacheService 计费缓存服务
|
||||
// 负责余额和订阅数据的缓存管理,提供高性能的计费资格检查
|
||||
type BillingCacheService struct {
|
||||
rdb *redis.Client
|
||||
userRepo *repository.UserRepository
|
||||
subRepo *repository.UserSubscriptionRepository
|
||||
}
|
||||
|
||||
// NewBillingCacheService 创建计费缓存服务
|
||||
func NewBillingCacheService(rdb *redis.Client, userRepo *repository.UserRepository, subRepo *repository.UserSubscriptionRepository) *BillingCacheService {
|
||||
return &BillingCacheService{
|
||||
rdb: rdb,
|
||||
userRepo: userRepo,
|
||||
subRepo: subRepo,
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================
|
||||
// 余额缓存方法
|
||||
// ============================================
|
||||
|
||||
// GetUserBalance 获取用户余额(优先从缓存读取)
|
||||
func (s *BillingCacheService) GetUserBalance(ctx context.Context, userID int64) (float64, error) {
|
||||
if s.rdb == nil {
|
||||
// Redis不可用,直接查询数据库
|
||||
return s.getUserBalanceFromDB(ctx, userID)
|
||||
}
|
||||
|
||||
key := fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID)
|
||||
|
||||
// 尝试从缓存读取
|
||||
val, err := s.rdb.Get(ctx, key).Result()
|
||||
if err == nil {
|
||||
balance, parseErr := strconv.ParseFloat(val, 64)
|
||||
if parseErr == nil {
|
||||
return balance, nil
|
||||
}
|
||||
}
|
||||
|
||||
// 缓存未命中或解析错误,从数据库读取
|
||||
balance, err := s.getUserBalanceFromDB(ctx, userID)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
// 异步建立缓存
|
||||
go func() {
|
||||
cacheCtx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
s.setBalanceCache(cacheCtx, userID, balance)
|
||||
}()
|
||||
|
||||
return balance, nil
|
||||
}
|
||||
|
||||
// getUserBalanceFromDB 从数据库获取用户余额
|
||||
func (s *BillingCacheService) getUserBalanceFromDB(ctx context.Context, userID int64) (float64, error) {
|
||||
user, err := s.userRepo.GetByID(ctx, userID)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("get user balance: %w", err)
|
||||
}
|
||||
return user.Balance, nil
|
||||
}
|
||||
|
||||
// setBalanceCache 设置余额缓存
|
||||
func (s *BillingCacheService) setBalanceCache(ctx context.Context, userID int64, balance float64) {
|
||||
if s.rdb == nil {
|
||||
return
|
||||
}
|
||||
key := fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID)
|
||||
if err := s.rdb.Set(ctx, key, balance, billingCacheTTL).Err(); err != nil {
|
||||
log.Printf("Warning: set balance cache failed for user %d: %v", userID, err)
|
||||
}
|
||||
}
|
||||
|
||||
// DeductBalanceCache 扣减余额缓存(异步调用,用于扣费后更新缓存)
|
||||
func (s *BillingCacheService) DeductBalanceCache(ctx context.Context, userID int64, amount float64) error {
|
||||
if s.rdb == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
key := fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID)
|
||||
|
||||
// 使用预编译的Lua脚本原子性扣减,如果key不存在则忽略
|
||||
_, err := deductBalanceScript.Run(ctx, s.rdb, []string{key}, amount, int(billingCacheTTL.Seconds())).Result()
|
||||
if err != nil && !errors.Is(err, redis.Nil) {
|
||||
log.Printf("Warning: deduct balance cache failed for user %d: %v", userID, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// InvalidateUserBalance 失效用户余额缓存
|
||||
func (s *BillingCacheService) InvalidateUserBalance(ctx context.Context, userID int64) error {
|
||||
if s.rdb == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
key := fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID)
|
||||
if err := s.rdb.Del(ctx, key).Err(); err != nil {
|
||||
log.Printf("Warning: invalidate balance cache failed for user %d: %v", userID, err)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ============================================
|
||||
// 订阅缓存方法
|
||||
// ============================================
|
||||
|
||||
// GetSubscriptionStatus 获取订阅状态(优先从缓存读取)
|
||||
func (s *BillingCacheService) GetSubscriptionStatus(ctx context.Context, userID, groupID int64) (*subscriptionCacheData, error) {
|
||||
if s.rdb == nil {
|
||||
return s.getSubscriptionFromDB(ctx, userID, groupID)
|
||||
}
|
||||
|
||||
key := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID)
|
||||
|
||||
// 尝试从缓存读取
|
||||
result, err := s.rdb.HGetAll(ctx, key).Result()
|
||||
if err == nil && len(result) > 0 {
|
||||
data, parseErr := s.parseSubscriptionCache(result)
|
||||
if parseErr == nil {
|
||||
return data, nil
|
||||
}
|
||||
}
|
||||
|
||||
// 缓存未命中,从数据库读取
|
||||
data, err := s.getSubscriptionFromDB(ctx, userID, groupID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 异步建立缓存
|
||||
go func() {
|
||||
cacheCtx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
s.setSubscriptionCache(cacheCtx, userID, groupID, data)
|
||||
}()
|
||||
|
||||
return data, nil
|
||||
}
|
||||
|
||||
// getSubscriptionFromDB 从数据库获取订阅数据
|
||||
func (s *BillingCacheService) getSubscriptionFromDB(ctx context.Context, userID, groupID int64) (*subscriptionCacheData, error) {
|
||||
sub, err := s.subRepo.GetActiveByUserIDAndGroupID(ctx, userID, groupID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get subscription: %w", err)
|
||||
}
|
||||
|
||||
return &subscriptionCacheData{
|
||||
Status: sub.Status,
|
||||
ExpiresAt: sub.ExpiresAt,
|
||||
DailyUsage: sub.DailyUsageUSD,
|
||||
WeeklyUsage: sub.WeeklyUsageUSD,
|
||||
MonthlyUsage: sub.MonthlyUsageUSD,
|
||||
Version: sub.UpdatedAt.Unix(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// parseSubscriptionCache 解析订阅缓存数据
|
||||
func (s *BillingCacheService) parseSubscriptionCache(data map[string]string) (*subscriptionCacheData, error) {
|
||||
result := &subscriptionCacheData{}
|
||||
|
||||
result.Status = data[subFieldStatus]
|
||||
if result.Status == "" {
|
||||
return nil, errors.New("invalid cache: missing status")
|
||||
}
|
||||
|
||||
if expiresStr, ok := data[subFieldExpiresAt]; ok {
|
||||
expiresAt, err := strconv.ParseInt(expiresStr, 10, 64)
|
||||
if err == nil {
|
||||
result.ExpiresAt = time.Unix(expiresAt, 0)
|
||||
}
|
||||
}
|
||||
|
||||
if dailyStr, ok := data[subFieldDailyUsage]; ok {
|
||||
result.DailyUsage, _ = strconv.ParseFloat(dailyStr, 64)
|
||||
}
|
||||
|
||||
if weeklyStr, ok := data[subFieldWeeklyUsage]; ok {
|
||||
result.WeeklyUsage, _ = strconv.ParseFloat(weeklyStr, 64)
|
||||
}
|
||||
|
||||
if monthlyStr, ok := data[subFieldMonthlyUsage]; ok {
|
||||
result.MonthlyUsage, _ = strconv.ParseFloat(monthlyStr, 64)
|
||||
}
|
||||
|
||||
if versionStr, ok := data[subFieldVersion]; ok {
|
||||
result.Version, _ = strconv.ParseInt(versionStr, 10, 64)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// setSubscriptionCache 设置订阅缓存
|
||||
func (s *BillingCacheService) setSubscriptionCache(ctx context.Context, userID, groupID int64, data *subscriptionCacheData) {
|
||||
if s.rdb == nil || data == nil {
|
||||
return
|
||||
}
|
||||
|
||||
key := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID)
|
||||
|
||||
fields := map[string]interface{}{
|
||||
subFieldStatus: data.Status,
|
||||
subFieldExpiresAt: data.ExpiresAt.Unix(),
|
||||
subFieldDailyUsage: data.DailyUsage,
|
||||
subFieldWeeklyUsage: data.WeeklyUsage,
|
||||
subFieldMonthlyUsage: data.MonthlyUsage,
|
||||
subFieldVersion: data.Version,
|
||||
}
|
||||
|
||||
pipe := s.rdb.Pipeline()
|
||||
pipe.HSet(ctx, key, fields)
|
||||
pipe.Expire(ctx, key, billingCacheTTL)
|
||||
if _, err := pipe.Exec(ctx); err != nil {
|
||||
log.Printf("Warning: set subscription cache failed for user %d group %d: %v", userID, groupID, err)
|
||||
}
|
||||
}
|
||||
|
||||
// UpdateSubscriptionUsage 更新订阅用量缓存(异步调用,用于扣费后更新缓存)
|
||||
func (s *BillingCacheService) UpdateSubscriptionUsage(ctx context.Context, userID, groupID int64, costUSD float64) error {
|
||||
if s.rdb == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
key := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID)
|
||||
|
||||
// 使用预编译的Lua脚本原子性增加用量,如果key不存在则忽略
|
||||
_, err := updateSubUsageScript.Run(ctx, s.rdb, []string{key}, costUSD, int(billingCacheTTL.Seconds())).Result()
|
||||
if err != nil && !errors.Is(err, redis.Nil) {
|
||||
log.Printf("Warning: update subscription usage cache failed for user %d group %d: %v", userID, groupID, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// InvalidateSubscription 失效指定订阅缓存
|
||||
func (s *BillingCacheService) InvalidateSubscription(ctx context.Context, userID, groupID int64) error {
|
||||
if s.rdb == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
key := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID)
|
||||
if err := s.rdb.Del(ctx, key).Err(); err != nil {
|
||||
log.Printf("Warning: invalidate subscription cache failed for user %d group %d: %v", userID, groupID, err)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ============================================
|
||||
// 统一检查方法
|
||||
// ============================================
|
||||
|
||||
// CheckBillingEligibility 检查用户是否有资格发起请求
|
||||
// 余额模式:检查缓存余额 > 0
|
||||
// 订阅模式:检查缓存用量未超过限额(Group限额从参数传入)
|
||||
func (s *BillingCacheService) CheckBillingEligibility(ctx context.Context, user *model.User, apiKey *model.ApiKey, group *model.Group, subscription *model.UserSubscription) error {
|
||||
// 判断计费模式
|
||||
isSubscriptionMode := group != nil && group.IsSubscriptionType() && subscription != nil
|
||||
|
||||
if isSubscriptionMode {
|
||||
return s.checkSubscriptionEligibility(ctx, user.ID, group, subscription)
|
||||
}
|
||||
|
||||
return s.checkBalanceEligibility(ctx, user.ID)
|
||||
}
|
||||
|
||||
// checkBalanceEligibility 检查余额模式资格
|
||||
func (s *BillingCacheService) checkBalanceEligibility(ctx context.Context, userID int64) error {
|
||||
balance, err := s.GetUserBalance(ctx, userID)
|
||||
if err != nil {
|
||||
// 缓存/数据库错误,允许通过(降级处理)
|
||||
log.Printf("Warning: get user balance failed, allowing request: %v", err)
|
||||
return nil
|
||||
}
|
||||
|
||||
if balance <= 0 {
|
||||
return ErrInsufficientBalance
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// checkSubscriptionEligibility 检查订阅模式资格
|
||||
func (s *BillingCacheService) checkSubscriptionEligibility(ctx context.Context, userID int64, group *model.Group, subscription *model.UserSubscription) error {
|
||||
// 获取订阅缓存数据
|
||||
subData, err := s.GetSubscriptionStatus(ctx, userID, group.ID)
|
||||
if err != nil {
|
||||
// 缓存/数据库错误,降级使用传入的subscription进行检查
|
||||
log.Printf("Warning: get subscription cache failed, using fallback: %v", err)
|
||||
return s.checkSubscriptionLimitsFallback(subscription, group)
|
||||
}
|
||||
|
||||
// 检查订阅状态
|
||||
if subData.Status != model.SubscriptionStatusActive {
|
||||
return ErrSubscriptionInvalid
|
||||
}
|
||||
|
||||
// 检查是否过期
|
||||
if time.Now().After(subData.ExpiresAt) {
|
||||
return ErrSubscriptionInvalid
|
||||
}
|
||||
|
||||
// 检查限额(使用传入的Group限额配置)
|
||||
if group.HasDailyLimit() && subData.DailyUsage >= *group.DailyLimitUSD {
|
||||
return ErrDailyLimitExceeded
|
||||
}
|
||||
|
||||
if group.HasWeeklyLimit() && subData.WeeklyUsage >= *group.WeeklyLimitUSD {
|
||||
return ErrWeeklyLimitExceeded
|
||||
}
|
||||
|
||||
if group.HasMonthlyLimit() && subData.MonthlyUsage >= *group.MonthlyLimitUSD {
|
||||
return ErrMonthlyLimitExceeded
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// checkSubscriptionLimitsFallback 降级检查订阅限额
|
||||
func (s *BillingCacheService) checkSubscriptionLimitsFallback(subscription *model.UserSubscription, group *model.Group) error {
|
||||
if subscription == nil {
|
||||
return ErrSubscriptionInvalid
|
||||
}
|
||||
|
||||
if !subscription.IsActive() {
|
||||
return ErrSubscriptionInvalid
|
||||
}
|
||||
|
||||
if !subscription.CheckDailyLimit(group, 0) {
|
||||
return ErrDailyLimitExceeded
|
||||
}
|
||||
|
||||
if !subscription.CheckWeeklyLimit(group, 0) {
|
||||
return ErrWeeklyLimitExceeded
|
||||
}
|
||||
|
||||
if !subscription.CheckMonthlyLimit(group, 0) {
|
||||
return ErrMonthlyLimitExceeded
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
279
backend/internal/service/billing_service.go
Normal file
279
backend/internal/service/billing_service.go
Normal file
@@ -0,0 +1,279 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"strings"
|
||||
"sub2api/internal/config"
|
||||
)
|
||||
|
||||
// ModelPricing 模型价格配置(per-token价格,与LiteLLM格式一致)
|
||||
type ModelPricing struct {
|
||||
InputPricePerToken float64 // 每token输入价格 (USD)
|
||||
OutputPricePerToken float64 // 每token输出价格 (USD)
|
||||
CacheCreationPricePerToken float64 // 缓存创建每token价格 (USD)
|
||||
CacheReadPricePerToken float64 // 缓存读取每token价格 (USD)
|
||||
CacheCreation5mPrice float64 // 5分钟缓存创建价格(每百万token)- 仅用于硬编码回退
|
||||
CacheCreation1hPrice float64 // 1小时缓存创建价格(每百万token)- 仅用于硬编码回退
|
||||
SupportsCacheBreakdown bool // 是否支持详细的缓存分类
|
||||
}
|
||||
|
||||
// UsageTokens 使用的token数量
|
||||
type UsageTokens struct {
|
||||
InputTokens int
|
||||
OutputTokens int
|
||||
CacheCreationTokens int
|
||||
CacheReadTokens int
|
||||
CacheCreation5mTokens int
|
||||
CacheCreation1hTokens int
|
||||
}
|
||||
|
||||
// CostBreakdown 费用明细
|
||||
type CostBreakdown struct {
|
||||
InputCost float64
|
||||
OutputCost float64
|
||||
CacheCreationCost float64
|
||||
CacheReadCost float64
|
||||
TotalCost float64
|
||||
ActualCost float64 // 应用倍率后的实际费用
|
||||
}
|
||||
|
||||
// BillingService 计费服务
|
||||
type BillingService struct {
|
||||
cfg *config.Config
|
||||
pricingService *PricingService
|
||||
fallbackPrices map[string]*ModelPricing // 硬编码回退价格
|
||||
}
|
||||
|
||||
// NewBillingService 创建计费服务实例
|
||||
func NewBillingService(cfg *config.Config, pricingService *PricingService) *BillingService {
|
||||
s := &BillingService{
|
||||
cfg: cfg,
|
||||
pricingService: pricingService,
|
||||
fallbackPrices: make(map[string]*ModelPricing),
|
||||
}
|
||||
|
||||
// 初始化硬编码回退价格(当动态价格不可用时使用)
|
||||
s.initFallbackPricing()
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
// initFallbackPricing 初始化硬编码回退价格(当动态价格不可用时使用)
|
||||
// 价格单位:USD per token(与LiteLLM格式一致)
|
||||
func (s *BillingService) initFallbackPricing() {
|
||||
// Claude 4.5 Opus
|
||||
s.fallbackPrices["claude-opus-4.5"] = &ModelPricing{
|
||||
InputPricePerToken: 5e-6, // $5 per MTok
|
||||
OutputPricePerToken: 25e-6, // $25 per MTok
|
||||
CacheCreationPricePerToken: 6.25e-6, // $6.25 per MTok
|
||||
CacheReadPricePerToken: 0.5e-6, // $0.50 per MTok
|
||||
SupportsCacheBreakdown: false,
|
||||
}
|
||||
|
||||
// Claude 4 Sonnet
|
||||
s.fallbackPrices["claude-sonnet-4"] = &ModelPricing{
|
||||
InputPricePerToken: 3e-6, // $3 per MTok
|
||||
OutputPricePerToken: 15e-6, // $15 per MTok
|
||||
CacheCreationPricePerToken: 3.75e-6, // $3.75 per MTok
|
||||
CacheReadPricePerToken: 0.3e-6, // $0.30 per MTok
|
||||
SupportsCacheBreakdown: false,
|
||||
}
|
||||
|
||||
// Claude 3.5 Sonnet
|
||||
s.fallbackPrices["claude-3-5-sonnet"] = &ModelPricing{
|
||||
InputPricePerToken: 3e-6, // $3 per MTok
|
||||
OutputPricePerToken: 15e-6, // $15 per MTok
|
||||
CacheCreationPricePerToken: 3.75e-6, // $3.75 per MTok
|
||||
CacheReadPricePerToken: 0.3e-6, // $0.30 per MTok
|
||||
SupportsCacheBreakdown: false,
|
||||
}
|
||||
|
||||
// Claude 3.5 Haiku
|
||||
s.fallbackPrices["claude-3-5-haiku"] = &ModelPricing{
|
||||
InputPricePerToken: 1e-6, // $1 per MTok
|
||||
OutputPricePerToken: 5e-6, // $5 per MTok
|
||||
CacheCreationPricePerToken: 1.25e-6, // $1.25 per MTok
|
||||
CacheReadPricePerToken: 0.1e-6, // $0.10 per MTok
|
||||
SupportsCacheBreakdown: false,
|
||||
}
|
||||
|
||||
// Claude 3 Opus
|
||||
s.fallbackPrices["claude-3-opus"] = &ModelPricing{
|
||||
InputPricePerToken: 15e-6, // $15 per MTok
|
||||
OutputPricePerToken: 75e-6, // $75 per MTok
|
||||
CacheCreationPricePerToken: 18.75e-6, // $18.75 per MTok
|
||||
CacheReadPricePerToken: 1.5e-6, // $1.50 per MTok
|
||||
SupportsCacheBreakdown: false,
|
||||
}
|
||||
|
||||
// Claude 3 Haiku
|
||||
s.fallbackPrices["claude-3-haiku"] = &ModelPricing{
|
||||
InputPricePerToken: 0.25e-6, // $0.25 per MTok
|
||||
OutputPricePerToken: 1.25e-6, // $1.25 per MTok
|
||||
CacheCreationPricePerToken: 0.3e-6, // $0.30 per MTok
|
||||
CacheReadPricePerToken: 0.03e-6, // $0.03 per MTok
|
||||
SupportsCacheBreakdown: false,
|
||||
}
|
||||
}
|
||||
|
||||
// getFallbackPricing 根据模型系列获取回退价格
|
||||
func (s *BillingService) getFallbackPricing(model string) *ModelPricing {
|
||||
modelLower := strings.ToLower(model)
|
||||
|
||||
// 按模型系列匹配
|
||||
if strings.Contains(modelLower, "opus") {
|
||||
if strings.Contains(modelLower, "4.5") || strings.Contains(modelLower, "4-5") {
|
||||
return s.fallbackPrices["claude-opus-4.5"]
|
||||
}
|
||||
return s.fallbackPrices["claude-3-opus"]
|
||||
}
|
||||
if strings.Contains(modelLower, "sonnet") {
|
||||
if strings.Contains(modelLower, "4") && !strings.Contains(modelLower, "3") {
|
||||
return s.fallbackPrices["claude-sonnet-4"]
|
||||
}
|
||||
return s.fallbackPrices["claude-3-5-sonnet"]
|
||||
}
|
||||
if strings.Contains(modelLower, "haiku") {
|
||||
if strings.Contains(modelLower, "3-5") || strings.Contains(modelLower, "3.5") {
|
||||
return s.fallbackPrices["claude-3-5-haiku"]
|
||||
}
|
||||
return s.fallbackPrices["claude-3-haiku"]
|
||||
}
|
||||
|
||||
// 默认使用Sonnet价格
|
||||
return s.fallbackPrices["claude-sonnet-4"]
|
||||
}
|
||||
|
||||
// GetModelPricing 获取模型价格配置
|
||||
func (s *BillingService) GetModelPricing(model string) (*ModelPricing, error) {
|
||||
// 标准化模型名称(转小写)
|
||||
model = strings.ToLower(model)
|
||||
|
||||
// 1. 优先从动态价格服务获取
|
||||
if s.pricingService != nil {
|
||||
litellmPricing := s.pricingService.GetModelPricing(model)
|
||||
if litellmPricing != nil {
|
||||
return &ModelPricing{
|
||||
InputPricePerToken: litellmPricing.InputCostPerToken,
|
||||
OutputPricePerToken: litellmPricing.OutputCostPerToken,
|
||||
CacheCreationPricePerToken: litellmPricing.CacheCreationInputTokenCost,
|
||||
CacheReadPricePerToken: litellmPricing.CacheReadInputTokenCost,
|
||||
SupportsCacheBreakdown: false,
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
// 2. 使用硬编码回退价格
|
||||
fallback := s.getFallbackPricing(model)
|
||||
if fallback != nil {
|
||||
log.Printf("[Billing] Using fallback pricing for model: %s", model)
|
||||
return fallback, nil
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("pricing not found for model: %s", model)
|
||||
}
|
||||
|
||||
// CalculateCost 计算使用费用
|
||||
func (s *BillingService) CalculateCost(model string, tokens UsageTokens, rateMultiplier float64) (*CostBreakdown, error) {
|
||||
pricing, err := s.GetModelPricing(model)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
breakdown := &CostBreakdown{}
|
||||
|
||||
// 计算输入token费用(使用per-token价格)
|
||||
breakdown.InputCost = float64(tokens.InputTokens) * pricing.InputPricePerToken
|
||||
|
||||
// 计算输出token费用
|
||||
breakdown.OutputCost = float64(tokens.OutputTokens) * pricing.OutputPricePerToken
|
||||
|
||||
// 计算缓存费用
|
||||
if pricing.SupportsCacheBreakdown && (pricing.CacheCreation5mPrice > 0 || pricing.CacheCreation1hPrice > 0) {
|
||||
// 支持详细缓存分类的模型(5分钟/1小时缓存)
|
||||
breakdown.CacheCreationCost = float64(tokens.CacheCreation5mTokens)/1_000_000*pricing.CacheCreation5mPrice +
|
||||
float64(tokens.CacheCreation1hTokens)/1_000_000*pricing.CacheCreation1hPrice
|
||||
} else {
|
||||
// 标准缓存创建价格(per-token)
|
||||
breakdown.CacheCreationCost = float64(tokens.CacheCreationTokens) * pricing.CacheCreationPricePerToken
|
||||
}
|
||||
|
||||
breakdown.CacheReadCost = float64(tokens.CacheReadTokens) * pricing.CacheReadPricePerToken
|
||||
|
||||
// 计算总费用
|
||||
breakdown.TotalCost = breakdown.InputCost + breakdown.OutputCost +
|
||||
breakdown.CacheCreationCost + breakdown.CacheReadCost
|
||||
|
||||
// 应用倍率计算实际费用
|
||||
if rateMultiplier <= 0 {
|
||||
rateMultiplier = 1.0
|
||||
}
|
||||
breakdown.ActualCost = breakdown.TotalCost * rateMultiplier
|
||||
|
||||
return breakdown, nil
|
||||
}
|
||||
|
||||
// CalculateCostWithConfig 使用配置中的默认倍率计算费用
|
||||
func (s *BillingService) CalculateCostWithConfig(model string, tokens UsageTokens) (*CostBreakdown, error) {
|
||||
multiplier := s.cfg.Default.RateMultiplier
|
||||
if multiplier <= 0 {
|
||||
multiplier = 1.0
|
||||
}
|
||||
return s.CalculateCost(model, tokens, multiplier)
|
||||
}
|
||||
|
||||
// ListSupportedModels 列出所有支持的模型(现在总是返回true,因为有模糊匹配)
|
||||
func (s *BillingService) ListSupportedModels() []string {
|
||||
models := make([]string, 0)
|
||||
// 返回回退价格支持的模型系列
|
||||
for model := range s.fallbackPrices {
|
||||
models = append(models, model)
|
||||
}
|
||||
return models
|
||||
}
|
||||
|
||||
// IsModelSupported 检查模型是否支持(现在总是返回true,因为有模糊匹配回退)
|
||||
func (s *BillingService) IsModelSupported(model string) bool {
|
||||
// 所有Claude模型都有回退价格支持
|
||||
modelLower := strings.ToLower(model)
|
||||
return strings.Contains(modelLower, "claude") ||
|
||||
strings.Contains(modelLower, "opus") ||
|
||||
strings.Contains(modelLower, "sonnet") ||
|
||||
strings.Contains(modelLower, "haiku")
|
||||
}
|
||||
|
||||
// GetEstimatedCost 估算费用(用于前端展示)
|
||||
func (s *BillingService) GetEstimatedCost(model string, estimatedInputTokens, estimatedOutputTokens int) (float64, error) {
|
||||
tokens := UsageTokens{
|
||||
InputTokens: estimatedInputTokens,
|
||||
OutputTokens: estimatedOutputTokens,
|
||||
}
|
||||
|
||||
breakdown, err := s.CalculateCostWithConfig(model, tokens)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return breakdown.ActualCost, nil
|
||||
}
|
||||
|
||||
// GetPricingServiceStatus 获取价格服务状态
|
||||
func (s *BillingService) GetPricingServiceStatus() map[string]interface{} {
|
||||
if s.pricingService != nil {
|
||||
return s.pricingService.GetStatus()
|
||||
}
|
||||
return map[string]interface{}{
|
||||
"model_count": len(s.fallbackPrices),
|
||||
"last_updated": "using fallback",
|
||||
"local_hash": "N/A",
|
||||
}
|
||||
}
|
||||
|
||||
// ForceUpdatePricing 强制更新价格数据
|
||||
func (s *BillingService) ForceUpdatePricing() error {
|
||||
if s.pricingService != nil {
|
||||
return s.pricingService.ForceUpdate()
|
||||
}
|
||||
return fmt.Errorf("pricing service not initialized")
|
||||
}
|
||||
251
backend/internal/service/concurrency_service.go
Normal file
251
backend/internal/service/concurrency_service.go
Normal file
@@ -0,0 +1,251 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
const (
|
||||
// Redis key prefixes
|
||||
accountConcurrencyKey = "concurrency:account:"
|
||||
userConcurrencyKey = "concurrency:user:"
|
||||
userWaitCountKey = "concurrency:wait:"
|
||||
|
||||
// TTL for concurrency keys (auto-release safety net)
|
||||
concurrencyKeyTTL = 10 * time.Minute
|
||||
|
||||
// Wait polling interval
|
||||
waitPollInterval = 100 * time.Millisecond
|
||||
|
||||
// Default max wait time
|
||||
defaultMaxWait = 60 * time.Second
|
||||
|
||||
// Default extra wait slots beyond concurrency limit
|
||||
defaultExtraWaitSlots = 20
|
||||
)
|
||||
|
||||
// Pre-compiled Lua scripts for better performance
|
||||
var (
|
||||
// acquireScript: increment counter if below max, return 1 if successful
|
||||
acquireScript = redis.NewScript(`
|
||||
local current = redis.call('GET', KEYS[1])
|
||||
if current == false then
|
||||
current = 0
|
||||
else
|
||||
current = tonumber(current)
|
||||
end
|
||||
if current < tonumber(ARGV[1]) then
|
||||
redis.call('INCR', KEYS[1])
|
||||
redis.call('EXPIRE', KEYS[1], ARGV[2])
|
||||
return 1
|
||||
end
|
||||
return 0
|
||||
`)
|
||||
|
||||
// releaseScript: decrement counter, but don't go below 0
|
||||
releaseScript = redis.NewScript(`
|
||||
local current = redis.call('GET', KEYS[1])
|
||||
if current ~= false and tonumber(current) > 0 then
|
||||
redis.call('DECR', KEYS[1])
|
||||
end
|
||||
return 1
|
||||
`)
|
||||
|
||||
// incrementWaitScript: increment wait counter if below max, return 1 if successful
|
||||
incrementWaitScript = redis.NewScript(`
|
||||
local waitKey = KEYS[1]
|
||||
local maxWait = tonumber(ARGV[1])
|
||||
local ttl = tonumber(ARGV[2])
|
||||
local current = redis.call('GET', waitKey)
|
||||
if current == false then
|
||||
current = 0
|
||||
else
|
||||
current = tonumber(current)
|
||||
end
|
||||
if current >= maxWait then
|
||||
return 0
|
||||
end
|
||||
redis.call('INCR', waitKey)
|
||||
redis.call('EXPIRE', waitKey, ttl)
|
||||
return 1
|
||||
`)
|
||||
|
||||
// decrementWaitScript: decrement wait counter, but don't go below 0
|
||||
decrementWaitScript = redis.NewScript(`
|
||||
local current = redis.call('GET', KEYS[1])
|
||||
if current ~= false and tonumber(current) > 0 then
|
||||
redis.call('DECR', KEYS[1])
|
||||
end
|
||||
return 1
|
||||
`)
|
||||
)
|
||||
|
||||
// ConcurrencyService manages concurrent request limiting for accounts and users
|
||||
type ConcurrencyService struct {
|
||||
rdb *redis.Client
|
||||
}
|
||||
|
||||
// NewConcurrencyService creates a new ConcurrencyService
|
||||
func NewConcurrencyService(rdb *redis.Client) *ConcurrencyService {
|
||||
return &ConcurrencyService{rdb: rdb}
|
||||
}
|
||||
|
||||
// AcquireResult represents the result of acquiring a concurrency slot
|
||||
type AcquireResult struct {
|
||||
Acquired bool
|
||||
ReleaseFunc func() // Must be called when done (typically via defer)
|
||||
}
|
||||
|
||||
// AcquireAccountSlot attempts to acquire a concurrency slot for an account.
|
||||
// If the account is at max concurrency, it waits until a slot is available or timeout.
|
||||
// Returns a release function that MUST be called when the request completes.
|
||||
func (s *ConcurrencyService) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int) (*AcquireResult, error) {
|
||||
key := fmt.Sprintf("%s%d", accountConcurrencyKey, accountID)
|
||||
return s.acquireSlot(ctx, key, maxConcurrency)
|
||||
}
|
||||
|
||||
// AcquireUserSlot attempts to acquire a concurrency slot for a user.
|
||||
// If the user is at max concurrency, it waits until a slot is available or timeout.
|
||||
// Returns a release function that MUST be called when the request completes.
|
||||
func (s *ConcurrencyService) AcquireUserSlot(ctx context.Context, userID int64, maxConcurrency int) (*AcquireResult, error) {
|
||||
key := fmt.Sprintf("%s%d", userConcurrencyKey, userID)
|
||||
return s.acquireSlot(ctx, key, maxConcurrency)
|
||||
}
|
||||
|
||||
// acquireSlot is the core implementation for acquiring a concurrency slot
|
||||
func (s *ConcurrencyService) acquireSlot(ctx context.Context, key string, maxConcurrency int) (*AcquireResult, error) {
|
||||
// If maxConcurrency is 0 or negative, no limit
|
||||
if maxConcurrency <= 0 {
|
||||
return &AcquireResult{
|
||||
Acquired: true,
|
||||
ReleaseFunc: func() {}, // no-op
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Try to acquire immediately
|
||||
acquired, err := s.tryAcquire(ctx, key, maxConcurrency)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if acquired {
|
||||
return &AcquireResult{
|
||||
Acquired: true,
|
||||
ReleaseFunc: s.makeReleaseFunc(key),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Not acquired, return with Acquired=false
|
||||
// The caller (gateway handler) will handle waiting with ping support
|
||||
return &AcquireResult{
|
||||
Acquired: false,
|
||||
ReleaseFunc: nil,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// tryAcquire attempts to increment the counter if below max
|
||||
// Uses pre-compiled Lua script for atomicity and performance
|
||||
func (s *ConcurrencyService) tryAcquire(ctx context.Context, key string, maxConcurrency int) (bool, error) {
|
||||
result, err := acquireScript.Run(ctx, s.rdb, []string{key}, maxConcurrency, int(concurrencyKeyTTL.Seconds())).Int()
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("acquire slot failed: %w", err)
|
||||
}
|
||||
return result == 1, nil
|
||||
}
|
||||
|
||||
// makeReleaseFunc creates a function to release a concurrency slot
|
||||
func (s *ConcurrencyService) makeReleaseFunc(key string) func() {
|
||||
return func() {
|
||||
// Use background context to ensure release even if original context is cancelled
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := releaseScript.Run(ctx, s.rdb, []string{key}).Err(); err != nil {
|
||||
// Log error but don't panic - TTL will eventually clean up
|
||||
log.Printf("Warning: failed to release concurrency slot for %s: %v", key, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// GetCurrentCount returns the current concurrency count for debugging/monitoring
|
||||
func (s *ConcurrencyService) GetCurrentCount(ctx context.Context, key string) (int, error) {
|
||||
val, err := s.rdb.Get(ctx, key).Int()
|
||||
if err == redis.Nil {
|
||||
return 0, nil
|
||||
}
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return val, nil
|
||||
}
|
||||
|
||||
// GetAccountCurrentCount returns current concurrency count for an account
|
||||
func (s *ConcurrencyService) GetAccountCurrentCount(ctx context.Context, accountID int64) (int, error) {
|
||||
key := fmt.Sprintf("%s%d", accountConcurrencyKey, accountID)
|
||||
return s.GetCurrentCount(ctx, key)
|
||||
}
|
||||
|
||||
// GetUserCurrentCount returns current concurrency count for a user
|
||||
func (s *ConcurrencyService) GetUserCurrentCount(ctx context.Context, userID int64) (int, error) {
|
||||
key := fmt.Sprintf("%s%d", userConcurrencyKey, userID)
|
||||
return s.GetCurrentCount(ctx, key)
|
||||
}
|
||||
|
||||
// ============================================
|
||||
// Wait Queue Count Methods
|
||||
// ============================================
|
||||
|
||||
// IncrementWaitCount attempts to increment the wait queue counter for a user.
|
||||
// Returns true if successful, false if the wait queue is full.
|
||||
// maxWait should be user.Concurrency + defaultExtraWaitSlots
|
||||
func (s *ConcurrencyService) IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error) {
|
||||
if s.rdb == nil {
|
||||
// Redis not available, allow request
|
||||
return true, nil
|
||||
}
|
||||
|
||||
key := fmt.Sprintf("%s%d", userWaitCountKey, userID)
|
||||
result, err := incrementWaitScript.Run(ctx, s.rdb, []string{key}, maxWait, int(concurrencyKeyTTL.Seconds())).Int()
|
||||
if err != nil {
|
||||
// On error, allow the request to proceed (fail open)
|
||||
log.Printf("Warning: increment wait count failed for user %d: %v", userID, err)
|
||||
return true, nil
|
||||
}
|
||||
return result == 1, nil
|
||||
}
|
||||
|
||||
// DecrementWaitCount decrements the wait queue counter for a user.
|
||||
// Should be called when a request completes or exits the wait queue.
|
||||
func (s *ConcurrencyService) DecrementWaitCount(ctx context.Context, userID int64) {
|
||||
if s.rdb == nil {
|
||||
return
|
||||
}
|
||||
|
||||
key := fmt.Sprintf("%s%d", userWaitCountKey, userID)
|
||||
// Use background context to ensure decrement even if original context is cancelled
|
||||
bgCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := decrementWaitScript.Run(bgCtx, s.rdb, []string{key}).Err(); err != nil {
|
||||
log.Printf("Warning: decrement wait count failed for user %d: %v", userID, err)
|
||||
}
|
||||
}
|
||||
|
||||
// GetUserWaitCount returns current wait queue count for a user
|
||||
func (s *ConcurrencyService) GetUserWaitCount(ctx context.Context, userID int64) (int, error) {
|
||||
key := fmt.Sprintf("%s%d", userWaitCountKey, userID)
|
||||
return s.GetCurrentCount(ctx, key)
|
||||
}
|
||||
|
||||
// CalculateMaxWait calculates the maximum wait queue size for a user
|
||||
// maxWait = userConcurrency + defaultExtraWaitSlots
|
||||
func CalculateMaxWait(userConcurrency int) int {
|
||||
if userConcurrency <= 0 {
|
||||
userConcurrency = 1
|
||||
}
|
||||
return userConcurrency + defaultExtraWaitSlots
|
||||
}
|
||||
109
backend/internal/service/email_queue_service.go
Normal file
109
backend/internal/service/email_queue_service.go
Normal file
@@ -0,0 +1,109 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// EmailTask 邮件发送任务
|
||||
type EmailTask struct {
|
||||
Email string
|
||||
SiteName string
|
||||
TaskType string // "verify_code"
|
||||
}
|
||||
|
||||
// EmailQueueService 异步邮件队列服务
|
||||
type EmailQueueService struct {
|
||||
emailService *EmailService
|
||||
taskChan chan EmailTask
|
||||
wg sync.WaitGroup
|
||||
stopChan chan struct{}
|
||||
workers int
|
||||
}
|
||||
|
||||
// NewEmailQueueService 创建邮件队列服务
|
||||
func NewEmailQueueService(emailService *EmailService, workers int) *EmailQueueService {
|
||||
if workers <= 0 {
|
||||
workers = 3 // 默认3个工作协程
|
||||
}
|
||||
|
||||
service := &EmailQueueService{
|
||||
emailService: emailService,
|
||||
taskChan: make(chan EmailTask, 100), // 缓冲100个任务
|
||||
stopChan: make(chan struct{}),
|
||||
workers: workers,
|
||||
}
|
||||
|
||||
// 启动工作协程
|
||||
service.start()
|
||||
|
||||
return service
|
||||
}
|
||||
|
||||
// start 启动工作协程
|
||||
func (s *EmailQueueService) start() {
|
||||
for i := 0; i < s.workers; i++ {
|
||||
s.wg.Add(1)
|
||||
go s.worker(i)
|
||||
}
|
||||
log.Printf("[EmailQueue] Started %d workers", s.workers)
|
||||
}
|
||||
|
||||
// worker 工作协程
|
||||
func (s *EmailQueueService) worker(id int) {
|
||||
defer s.wg.Done()
|
||||
|
||||
for {
|
||||
select {
|
||||
case task := <-s.taskChan:
|
||||
s.processTask(id, task)
|
||||
case <-s.stopChan:
|
||||
log.Printf("[EmailQueue] Worker %d stopping", id)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// processTask 处理任务
|
||||
func (s *EmailQueueService) processTask(workerID int, task EmailTask) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
switch task.TaskType {
|
||||
case "verify_code":
|
||||
if err := s.emailService.SendVerifyCode(ctx, task.Email, task.SiteName); err != nil {
|
||||
log.Printf("[EmailQueue] Worker %d failed to send verify code to %s: %v", workerID, task.Email, err)
|
||||
} else {
|
||||
log.Printf("[EmailQueue] Worker %d sent verify code to %s", workerID, task.Email)
|
||||
}
|
||||
default:
|
||||
log.Printf("[EmailQueue] Worker %d unknown task type: %s", workerID, task.TaskType)
|
||||
}
|
||||
}
|
||||
|
||||
// EnqueueVerifyCode 将验证码发送任务加入队列
|
||||
func (s *EmailQueueService) EnqueueVerifyCode(email, siteName string) error {
|
||||
task := EmailTask{
|
||||
Email: email,
|
||||
SiteName: siteName,
|
||||
TaskType: "verify_code",
|
||||
}
|
||||
|
||||
select {
|
||||
case s.taskChan <- task:
|
||||
log.Printf("[EmailQueue] Enqueued verify code task for %s", email)
|
||||
return nil
|
||||
default:
|
||||
return fmt.Errorf("email queue is full")
|
||||
}
|
||||
}
|
||||
|
||||
// Stop 停止队列服务
|
||||
func (s *EmailQueueService) Stop() {
|
||||
close(s.stopChan)
|
||||
s.wg.Wait()
|
||||
log.Println("[EmailQueue] All workers stopped")
|
||||
}
|
||||
372
backend/internal/service/email_service.go
Normal file
372
backend/internal/service/email_service.go
Normal file
@@ -0,0 +1,372 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/tls"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"net/smtp"
|
||||
"strconv"
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/repository"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrEmailNotConfigured = errors.New("email service not configured")
|
||||
ErrInvalidVerifyCode = errors.New("invalid or expired verification code")
|
||||
ErrVerifyCodeTooFrequent = errors.New("please wait before requesting a new code")
|
||||
ErrVerifyCodeMaxAttempts = errors.New("too many failed attempts, please request a new code")
|
||||
)
|
||||
|
||||
const (
|
||||
verifyCodeKeyPrefix = "email_verify:"
|
||||
verifyCodeTTL = 15 * time.Minute
|
||||
verifyCodeCooldown = 1 * time.Minute
|
||||
maxVerifyCodeAttempts = 5
|
||||
)
|
||||
|
||||
// verifyCodeData Redis 中存储的验证码数据
|
||||
type verifyCodeData struct {
|
||||
Code string `json:"code"`
|
||||
Attempts int `json:"attempts"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
}
|
||||
|
||||
// SmtpConfig SMTP配置
|
||||
type SmtpConfig struct {
|
||||
Host string
|
||||
Port int
|
||||
Username string
|
||||
Password string
|
||||
From string
|
||||
FromName string
|
||||
UseTLS bool
|
||||
}
|
||||
|
||||
// EmailService 邮件服务
|
||||
type EmailService struct {
|
||||
settingRepo *repository.SettingRepository
|
||||
rdb *redis.Client
|
||||
}
|
||||
|
||||
// NewEmailService 创建邮件服务实例
|
||||
func NewEmailService(settingRepo *repository.SettingRepository, rdb *redis.Client) *EmailService {
|
||||
return &EmailService{
|
||||
settingRepo: settingRepo,
|
||||
rdb: rdb,
|
||||
}
|
||||
}
|
||||
|
||||
// GetSmtpConfig 从数据库获取SMTP配置
|
||||
func (s *EmailService) GetSmtpConfig(ctx context.Context) (*SmtpConfig, error) {
|
||||
keys := []string{
|
||||
model.SettingKeySmtpHost,
|
||||
model.SettingKeySmtpPort,
|
||||
model.SettingKeySmtpUsername,
|
||||
model.SettingKeySmtpPassword,
|
||||
model.SettingKeySmtpFrom,
|
||||
model.SettingKeySmtpFromName,
|
||||
model.SettingKeySmtpUseTLS,
|
||||
}
|
||||
|
||||
settings, err := s.settingRepo.GetMultiple(ctx, keys)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get smtp settings: %w", err)
|
||||
}
|
||||
|
||||
host := settings[model.SettingKeySmtpHost]
|
||||
if host == "" {
|
||||
return nil, ErrEmailNotConfigured
|
||||
}
|
||||
|
||||
port := 587 // 默认端口
|
||||
if portStr := settings[model.SettingKeySmtpPort]; portStr != "" {
|
||||
if p, err := strconv.Atoi(portStr); err == nil {
|
||||
port = p
|
||||
}
|
||||
}
|
||||
|
||||
useTLS := settings[model.SettingKeySmtpUseTLS] == "true"
|
||||
|
||||
return &SmtpConfig{
|
||||
Host: host,
|
||||
Port: port,
|
||||
Username: settings[model.SettingKeySmtpUsername],
|
||||
Password: settings[model.SettingKeySmtpPassword],
|
||||
From: settings[model.SettingKeySmtpFrom],
|
||||
FromName: settings[model.SettingKeySmtpFromName],
|
||||
UseTLS: useTLS,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// SendEmail 发送邮件(使用数据库中保存的配置)
|
||||
func (s *EmailService) SendEmail(ctx context.Context, to, subject, body string) error {
|
||||
config, err := s.GetSmtpConfig(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return s.SendEmailWithConfig(config, to, subject, body)
|
||||
}
|
||||
|
||||
// SendEmailWithConfig 使用指定配置发送邮件
|
||||
func (s *EmailService) SendEmailWithConfig(config *SmtpConfig, to, subject, body string) error {
|
||||
from := config.From
|
||||
if config.FromName != "" {
|
||||
from = fmt.Sprintf("%s <%s>", config.FromName, config.From)
|
||||
}
|
||||
|
||||
msg := fmt.Sprintf("From: %s\r\nTo: %s\r\nSubject: %s\r\nMIME-Version: 1.0\r\nContent-Type: text/html; charset=UTF-8\r\n\r\n%s",
|
||||
from, to, subject, body)
|
||||
|
||||
addr := fmt.Sprintf("%s:%d", config.Host, config.Port)
|
||||
auth := smtp.PlainAuth("", config.Username, config.Password, config.Host)
|
||||
|
||||
if config.UseTLS {
|
||||
return s.sendMailTLS(addr, auth, config.From, to, []byte(msg), config.Host)
|
||||
}
|
||||
|
||||
return smtp.SendMail(addr, auth, config.From, []string{to}, []byte(msg))
|
||||
}
|
||||
|
||||
// sendMailTLS 使用TLS发送邮件
|
||||
func (s *EmailService) sendMailTLS(addr string, auth smtp.Auth, from, to string, msg []byte, host string) error {
|
||||
tlsConfig := &tls.Config{
|
||||
ServerName: host,
|
||||
}
|
||||
|
||||
conn, err := tls.Dial("tcp", addr, tlsConfig)
|
||||
if err != nil {
|
||||
return fmt.Errorf("tls dial: %w", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
client, err := smtp.NewClient(conn, host)
|
||||
if err != nil {
|
||||
return fmt.Errorf("new smtp client: %w", err)
|
||||
}
|
||||
defer client.Close()
|
||||
|
||||
if err = client.Auth(auth); err != nil {
|
||||
return fmt.Errorf("smtp auth: %w", err)
|
||||
}
|
||||
|
||||
if err = client.Mail(from); err != nil {
|
||||
return fmt.Errorf("smtp mail: %w", err)
|
||||
}
|
||||
|
||||
if err = client.Rcpt(to); err != nil {
|
||||
return fmt.Errorf("smtp rcpt: %w", err)
|
||||
}
|
||||
|
||||
w, err := client.Data()
|
||||
if err != nil {
|
||||
return fmt.Errorf("smtp data: %w", err)
|
||||
}
|
||||
|
||||
_, err = w.Write(msg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("write msg: %w", err)
|
||||
}
|
||||
|
||||
err = w.Close()
|
||||
if err != nil {
|
||||
return fmt.Errorf("close writer: %w", err)
|
||||
}
|
||||
|
||||
// Email is sent successfully after w.Close(), ignore Quit errors
|
||||
// Some SMTP servers return non-standard responses on QUIT
|
||||
_ = client.Quit()
|
||||
return nil
|
||||
}
|
||||
|
||||
// GenerateVerifyCode 生成6位数字验证码
|
||||
func (s *EmailService) GenerateVerifyCode() (string, error) {
|
||||
const digits = "0123456789"
|
||||
code := make([]byte, 6)
|
||||
for i := range code {
|
||||
num, err := rand.Int(rand.Reader, big.NewInt(int64(len(digits))))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
code[i] = digits[num.Int64()]
|
||||
}
|
||||
return string(code), nil
|
||||
}
|
||||
|
||||
// SendVerifyCode 发送验证码邮件
|
||||
func (s *EmailService) SendVerifyCode(ctx context.Context, email, siteName string) error {
|
||||
key := verifyCodeKeyPrefix + email
|
||||
|
||||
// 检查是否在冷却期内
|
||||
existing, err := s.getVerifyCodeData(ctx, key)
|
||||
if err == nil && existing != nil {
|
||||
if time.Since(existing.CreatedAt) < verifyCodeCooldown {
|
||||
return ErrVerifyCodeTooFrequent
|
||||
}
|
||||
}
|
||||
|
||||
// 生成验证码
|
||||
code, err := s.GenerateVerifyCode()
|
||||
if err != nil {
|
||||
return fmt.Errorf("generate code: %w", err)
|
||||
}
|
||||
|
||||
// 保存验证码到 Redis
|
||||
data := &verifyCodeData{
|
||||
Code: code,
|
||||
Attempts: 0,
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
if err := s.setVerifyCodeData(ctx, key, data); err != nil {
|
||||
return fmt.Errorf("save verify code: %w", err)
|
||||
}
|
||||
|
||||
// 构建邮件内容
|
||||
subject := fmt.Sprintf("[%s] Email Verification Code", siteName)
|
||||
body := s.buildVerifyCodeEmailBody(code, siteName)
|
||||
|
||||
// 发送邮件
|
||||
if err := s.SendEmail(ctx, email, subject, body); err != nil {
|
||||
return fmt.Errorf("send email: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// VerifyCode 验证验证码
|
||||
func (s *EmailService) VerifyCode(ctx context.Context, email, code string) error {
|
||||
key := verifyCodeKeyPrefix + email
|
||||
|
||||
data, err := s.getVerifyCodeData(ctx, key)
|
||||
if err != nil || data == nil {
|
||||
return ErrInvalidVerifyCode
|
||||
}
|
||||
|
||||
// 检查是否已达到最大尝试次数
|
||||
if data.Attempts >= maxVerifyCodeAttempts {
|
||||
return ErrVerifyCodeMaxAttempts
|
||||
}
|
||||
|
||||
// 验证码不匹配
|
||||
if data.Code != code {
|
||||
data.Attempts++
|
||||
_ = s.setVerifyCodeData(ctx, key, data)
|
||||
if data.Attempts >= maxVerifyCodeAttempts {
|
||||
return ErrVerifyCodeMaxAttempts
|
||||
}
|
||||
return ErrInvalidVerifyCode
|
||||
}
|
||||
|
||||
// 验证成功,删除验证码
|
||||
s.rdb.Del(ctx, key)
|
||||
return nil
|
||||
}
|
||||
|
||||
// getVerifyCodeData 从 Redis 获取验证码数据
|
||||
func (s *EmailService) getVerifyCodeData(ctx context.Context, key string) (*verifyCodeData, error) {
|
||||
val, err := s.rdb.Get(ctx, key).Result()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var data verifyCodeData
|
||||
if err := json.Unmarshal([]byte(val), &data); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &data, nil
|
||||
}
|
||||
|
||||
// setVerifyCodeData 保存验证码数据到 Redis
|
||||
func (s *EmailService) setVerifyCodeData(ctx context.Context, key string, data *verifyCodeData) error {
|
||||
val, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return s.rdb.Set(ctx, key, val, verifyCodeTTL).Err()
|
||||
}
|
||||
|
||||
// buildVerifyCodeEmailBody 构建验证码邮件HTML内容
|
||||
func (s *EmailService) buildVerifyCodeEmailBody(code, siteName string) string {
|
||||
return fmt.Sprintf(`
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<style>
|
||||
body { font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Oxygen, Ubuntu, sans-serif; background-color: #f5f5f5; margin: 0; padding: 20px; }
|
||||
.container { max-width: 600px; margin: 0 auto; background-color: #ffffff; border-radius: 8px; overflow: hidden; box-shadow: 0 2px 8px rgba(0,0,0,0.1); }
|
||||
.header { background: linear-gradient(135deg, #667eea 0%%, #764ba2 100%%); color: white; padding: 30px; text-align: center; }
|
||||
.header h1 { margin: 0; font-size: 24px; }
|
||||
.content { padding: 40px 30px; text-align: center; }
|
||||
.code { font-size: 36px; font-weight: bold; letter-spacing: 8px; color: #333; background-color: #f8f9fa; padding: 20px 30px; border-radius: 8px; display: inline-block; margin: 20px 0; font-family: monospace; }
|
||||
.info { color: #666; font-size: 14px; line-height: 1.6; margin-top: 20px; }
|
||||
.footer { background-color: #f8f9fa; padding: 20px; text-align: center; color: #999; font-size: 12px; }
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="container">
|
||||
<div class="header">
|
||||
<h1>%s</h1>
|
||||
</div>
|
||||
<div class="content">
|
||||
<p style="font-size: 18px; color: #333;">Your verification code is:</p>
|
||||
<div class="code">%s</div>
|
||||
<div class="info">
|
||||
<p>This code will expire in <strong>15 minutes</strong>.</p>
|
||||
<p>If you did not request this code, please ignore this email.</p>
|
||||
</div>
|
||||
</div>
|
||||
<div class="footer">
|
||||
<p>This is an automated message, please do not reply.</p>
|
||||
</div>
|
||||
</div>
|
||||
</body>
|
||||
</html>
|
||||
`, siteName, code)
|
||||
}
|
||||
|
||||
// TestSmtpConnectionWithConfig 使用指定配置测试SMTP连接
|
||||
func (s *EmailService) TestSmtpConnectionWithConfig(config *SmtpConfig) error {
|
||||
addr := fmt.Sprintf("%s:%d", config.Host, config.Port)
|
||||
|
||||
if config.UseTLS {
|
||||
tlsConfig := &tls.Config{ServerName: config.Host}
|
||||
conn, err := tls.Dial("tcp", addr, tlsConfig)
|
||||
if err != nil {
|
||||
return fmt.Errorf("tls connection failed: %w", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
client, err := smtp.NewClient(conn, config.Host)
|
||||
if err != nil {
|
||||
return fmt.Errorf("smtp client creation failed: %w", err)
|
||||
}
|
||||
defer client.Close()
|
||||
|
||||
auth := smtp.PlainAuth("", config.Username, config.Password, config.Host)
|
||||
if err = client.Auth(auth); err != nil {
|
||||
return fmt.Errorf("smtp authentication failed: %w", err)
|
||||
}
|
||||
|
||||
return client.Quit()
|
||||
}
|
||||
|
||||
// 非TLS连接测试
|
||||
client, err := smtp.Dial(addr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("smtp connection failed: %w", err)
|
||||
}
|
||||
defer client.Close()
|
||||
|
||||
auth := smtp.PlainAuth("", config.Username, config.Password, config.Host)
|
||||
if err = client.Auth(auth); err != nil {
|
||||
return fmt.Errorf("smtp authentication failed: %w", err)
|
||||
}
|
||||
|
||||
return client.Quit()
|
||||
}
|
||||
1022
backend/internal/service/gateway_service.go
Normal file
1022
backend/internal/service/gateway_service.go
Normal file
File diff suppressed because it is too large
Load Diff
194
backend/internal/service/group_service.go
Normal file
194
backend/internal/service/group_service.go
Normal file
@@ -0,0 +1,194 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/repository"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrGroupNotFound = errors.New("group not found")
|
||||
ErrGroupExists = errors.New("group name already exists")
|
||||
)
|
||||
|
||||
// CreateGroupRequest 创建分组请求
|
||||
type CreateGroupRequest struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
RateMultiplier float64 `json:"rate_multiplier"`
|
||||
IsExclusive bool `json:"is_exclusive"`
|
||||
}
|
||||
|
||||
// UpdateGroupRequest 更新分组请求
|
||||
type UpdateGroupRequest struct {
|
||||
Name *string `json:"name"`
|
||||
Description *string `json:"description"`
|
||||
RateMultiplier *float64 `json:"rate_multiplier"`
|
||||
IsExclusive *bool `json:"is_exclusive"`
|
||||
Status *string `json:"status"`
|
||||
}
|
||||
|
||||
// GroupService 分组管理服务
|
||||
type GroupService struct {
|
||||
groupRepo *repository.GroupRepository
|
||||
}
|
||||
|
||||
// NewGroupService 创建分组服务实例
|
||||
func NewGroupService(groupRepo *repository.GroupRepository) *GroupService {
|
||||
return &GroupService{
|
||||
groupRepo: groupRepo,
|
||||
}
|
||||
}
|
||||
|
||||
// Create 创建分组
|
||||
func (s *GroupService) Create(ctx context.Context, req CreateGroupRequest) (*model.Group, error) {
|
||||
// 检查名称是否已存在
|
||||
exists, err := s.groupRepo.ExistsByName(ctx, req.Name)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("check group exists: %w", err)
|
||||
}
|
||||
if exists {
|
||||
return nil, ErrGroupExists
|
||||
}
|
||||
|
||||
// 创建分组
|
||||
group := &model.Group{
|
||||
Name: req.Name,
|
||||
Description: req.Description,
|
||||
RateMultiplier: req.RateMultiplier,
|
||||
IsExclusive: req.IsExclusive,
|
||||
Status: model.StatusActive,
|
||||
}
|
||||
|
||||
if err := s.groupRepo.Create(ctx, group); err != nil {
|
||||
return nil, fmt.Errorf("create group: %w", err)
|
||||
}
|
||||
|
||||
return group, nil
|
||||
}
|
||||
|
||||
// GetByID 根据ID获取分组
|
||||
func (s *GroupService) GetByID(ctx context.Context, id int64) (*model.Group, error) {
|
||||
group, err := s.groupRepo.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, ErrGroupNotFound
|
||||
}
|
||||
return nil, fmt.Errorf("get group: %w", err)
|
||||
}
|
||||
return group, nil
|
||||
}
|
||||
|
||||
// List 获取分组列表
|
||||
func (s *GroupService) List(ctx context.Context, params repository.PaginationParams) ([]model.Group, *repository.PaginationResult, error) {
|
||||
groups, pagination, err := s.groupRepo.List(ctx, params)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("list groups: %w", err)
|
||||
}
|
||||
return groups, pagination, nil
|
||||
}
|
||||
|
||||
// ListActive 获取活跃分组列表
|
||||
func (s *GroupService) ListActive(ctx context.Context) ([]model.Group, error) {
|
||||
groups, err := s.groupRepo.ListActive(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list active groups: %w", err)
|
||||
}
|
||||
return groups, nil
|
||||
}
|
||||
|
||||
// Update 更新分组
|
||||
func (s *GroupService) Update(ctx context.Context, id int64, req UpdateGroupRequest) (*model.Group, error) {
|
||||
group, err := s.groupRepo.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, ErrGroupNotFound
|
||||
}
|
||||
return nil, fmt.Errorf("get group: %w", err)
|
||||
}
|
||||
|
||||
// 更新字段
|
||||
if req.Name != nil && *req.Name != group.Name {
|
||||
// 检查新名称是否已存在
|
||||
exists, err := s.groupRepo.ExistsByName(ctx, *req.Name)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("check group exists: %w", err)
|
||||
}
|
||||
if exists {
|
||||
return nil, ErrGroupExists
|
||||
}
|
||||
group.Name = *req.Name
|
||||
}
|
||||
|
||||
if req.Description != nil {
|
||||
group.Description = *req.Description
|
||||
}
|
||||
|
||||
if req.RateMultiplier != nil {
|
||||
group.RateMultiplier = *req.RateMultiplier
|
||||
}
|
||||
|
||||
if req.IsExclusive != nil {
|
||||
group.IsExclusive = *req.IsExclusive
|
||||
}
|
||||
|
||||
if req.Status != nil {
|
||||
group.Status = *req.Status
|
||||
}
|
||||
|
||||
if err := s.groupRepo.Update(ctx, group); err != nil {
|
||||
return nil, fmt.Errorf("update group: %w", err)
|
||||
}
|
||||
|
||||
return group, nil
|
||||
}
|
||||
|
||||
// Delete 删除分组
|
||||
func (s *GroupService) Delete(ctx context.Context, id int64) error {
|
||||
// 检查分组是否存在
|
||||
_, err := s.groupRepo.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return ErrGroupNotFound
|
||||
}
|
||||
return fmt.Errorf("get group: %w", err)
|
||||
}
|
||||
|
||||
if err := s.groupRepo.Delete(ctx, id); err != nil {
|
||||
return fmt.Errorf("delete group: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetStats 获取分组统计信息
|
||||
func (s *GroupService) GetStats(ctx context.Context, id int64) (map[string]interface{}, error) {
|
||||
group, err := s.groupRepo.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, ErrGroupNotFound
|
||||
}
|
||||
return nil, fmt.Errorf("get group: %w", err)
|
||||
}
|
||||
|
||||
// 获取账号数量
|
||||
accountCount, err := s.groupRepo.GetAccountCount(ctx, id)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get account count: %w", err)
|
||||
}
|
||||
|
||||
stats := map[string]interface{}{
|
||||
"id": group.ID,
|
||||
"name": group.Name,
|
||||
"rate_multiplier": group.RateMultiplier,
|
||||
"is_exclusive": group.IsExclusive,
|
||||
"status": group.Status,
|
||||
"account_count": accountCount,
|
||||
}
|
||||
|
||||
return stats, nil
|
||||
}
|
||||
282
backend/internal/service/identity_service.go
Normal file
282
backend/internal/service/identity_service.go
Normal file
@@ -0,0 +1,282 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
const (
|
||||
// Redis key prefix
|
||||
identityFingerprintKey = "identity:fingerprint:"
|
||||
)
|
||||
|
||||
// 预编译正则表达式(避免每次调用重新编译)
|
||||
var (
|
||||
// 匹配 user_id 格式: user_{64位hex}_account__session_{uuid}
|
||||
userIDRegex = regexp.MustCompile(`^user_[a-f0-9]{64}_account__session_([a-f0-9-]{36})$`)
|
||||
// 匹配 User-Agent 版本号: xxx/x.y.z
|
||||
userAgentVersionRegex = regexp.MustCompile(`/(\d+)\.(\d+)\.(\d+)`)
|
||||
)
|
||||
|
||||
// Fingerprint 存储的指纹数据结构
|
||||
type Fingerprint struct {
|
||||
ClientID string `json:"client_id"` // 64位hex客户端ID(首次随机生成)
|
||||
UserAgent string `json:"user_agent"` // User-Agent
|
||||
StainlessLang string `json:"x_stainless_lang"` // x-stainless-lang
|
||||
StainlessPackageVersion string `json:"x_stainless_package_version"` // x-stainless-package-version
|
||||
StainlessOS string `json:"x_stainless_os"` // x-stainless-os
|
||||
StainlessArch string `json:"x_stainless_arch"` // x-stainless-arch
|
||||
StainlessRuntime string `json:"x_stainless_runtime"` // x-stainless-runtime
|
||||
StainlessRuntimeVersion string `json:"x_stainless_runtime_version"` // x-stainless-runtime-version
|
||||
}
|
||||
|
||||
// 默认指纹值(当客户端未提供时使用)
|
||||
var defaultFingerprint = Fingerprint{
|
||||
UserAgent: "claude-cli/2.0.62 (external, cli)",
|
||||
StainlessLang: "js",
|
||||
StainlessPackageVersion: "0.52.0",
|
||||
StainlessOS: "Linux",
|
||||
StainlessArch: "x64",
|
||||
StainlessRuntime: "node",
|
||||
StainlessRuntimeVersion: "v22.14.0",
|
||||
}
|
||||
|
||||
// IdentityService 管理OAuth账号的请求身份指纹
|
||||
type IdentityService struct {
|
||||
rdb *redis.Client
|
||||
}
|
||||
|
||||
// NewIdentityService 创建新的IdentityService
|
||||
func NewIdentityService(rdb *redis.Client) *IdentityService {
|
||||
return &IdentityService{rdb: rdb}
|
||||
}
|
||||
|
||||
// GetOrCreateFingerprint 获取或创建账号的指纹
|
||||
// 如果缓存存在,检测user-agent版本,新版本则更新
|
||||
// 如果缓存不存在,生成随机ClientID并从请求头创建指纹,然后缓存
|
||||
func (s *IdentityService) GetOrCreateFingerprint(ctx context.Context, accountID int64, headers http.Header) (*Fingerprint, error) {
|
||||
key := identityFingerprintKey + strconv.FormatInt(accountID, 10)
|
||||
|
||||
// 尝试从Redis获取缓存的指纹
|
||||
data, err := s.rdb.Get(ctx, key).Bytes()
|
||||
if err == nil && len(data) > 0 {
|
||||
// 缓存存在,解析指纹
|
||||
var cached Fingerprint
|
||||
if err := json.Unmarshal(data, &cached); err == nil {
|
||||
// 检查客户端的user-agent是否是更新版本
|
||||
clientUA := headers.Get("User-Agent")
|
||||
if clientUA != "" && isNewerVersion(clientUA, cached.UserAgent) {
|
||||
// 更新user-agent
|
||||
cached.UserAgent = clientUA
|
||||
// 保存更新后的指纹
|
||||
if newData, err := json.Marshal(cached); err == nil {
|
||||
s.rdb.Set(ctx, key, newData, 0) // 永不过期
|
||||
}
|
||||
log.Printf("Updated fingerprint user-agent for account %d: %s", accountID, clientUA)
|
||||
}
|
||||
return &cached, nil
|
||||
}
|
||||
}
|
||||
|
||||
// 缓存不存在或解析失败,创建新指纹
|
||||
fp := s.createFingerprintFromHeaders(headers)
|
||||
|
||||
// 生成随机ClientID
|
||||
fp.ClientID = generateClientID()
|
||||
|
||||
// 保存到Redis(永不过期)
|
||||
if data, err := json.Marshal(fp); err == nil {
|
||||
if err := s.rdb.Set(ctx, key, data, 0).Err(); err != nil {
|
||||
log.Printf("Warning: failed to cache fingerprint for account %d: %v", accountID, err)
|
||||
}
|
||||
}
|
||||
|
||||
log.Printf("Created new fingerprint for account %d with client_id: %s", accountID, fp.ClientID)
|
||||
return fp, nil
|
||||
}
|
||||
|
||||
// createFingerprintFromHeaders 从请求头创建指纹
|
||||
func (s *IdentityService) createFingerprintFromHeaders(headers http.Header) *Fingerprint {
|
||||
fp := &Fingerprint{}
|
||||
|
||||
// 获取User-Agent
|
||||
if ua := headers.Get("User-Agent"); ua != "" {
|
||||
fp.UserAgent = ua
|
||||
} else {
|
||||
fp.UserAgent = defaultFingerprint.UserAgent
|
||||
}
|
||||
|
||||
// 获取x-stainless-*头,如果没有则使用默认值
|
||||
fp.StainlessLang = getHeaderOrDefault(headers, "X-Stainless-Lang", defaultFingerprint.StainlessLang)
|
||||
fp.StainlessPackageVersion = getHeaderOrDefault(headers, "X-Stainless-Package-Version", defaultFingerprint.StainlessPackageVersion)
|
||||
fp.StainlessOS = getHeaderOrDefault(headers, "X-Stainless-OS", defaultFingerprint.StainlessOS)
|
||||
fp.StainlessArch = getHeaderOrDefault(headers, "X-Stainless-Arch", defaultFingerprint.StainlessArch)
|
||||
fp.StainlessRuntime = getHeaderOrDefault(headers, "X-Stainless-Runtime", defaultFingerprint.StainlessRuntime)
|
||||
fp.StainlessRuntimeVersion = getHeaderOrDefault(headers, "X-Stainless-Runtime-Version", defaultFingerprint.StainlessRuntimeVersion)
|
||||
|
||||
return fp
|
||||
}
|
||||
|
||||
// getHeaderOrDefault 获取header值,如果不存在则返回默认值
|
||||
func getHeaderOrDefault(headers http.Header, key, defaultValue string) string {
|
||||
if v := headers.Get(key); v != "" {
|
||||
return v
|
||||
}
|
||||
return defaultValue
|
||||
}
|
||||
|
||||
// ApplyFingerprint 将指纹应用到请求头(覆盖原有的x-stainless-*头)
|
||||
func (s *IdentityService) ApplyFingerprint(req *http.Request, fp *Fingerprint) {
|
||||
if fp == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// 设置User-Agent
|
||||
if fp.UserAgent != "" {
|
||||
req.Header.Set("User-Agent", fp.UserAgent)
|
||||
}
|
||||
|
||||
// 设置x-stainless-*头(使用正确的大小写)
|
||||
if fp.StainlessLang != "" {
|
||||
req.Header.Set("X-Stainless-Lang", fp.StainlessLang)
|
||||
}
|
||||
if fp.StainlessPackageVersion != "" {
|
||||
req.Header.Set("X-Stainless-Package-Version", fp.StainlessPackageVersion)
|
||||
}
|
||||
if fp.StainlessOS != "" {
|
||||
req.Header.Set("X-Stainless-OS", fp.StainlessOS)
|
||||
}
|
||||
if fp.StainlessArch != "" {
|
||||
req.Header.Set("X-Stainless-Arch", fp.StainlessArch)
|
||||
}
|
||||
if fp.StainlessRuntime != "" {
|
||||
req.Header.Set("X-Stainless-Runtime", fp.StainlessRuntime)
|
||||
}
|
||||
if fp.StainlessRuntimeVersion != "" {
|
||||
req.Header.Set("X-Stainless-Runtime-Version", fp.StainlessRuntimeVersion)
|
||||
}
|
||||
}
|
||||
|
||||
// RewriteUserID 重写body中的metadata.user_id
|
||||
// 输入格式:user_{clientId}_account__session_{sessionUUID}
|
||||
// 输出格式:user_{cachedClientID}_account_{accountUUID}_session_{newHash}
|
||||
func (s *IdentityService) RewriteUserID(body []byte, accountID int64, accountUUID, cachedClientID string) ([]byte, error) {
|
||||
if len(body) == 0 || accountUUID == "" || cachedClientID == "" {
|
||||
return body, nil
|
||||
}
|
||||
|
||||
// 解析JSON
|
||||
var reqMap map[string]interface{}
|
||||
if err := json.Unmarshal(body, &reqMap); err != nil {
|
||||
return body, nil
|
||||
}
|
||||
|
||||
metadata, ok := reqMap["metadata"].(map[string]interface{})
|
||||
if !ok {
|
||||
return body, nil
|
||||
}
|
||||
|
||||
userID, ok := metadata["user_id"].(string)
|
||||
if !ok || userID == "" {
|
||||
return body, nil
|
||||
}
|
||||
|
||||
// 匹配格式: user_{64位hex}_account__session_{uuid}
|
||||
matches := userIDRegex.FindStringSubmatch(userID)
|
||||
if matches == nil {
|
||||
return body, nil
|
||||
}
|
||||
|
||||
sessionTail := matches[1] // 原始session UUID
|
||||
|
||||
// 生成新的session hash: SHA256(accountID::sessionTail) -> UUID格式
|
||||
seed := fmt.Sprintf("%d::%s", accountID, sessionTail)
|
||||
newSessionHash := generateUUIDFromSeed(seed)
|
||||
|
||||
// 构建新的user_id
|
||||
// 格式: user_{cachedClientID}_account_{account_uuid}_session_{newSessionHash}
|
||||
newUserID := fmt.Sprintf("user_%s_account_%s_session_%s", cachedClientID, accountUUID, newSessionHash)
|
||||
|
||||
metadata["user_id"] = newUserID
|
||||
reqMap["metadata"] = metadata
|
||||
|
||||
return json.Marshal(reqMap)
|
||||
}
|
||||
|
||||
// generateClientID 生成64位十六进制客户端ID(32字节随机数)
|
||||
func generateClientID() string {
|
||||
b := make([]byte, 32)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
// 极罕见的情况,使用时间戳+固定值作为fallback
|
||||
log.Printf("Warning: crypto/rand.Read failed: %v, using fallback", err)
|
||||
// 使用SHA256(当前纳秒时间)作为fallback
|
||||
h := sha256.Sum256([]byte(fmt.Sprintf("%d", time.Now().UnixNano())))
|
||||
return hex.EncodeToString(h[:])
|
||||
}
|
||||
return hex.EncodeToString(b)
|
||||
}
|
||||
|
||||
// generateUUIDFromSeed 从种子生成确定性UUID v4格式字符串
|
||||
func generateUUIDFromSeed(seed string) string {
|
||||
hash := sha256.Sum256([]byte(seed))
|
||||
bytes := hash[:16]
|
||||
|
||||
// 设置UUID v4版本和变体位
|
||||
bytes[6] = (bytes[6] & 0x0f) | 0x40
|
||||
bytes[8] = (bytes[8] & 0x3f) | 0x80
|
||||
|
||||
return fmt.Sprintf("%x-%x-%x-%x-%x",
|
||||
bytes[0:4], bytes[4:6], bytes[6:8], bytes[8:10], bytes[10:16])
|
||||
}
|
||||
|
||||
// parseUserAgentVersion 解析user-agent版本号
|
||||
// 例如:claude-cli/2.0.62 -> (2, 0, 62)
|
||||
func parseUserAgentVersion(ua string) (major, minor, patch int, ok bool) {
|
||||
// 匹配 xxx/x.y.z 格式
|
||||
matches := userAgentVersionRegex.FindStringSubmatch(ua)
|
||||
if len(matches) != 4 {
|
||||
return 0, 0, 0, false
|
||||
}
|
||||
major, _ = strconv.Atoi(matches[1])
|
||||
minor, _ = strconv.Atoi(matches[2])
|
||||
patch, _ = strconv.Atoi(matches[3])
|
||||
return major, minor, patch, true
|
||||
}
|
||||
|
||||
// isNewerVersion 比较版本号,判断newUA是否比cachedUA更新
|
||||
func isNewerVersion(newUA, cachedUA string) bool {
|
||||
newMajor, newMinor, newPatch, newOk := parseUserAgentVersion(newUA)
|
||||
cachedMajor, cachedMinor, cachedPatch, cachedOk := parseUserAgentVersion(cachedUA)
|
||||
|
||||
if !newOk || !cachedOk {
|
||||
return false
|
||||
}
|
||||
|
||||
// 比较版本号
|
||||
if newMajor > cachedMajor {
|
||||
return true
|
||||
}
|
||||
if newMajor < cachedMajor {
|
||||
return false
|
||||
}
|
||||
|
||||
if newMinor > cachedMinor {
|
||||
return true
|
||||
}
|
||||
if newMinor < cachedMinor {
|
||||
return false
|
||||
}
|
||||
|
||||
return newPatch > cachedPatch
|
||||
}
|
||||
471
backend/internal/service/oauth_service.go
Normal file
471
backend/internal/service/oauth_service.go
Normal file
@@ -0,0 +1,471 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/pkg/oauth"
|
||||
"sub2api/internal/repository"
|
||||
|
||||
"github.com/imroc/req/v3"
|
||||
)
|
||||
|
||||
// OAuthService handles OAuth authentication flows
|
||||
type OAuthService struct {
|
||||
sessionStore *oauth.SessionStore
|
||||
proxyRepo *repository.ProxyRepository
|
||||
}
|
||||
|
||||
// NewOAuthService creates a new OAuth service
|
||||
func NewOAuthService(proxyRepo *repository.ProxyRepository) *OAuthService {
|
||||
return &OAuthService{
|
||||
sessionStore: oauth.NewSessionStore(),
|
||||
proxyRepo: proxyRepo,
|
||||
}
|
||||
}
|
||||
|
||||
// GenerateAuthURLResult contains the authorization URL and session info
|
||||
type GenerateAuthURLResult struct {
|
||||
AuthURL string `json:"auth_url"`
|
||||
SessionID string `json:"session_id"`
|
||||
}
|
||||
|
||||
// GenerateAuthURL generates an OAuth authorization URL with full scope
|
||||
func (s *OAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64) (*GenerateAuthURLResult, error) {
|
||||
scope := fmt.Sprintf("%s %s", oauth.ScopeProfile, oauth.ScopeInference)
|
||||
return s.generateAuthURLWithScope(ctx, scope, proxyID)
|
||||
}
|
||||
|
||||
// GenerateSetupTokenURL generates an OAuth authorization URL for setup token (inference only)
|
||||
func (s *OAuthService) GenerateSetupTokenURL(ctx context.Context, proxyID *int64) (*GenerateAuthURLResult, error) {
|
||||
scope := oauth.ScopeInference
|
||||
return s.generateAuthURLWithScope(ctx, scope, proxyID)
|
||||
}
|
||||
|
||||
func (s *OAuthService) generateAuthURLWithScope(ctx context.Context, scope string, proxyID *int64) (*GenerateAuthURLResult, error) {
|
||||
// Generate PKCE values
|
||||
state, err := oauth.GenerateState()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate state: %w", err)
|
||||
}
|
||||
|
||||
codeVerifier, err := oauth.GenerateCodeVerifier()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate code verifier: %w", err)
|
||||
}
|
||||
|
||||
codeChallenge := oauth.GenerateCodeChallenge(codeVerifier)
|
||||
|
||||
// Generate session ID
|
||||
sessionID, err := oauth.GenerateSessionID()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate session ID: %w", err)
|
||||
}
|
||||
|
||||
// Get proxy URL if specified
|
||||
var proxyURL string
|
||||
if proxyID != nil {
|
||||
proxy, err := s.proxyRepo.GetByID(ctx, *proxyID)
|
||||
if err == nil && proxy != nil {
|
||||
proxyURL = proxy.URL()
|
||||
}
|
||||
}
|
||||
|
||||
// Store session
|
||||
session := &oauth.OAuthSession{
|
||||
State: state,
|
||||
CodeVerifier: codeVerifier,
|
||||
Scope: scope,
|
||||
ProxyURL: proxyURL,
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
s.sessionStore.Set(sessionID, session)
|
||||
|
||||
// Build authorization URL
|
||||
authURL := oauth.BuildAuthorizationURL(state, codeChallenge, scope)
|
||||
|
||||
return &GenerateAuthURLResult{
|
||||
AuthURL: authURL,
|
||||
SessionID: sessionID,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ExchangeCodeInput represents the input for code exchange
|
||||
type ExchangeCodeInput struct {
|
||||
SessionID string
|
||||
Code string
|
||||
ProxyID *int64
|
||||
}
|
||||
|
||||
// TokenInfo represents the token information stored in credentials
|
||||
type TokenInfo struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
TokenType string `json:"token_type"`
|
||||
ExpiresIn int64 `json:"expires_in"`
|
||||
ExpiresAt int64 `json:"expires_at"`
|
||||
RefreshToken string `json:"refresh_token,omitempty"`
|
||||
Scope string `json:"scope,omitempty"`
|
||||
OrgUUID string `json:"org_uuid,omitempty"`
|
||||
AccountUUID string `json:"account_uuid,omitempty"`
|
||||
}
|
||||
|
||||
// ExchangeCode exchanges authorization code for tokens
|
||||
func (s *OAuthService) ExchangeCode(ctx context.Context, input *ExchangeCodeInput) (*TokenInfo, error) {
|
||||
// Get session
|
||||
session, ok := s.sessionStore.Get(input.SessionID)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("session not found or expired")
|
||||
}
|
||||
|
||||
// Get proxy URL
|
||||
proxyURL := session.ProxyURL
|
||||
if input.ProxyID != nil {
|
||||
proxy, err := s.proxyRepo.GetByID(ctx, *input.ProxyID)
|
||||
if err == nil && proxy != nil {
|
||||
proxyURL = proxy.URL()
|
||||
}
|
||||
}
|
||||
|
||||
// Exchange code for token
|
||||
tokenInfo, err := s.exchangeCodeForToken(ctx, input.Code, session.CodeVerifier, session.State, proxyURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Delete session after successful exchange
|
||||
s.sessionStore.Delete(input.SessionID)
|
||||
|
||||
return tokenInfo, nil
|
||||
}
|
||||
|
||||
// CookieAuthInput represents the input for cookie-based authentication
|
||||
type CookieAuthInput struct {
|
||||
SessionKey string
|
||||
ProxyID *int64
|
||||
Scope string // "full" or "inference"
|
||||
}
|
||||
|
||||
// CookieAuth performs OAuth using sessionKey (cookie-based auto-auth)
|
||||
func (s *OAuthService) CookieAuth(ctx context.Context, input *CookieAuthInput) (*TokenInfo, error) {
|
||||
// Get proxy URL if specified
|
||||
var proxyURL string
|
||||
if input.ProxyID != nil {
|
||||
proxy, err := s.proxyRepo.GetByID(ctx, *input.ProxyID)
|
||||
if err == nil && proxy != nil {
|
||||
proxyURL = proxy.URL()
|
||||
}
|
||||
}
|
||||
|
||||
// Determine scope
|
||||
scope := fmt.Sprintf("%s %s", oauth.ScopeProfile, oauth.ScopeInference)
|
||||
if input.Scope == "inference" {
|
||||
scope = oauth.ScopeInference
|
||||
}
|
||||
|
||||
// Step 1: Get organization info using sessionKey
|
||||
orgUUID, err := s.getOrganizationUUID(ctx, input.SessionKey, proxyURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get organization info: %w", err)
|
||||
}
|
||||
|
||||
// Step 2: Generate PKCE values
|
||||
codeVerifier, err := oauth.GenerateCodeVerifier()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate code verifier: %w", err)
|
||||
}
|
||||
codeChallenge := oauth.GenerateCodeChallenge(codeVerifier)
|
||||
|
||||
state, err := oauth.GenerateState()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate state: %w", err)
|
||||
}
|
||||
|
||||
// Step 3: Get authorization code using cookie
|
||||
authCode, err := s.getAuthorizationCode(ctx, input.SessionKey, orgUUID, scope, codeChallenge, state, proxyURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get authorization code: %w", err)
|
||||
}
|
||||
|
||||
// Step 4: Exchange code for token
|
||||
tokenInfo, err := s.exchangeCodeForToken(ctx, authCode, codeVerifier, state, proxyURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to exchange code: %w", err)
|
||||
}
|
||||
|
||||
// Ensure org_uuid is set (from step 1 if not from token response)
|
||||
if tokenInfo.OrgUUID == "" && orgUUID != "" {
|
||||
tokenInfo.OrgUUID = orgUUID
|
||||
log.Printf("[OAuth] Set org_uuid from cookie auth: %s", orgUUID)
|
||||
}
|
||||
|
||||
return tokenInfo, nil
|
||||
}
|
||||
|
||||
// getOrganizationUUID gets the organization UUID from claude.ai using sessionKey
|
||||
func (s *OAuthService) getOrganizationUUID(ctx context.Context, sessionKey, proxyURL string) (string, error) {
|
||||
client := s.createReqClient(proxyURL)
|
||||
|
||||
var orgs []struct {
|
||||
UUID string `json:"uuid"`
|
||||
}
|
||||
|
||||
targetURL := "https://claude.ai/api/organizations"
|
||||
log.Printf("[OAuth] Step 1: Getting organization UUID from %s", targetURL)
|
||||
|
||||
resp, err := client.R().
|
||||
SetContext(ctx).
|
||||
SetCookies(&http.Cookie{
|
||||
Name: "sessionKey",
|
||||
Value: sessionKey,
|
||||
}).
|
||||
SetSuccessResult(&orgs).
|
||||
Get(targetURL)
|
||||
|
||||
if err != nil {
|
||||
log.Printf("[OAuth] Step 1 FAILED - Request error: %v", err)
|
||||
return "", fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
|
||||
log.Printf("[OAuth] Step 1 Response - Status: %d, Body: %s", resp.StatusCode, resp.String())
|
||||
|
||||
if !resp.IsSuccessState() {
|
||||
return "", fmt.Errorf("failed to get organizations: status %d, body: %s", resp.StatusCode, resp.String())
|
||||
}
|
||||
|
||||
if len(orgs) == 0 {
|
||||
return "", fmt.Errorf("no organizations found")
|
||||
}
|
||||
|
||||
log.Printf("[OAuth] Step 1 SUCCESS - Got org UUID: %s", orgs[0].UUID)
|
||||
return orgs[0].UUID, nil
|
||||
}
|
||||
|
||||
// getAuthorizationCode gets the authorization code using sessionKey
|
||||
func (s *OAuthService) getAuthorizationCode(ctx context.Context, sessionKey, orgUUID, scope, codeChallenge, state, proxyURL string) (string, error) {
|
||||
client := s.createReqClient(proxyURL)
|
||||
|
||||
authURL := fmt.Sprintf("https://claude.ai/v1/oauth/%s/authorize", orgUUID)
|
||||
|
||||
// Build request body - must include organization_uuid as per CRS
|
||||
reqBody := map[string]interface{}{
|
||||
"response_type": "code",
|
||||
"client_id": oauth.ClientID,
|
||||
"organization_uuid": orgUUID, // Required field!
|
||||
"redirect_uri": oauth.RedirectURI,
|
||||
"scope": scope,
|
||||
"state": state,
|
||||
"code_challenge": codeChallenge,
|
||||
"code_challenge_method": "S256",
|
||||
}
|
||||
|
||||
reqBodyJSON, _ := json.Marshal(reqBody)
|
||||
log.Printf("[OAuth] Step 2: Getting authorization code from %s", authURL)
|
||||
log.Printf("[OAuth] Step 2 Request Body: %s", string(reqBodyJSON))
|
||||
|
||||
// Response contains redirect_uri with code, not direct code field
|
||||
var result struct {
|
||||
RedirectURI string `json:"redirect_uri"`
|
||||
}
|
||||
|
||||
resp, err := client.R().
|
||||
SetContext(ctx).
|
||||
SetCookies(&http.Cookie{
|
||||
Name: "sessionKey",
|
||||
Value: sessionKey,
|
||||
}).
|
||||
SetHeader("Accept", "application/json").
|
||||
SetHeader("Accept-Language", "en-US,en;q=0.9").
|
||||
SetHeader("Cache-Control", "no-cache").
|
||||
SetHeader("Origin", "https://claude.ai").
|
||||
SetHeader("Referer", "https://claude.ai/new").
|
||||
SetHeader("Content-Type", "application/json").
|
||||
SetBody(reqBody).
|
||||
SetSuccessResult(&result).
|
||||
Post(authURL)
|
||||
|
||||
if err != nil {
|
||||
log.Printf("[OAuth] Step 2 FAILED - Request error: %v", err)
|
||||
return "", fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
|
||||
log.Printf("[OAuth] Step 2 Response - Status: %d, Body: %s", resp.StatusCode, resp.String())
|
||||
|
||||
if !resp.IsSuccessState() {
|
||||
return "", fmt.Errorf("failed to get authorization code: status %d, body: %s", resp.StatusCode, resp.String())
|
||||
}
|
||||
|
||||
if result.RedirectURI == "" {
|
||||
return "", fmt.Errorf("no redirect_uri in response")
|
||||
}
|
||||
|
||||
// Parse redirect_uri to extract code and state
|
||||
parsedURL, err := url.Parse(result.RedirectURI)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to parse redirect_uri: %w", err)
|
||||
}
|
||||
|
||||
queryParams := parsedURL.Query()
|
||||
authCode := queryParams.Get("code")
|
||||
responseState := queryParams.Get("state")
|
||||
|
||||
if authCode == "" {
|
||||
return "", fmt.Errorf("no authorization code in redirect_uri")
|
||||
}
|
||||
|
||||
// Combine code with state if present (as CRS does)
|
||||
fullCode := authCode
|
||||
if responseState != "" {
|
||||
fullCode = authCode + "#" + responseState
|
||||
}
|
||||
|
||||
log.Printf("[OAuth] Step 2 SUCCESS - Got authorization code: %s...", authCode[:20])
|
||||
return fullCode, nil
|
||||
}
|
||||
|
||||
// exchangeCodeForToken exchanges authorization code for tokens
|
||||
func (s *OAuthService) exchangeCodeForToken(ctx context.Context, code, codeVerifier, state, proxyURL string) (*TokenInfo, error) {
|
||||
client := s.createReqClient(proxyURL)
|
||||
|
||||
// Parse code#state format if present
|
||||
authCode := code
|
||||
codeState := ""
|
||||
if parts := strings.Split(code, "#"); len(parts) > 1 {
|
||||
authCode = parts[0]
|
||||
codeState = parts[1]
|
||||
}
|
||||
|
||||
// Build JSON body as CRS does (not form data!)
|
||||
reqBody := map[string]interface{}{
|
||||
"code": authCode,
|
||||
"grant_type": "authorization_code",
|
||||
"client_id": oauth.ClientID,
|
||||
"redirect_uri": oauth.RedirectURI,
|
||||
"code_verifier": codeVerifier,
|
||||
}
|
||||
|
||||
// Add state if present
|
||||
if codeState != "" {
|
||||
reqBody["state"] = codeState
|
||||
}
|
||||
|
||||
reqBodyJSON, _ := json.Marshal(reqBody)
|
||||
log.Printf("[OAuth] Step 3: Exchanging code for token at %s", oauth.TokenURL)
|
||||
log.Printf("[OAuth] Step 3 Request Body: %s", string(reqBodyJSON))
|
||||
|
||||
var tokenResp oauth.TokenResponse
|
||||
|
||||
resp, err := client.R().
|
||||
SetContext(ctx).
|
||||
SetHeader("Content-Type", "application/json").
|
||||
SetBody(reqBody).
|
||||
SetSuccessResult(&tokenResp).
|
||||
Post(oauth.TokenURL)
|
||||
|
||||
if err != nil {
|
||||
log.Printf("[OAuth] Step 3 FAILED - Request error: %v", err)
|
||||
return nil, fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
|
||||
log.Printf("[OAuth] Step 3 Response - Status: %d, Body: %s", resp.StatusCode, resp.String())
|
||||
|
||||
if !resp.IsSuccessState() {
|
||||
return nil, fmt.Errorf("token exchange failed: status %d, body: %s", resp.StatusCode, resp.String())
|
||||
}
|
||||
|
||||
log.Printf("[OAuth] Step 3 SUCCESS - Got access token")
|
||||
|
||||
tokenInfo := &TokenInfo{
|
||||
AccessToken: tokenResp.AccessToken,
|
||||
TokenType: tokenResp.TokenType,
|
||||
ExpiresIn: tokenResp.ExpiresIn,
|
||||
ExpiresAt: time.Now().Unix() + tokenResp.ExpiresIn,
|
||||
RefreshToken: tokenResp.RefreshToken,
|
||||
Scope: tokenResp.Scope,
|
||||
}
|
||||
|
||||
// Extract org_uuid and account_uuid from response
|
||||
if tokenResp.Organization != nil && tokenResp.Organization.UUID != "" {
|
||||
tokenInfo.OrgUUID = tokenResp.Organization.UUID
|
||||
log.Printf("[OAuth] Got org_uuid: %s", tokenInfo.OrgUUID)
|
||||
}
|
||||
if tokenResp.Account != nil && tokenResp.Account.UUID != "" {
|
||||
tokenInfo.AccountUUID = tokenResp.Account.UUID
|
||||
log.Printf("[OAuth] Got account_uuid: %s", tokenInfo.AccountUUID)
|
||||
}
|
||||
|
||||
return tokenInfo, nil
|
||||
}
|
||||
|
||||
// RefreshToken refreshes an OAuth token
|
||||
func (s *OAuthService) RefreshToken(ctx context.Context, refreshToken string, proxyURL string) (*TokenInfo, error) {
|
||||
client := s.createReqClient(proxyURL)
|
||||
|
||||
formData := url.Values{}
|
||||
formData.Set("grant_type", "refresh_token")
|
||||
formData.Set("refresh_token", refreshToken)
|
||||
formData.Set("client_id", oauth.ClientID)
|
||||
|
||||
var tokenResp oauth.TokenResponse
|
||||
|
||||
resp, err := client.R().
|
||||
SetContext(ctx).
|
||||
SetFormDataFromValues(formData).
|
||||
SetSuccessResult(&tokenResp).
|
||||
Post(oauth.TokenURL)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
|
||||
if !resp.IsSuccessState() {
|
||||
return nil, fmt.Errorf("token refresh failed: status %d, body: %s", resp.StatusCode, resp.String())
|
||||
}
|
||||
|
||||
return &TokenInfo{
|
||||
AccessToken: tokenResp.AccessToken,
|
||||
TokenType: tokenResp.TokenType,
|
||||
ExpiresIn: tokenResp.ExpiresIn,
|
||||
ExpiresAt: time.Now().Unix() + tokenResp.ExpiresIn,
|
||||
RefreshToken: tokenResp.RefreshToken,
|
||||
Scope: tokenResp.Scope,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// RefreshAccountToken refreshes token for an account
|
||||
func (s *OAuthService) RefreshAccountToken(ctx context.Context, account *model.Account) (*TokenInfo, error) {
|
||||
refreshToken := account.GetCredential("refresh_token")
|
||||
if refreshToken == "" {
|
||||
return nil, fmt.Errorf("no refresh token available")
|
||||
}
|
||||
|
||||
var proxyURL string
|
||||
if account.ProxyID != nil {
|
||||
proxy, err := s.proxyRepo.GetByID(ctx, *account.ProxyID)
|
||||
if err == nil && proxy != nil {
|
||||
proxyURL = proxy.URL()
|
||||
}
|
||||
}
|
||||
|
||||
return s.RefreshToken(ctx, refreshToken, proxyURL)
|
||||
}
|
||||
|
||||
// createReqClient creates a req client with Chrome impersonation and optional proxy
|
||||
func (s *OAuthService) createReqClient(proxyURL string) *req.Client {
|
||||
client := req.C().
|
||||
ImpersonateChrome(). // Impersonate Chrome browser to bypass Cloudflare
|
||||
SetTimeout(60 * time.Second)
|
||||
|
||||
// Set proxy if specified
|
||||
if proxyURL != "" {
|
||||
client.SetProxyURL(proxyURL)
|
||||
}
|
||||
|
||||
return client
|
||||
}
|
||||
572
backend/internal/service/pricing_service.go
Normal file
572
backend/internal/service/pricing_service.go
Normal file
@@ -0,0 +1,572 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"sub2api/internal/config"
|
||||
)
|
||||
|
||||
// LiteLLMModelPricing LiteLLM价格数据结构
|
||||
// 只保留我们需要的字段,使用指针来处理可能缺失的值
|
||||
type LiteLLMModelPricing struct {
|
||||
InputCostPerToken float64 `json:"input_cost_per_token"`
|
||||
OutputCostPerToken float64 `json:"output_cost_per_token"`
|
||||
CacheCreationInputTokenCost float64 `json:"cache_creation_input_token_cost"`
|
||||
CacheReadInputTokenCost float64 `json:"cache_read_input_token_cost"`
|
||||
LiteLLMProvider string `json:"litellm_provider"`
|
||||
Mode string `json:"mode"`
|
||||
SupportsPromptCaching bool `json:"supports_prompt_caching"`
|
||||
}
|
||||
|
||||
// LiteLLMRawEntry 用于解析原始JSON数据
|
||||
type LiteLLMRawEntry struct {
|
||||
InputCostPerToken *float64 `json:"input_cost_per_token"`
|
||||
OutputCostPerToken *float64 `json:"output_cost_per_token"`
|
||||
CacheCreationInputTokenCost *float64 `json:"cache_creation_input_token_cost"`
|
||||
CacheReadInputTokenCost *float64 `json:"cache_read_input_token_cost"`
|
||||
LiteLLMProvider string `json:"litellm_provider"`
|
||||
Mode string `json:"mode"`
|
||||
SupportsPromptCaching bool `json:"supports_prompt_caching"`
|
||||
}
|
||||
|
||||
// PricingService 动态价格服务
|
||||
type PricingService struct {
|
||||
cfg *config.Config
|
||||
mu sync.RWMutex
|
||||
pricingData map[string]*LiteLLMModelPricing
|
||||
lastUpdated time.Time
|
||||
localHash string
|
||||
|
||||
// 停止信号
|
||||
stopCh chan struct{}
|
||||
wg sync.WaitGroup
|
||||
}
|
||||
|
||||
// NewPricingService 创建价格服务
|
||||
func NewPricingService(cfg *config.Config) *PricingService {
|
||||
s := &PricingService{
|
||||
cfg: cfg,
|
||||
pricingData: make(map[string]*LiteLLMModelPricing),
|
||||
stopCh: make(chan struct{}),
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
// Initialize 初始化价格服务
|
||||
func (s *PricingService) Initialize() error {
|
||||
// 确保数据目录存在
|
||||
if err := os.MkdirAll(s.cfg.Pricing.DataDir, 0755); err != nil {
|
||||
log.Printf("[Pricing] Failed to create data directory: %v", err)
|
||||
}
|
||||
|
||||
// 首次加载价格数据
|
||||
if err := s.checkAndUpdatePricing(); err != nil {
|
||||
log.Printf("[Pricing] Initial load failed, using fallback: %v", err)
|
||||
if err := s.useFallbackPricing(); err != nil {
|
||||
return fmt.Errorf("failed to load pricing data: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// 启动定时更新
|
||||
s.startUpdateScheduler()
|
||||
|
||||
log.Printf("[Pricing] Service initialized with %d models", len(s.pricingData))
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop 停止价格服务
|
||||
func (s *PricingService) Stop() {
|
||||
close(s.stopCh)
|
||||
s.wg.Wait()
|
||||
log.Println("[Pricing] Service stopped")
|
||||
}
|
||||
|
||||
// startUpdateScheduler 启动定时更新调度器
|
||||
func (s *PricingService) startUpdateScheduler() {
|
||||
// 定期检查哈希更新
|
||||
hashInterval := time.Duration(s.cfg.Pricing.HashCheckIntervalMinutes) * time.Minute
|
||||
if hashInterval < time.Minute {
|
||||
hashInterval = 10 * time.Minute
|
||||
}
|
||||
|
||||
s.wg.Add(1)
|
||||
go func() {
|
||||
defer s.wg.Done()
|
||||
ticker := time.NewTicker(hashInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
if err := s.syncWithRemote(); err != nil {
|
||||
log.Printf("[Pricing] Sync failed: %v", err)
|
||||
}
|
||||
case <-s.stopCh:
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
log.Printf("[Pricing] Update scheduler started (check every %v)", hashInterval)
|
||||
}
|
||||
|
||||
// checkAndUpdatePricing 检查并更新价格数据
|
||||
func (s *PricingService) checkAndUpdatePricing() error {
|
||||
pricingFile := s.getPricingFilePath()
|
||||
|
||||
// 检查本地文件是否存在
|
||||
if _, err := os.Stat(pricingFile); os.IsNotExist(err) {
|
||||
log.Println("[Pricing] Local pricing file not found, downloading...")
|
||||
return s.downloadPricingData()
|
||||
}
|
||||
|
||||
// 检查文件是否过期
|
||||
info, err := os.Stat(pricingFile)
|
||||
if err != nil {
|
||||
return s.downloadPricingData()
|
||||
}
|
||||
|
||||
fileAge := time.Since(info.ModTime())
|
||||
maxAge := time.Duration(s.cfg.Pricing.UpdateIntervalHours) * time.Hour
|
||||
|
||||
if fileAge > maxAge {
|
||||
log.Printf("[Pricing] Local file is %v old, updating...", fileAge.Round(time.Hour))
|
||||
if err := s.downloadPricingData(); err != nil {
|
||||
log.Printf("[Pricing] Download failed, using existing file: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// 加载本地文件
|
||||
return s.loadPricingData(pricingFile)
|
||||
}
|
||||
|
||||
// syncWithRemote 与远程同步(基于哈希校验)
|
||||
func (s *PricingService) syncWithRemote() error {
|
||||
pricingFile := s.getPricingFilePath()
|
||||
|
||||
// 计算本地文件哈希
|
||||
localHash, err := s.computeFileHash(pricingFile)
|
||||
if err != nil {
|
||||
log.Printf("[Pricing] Failed to compute local hash: %v", err)
|
||||
return s.downloadPricingData()
|
||||
}
|
||||
|
||||
// 如果配置了哈希URL,从远程获取哈希进行比对
|
||||
if s.cfg.Pricing.HashURL != "" {
|
||||
remoteHash, err := s.fetchRemoteHash()
|
||||
if err != nil {
|
||||
log.Printf("[Pricing] Failed to fetch remote hash: %v", err)
|
||||
return nil // 哈希获取失败不影响正常使用
|
||||
}
|
||||
|
||||
if remoteHash != localHash {
|
||||
log.Println("[Pricing] Remote hash differs, downloading new version...")
|
||||
return s.downloadPricingData()
|
||||
}
|
||||
log.Println("[Pricing] Hash check passed, no update needed")
|
||||
return nil
|
||||
}
|
||||
|
||||
// 没有哈希URL时,基于时间检查
|
||||
info, err := os.Stat(pricingFile)
|
||||
if err != nil {
|
||||
return s.downloadPricingData()
|
||||
}
|
||||
|
||||
fileAge := time.Since(info.ModTime())
|
||||
maxAge := time.Duration(s.cfg.Pricing.UpdateIntervalHours) * time.Hour
|
||||
|
||||
if fileAge > maxAge {
|
||||
log.Printf("[Pricing] File is %v old, downloading...", fileAge.Round(time.Hour))
|
||||
return s.downloadPricingData()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// downloadPricingData 从远程下载价格数据
|
||||
func (s *PricingService) downloadPricingData() error {
|
||||
log.Printf("[Pricing] Downloading from %s", s.cfg.Pricing.RemoteURL)
|
||||
|
||||
client := &http.Client{Timeout: 30 * time.Second}
|
||||
resp, err := client.Get(s.cfg.Pricing.RemoteURL)
|
||||
if err != nil {
|
||||
return fmt.Errorf("download failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return fmt.Errorf("download failed: HTTP %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("read response failed: %w", err)
|
||||
}
|
||||
|
||||
// 解析JSON数据(使用灵活的解析方式)
|
||||
data, err := s.parsePricingData(body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parse pricing data: %w", err)
|
||||
}
|
||||
|
||||
// 保存到本地文件
|
||||
pricingFile := s.getPricingFilePath()
|
||||
if err := os.WriteFile(pricingFile, body, 0644); err != nil {
|
||||
log.Printf("[Pricing] Failed to save file: %v", err)
|
||||
}
|
||||
|
||||
// 保存哈希
|
||||
hash := sha256.Sum256(body)
|
||||
hashStr := hex.EncodeToString(hash[:])
|
||||
hashFile := s.getHashFilePath()
|
||||
if err := os.WriteFile(hashFile, []byte(hashStr+"\n"), 0644); err != nil {
|
||||
log.Printf("[Pricing] Failed to save hash: %v", err)
|
||||
}
|
||||
|
||||
// 更新内存数据
|
||||
s.mu.Lock()
|
||||
s.pricingData = data
|
||||
s.lastUpdated = time.Now()
|
||||
s.localHash = hashStr
|
||||
s.mu.Unlock()
|
||||
|
||||
log.Printf("[Pricing] Downloaded %d models successfully", len(data))
|
||||
return nil
|
||||
}
|
||||
|
||||
// parsePricingData 解析价格数据(处理各种格式)
|
||||
func (s *PricingService) parsePricingData(body []byte) (map[string]*LiteLLMModelPricing, error) {
|
||||
// 首先解析为 map[string]json.RawMessage
|
||||
var rawData map[string]json.RawMessage
|
||||
if err := json.Unmarshal(body, &rawData); err != nil {
|
||||
return nil, fmt.Errorf("parse raw JSON: %w", err)
|
||||
}
|
||||
|
||||
result := make(map[string]*LiteLLMModelPricing)
|
||||
skipped := 0
|
||||
|
||||
for modelName, rawEntry := range rawData {
|
||||
// 跳过 sample_spec 等文档条目
|
||||
if modelName == "sample_spec" {
|
||||
continue
|
||||
}
|
||||
|
||||
// 尝试解析每个条目
|
||||
var entry LiteLLMRawEntry
|
||||
if err := json.Unmarshal(rawEntry, &entry); err != nil {
|
||||
skipped++
|
||||
continue
|
||||
}
|
||||
|
||||
// 只保留有有效价格的条目
|
||||
if entry.InputCostPerToken == nil && entry.OutputCostPerToken == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
pricing := &LiteLLMModelPricing{
|
||||
LiteLLMProvider: entry.LiteLLMProvider,
|
||||
Mode: entry.Mode,
|
||||
SupportsPromptCaching: entry.SupportsPromptCaching,
|
||||
}
|
||||
|
||||
if entry.InputCostPerToken != nil {
|
||||
pricing.InputCostPerToken = *entry.InputCostPerToken
|
||||
}
|
||||
if entry.OutputCostPerToken != nil {
|
||||
pricing.OutputCostPerToken = *entry.OutputCostPerToken
|
||||
}
|
||||
if entry.CacheCreationInputTokenCost != nil {
|
||||
pricing.CacheCreationInputTokenCost = *entry.CacheCreationInputTokenCost
|
||||
}
|
||||
if entry.CacheReadInputTokenCost != nil {
|
||||
pricing.CacheReadInputTokenCost = *entry.CacheReadInputTokenCost
|
||||
}
|
||||
|
||||
result[modelName] = pricing
|
||||
}
|
||||
|
||||
if skipped > 0 {
|
||||
log.Printf("[Pricing] Skipped %d invalid entries", skipped)
|
||||
}
|
||||
|
||||
if len(result) == 0 {
|
||||
return nil, fmt.Errorf("no valid pricing entries found")
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// loadPricingData 从本地文件加载价格数据
|
||||
func (s *PricingService) loadPricingData(filePath string) error {
|
||||
data, err := os.ReadFile(filePath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("read file failed: %w", err)
|
||||
}
|
||||
|
||||
// 使用灵活的解析方式
|
||||
pricingData, err := s.parsePricingData(data)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parse pricing data: %w", err)
|
||||
}
|
||||
|
||||
// 计算哈希
|
||||
hash := sha256.Sum256(data)
|
||||
hashStr := hex.EncodeToString(hash[:])
|
||||
|
||||
s.mu.Lock()
|
||||
s.pricingData = pricingData
|
||||
s.localHash = hashStr
|
||||
|
||||
info, _ := os.Stat(filePath)
|
||||
if info != nil {
|
||||
s.lastUpdated = info.ModTime()
|
||||
} else {
|
||||
s.lastUpdated = time.Now()
|
||||
}
|
||||
s.mu.Unlock()
|
||||
|
||||
log.Printf("[Pricing] Loaded %d models from %s", len(pricingData), filePath)
|
||||
return nil
|
||||
}
|
||||
|
||||
// useFallbackPricing 使用回退价格文件
|
||||
func (s *PricingService) useFallbackPricing() error {
|
||||
fallbackFile := s.cfg.Pricing.FallbackFile
|
||||
|
||||
if _, err := os.Stat(fallbackFile); os.IsNotExist(err) {
|
||||
return fmt.Errorf("fallback file not found: %s", fallbackFile)
|
||||
}
|
||||
|
||||
log.Printf("[Pricing] Using fallback file: %s", fallbackFile)
|
||||
|
||||
// 复制到数据目录
|
||||
data, err := os.ReadFile(fallbackFile)
|
||||
if err != nil {
|
||||
return fmt.Errorf("read fallback failed: %w", err)
|
||||
}
|
||||
|
||||
pricingFile := s.getPricingFilePath()
|
||||
if err := os.WriteFile(pricingFile, data, 0644); err != nil {
|
||||
log.Printf("[Pricing] Failed to copy fallback: %v", err)
|
||||
}
|
||||
|
||||
return s.loadPricingData(fallbackFile)
|
||||
}
|
||||
|
||||
// fetchRemoteHash 从远程获取哈希值
|
||||
func (s *PricingService) fetchRemoteHash() (string, error) {
|
||||
client := &http.Client{Timeout: 10 * time.Second}
|
||||
resp, err := client.Get(s.cfg.Pricing.HashURL)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return "", fmt.Errorf("HTTP %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// 哈希文件格式:hash filename 或者纯 hash
|
||||
hash := strings.TrimSpace(string(body))
|
||||
parts := strings.Fields(hash)
|
||||
if len(parts) > 0 {
|
||||
return parts[0], nil
|
||||
}
|
||||
return hash, nil
|
||||
}
|
||||
|
||||
// computeFileHash 计算文件哈希
|
||||
func (s *PricingService) computeFileHash(filePath string) (string, error) {
|
||||
data, err := os.ReadFile(filePath)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
hash := sha256.Sum256(data)
|
||||
return hex.EncodeToString(hash[:]), nil
|
||||
}
|
||||
|
||||
// GetModelPricing 获取模型价格(带模糊匹配)
|
||||
func (s *PricingService) GetModelPricing(modelName string) *LiteLLMModelPricing {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
if modelName == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 标准化模型名称
|
||||
modelLower := strings.ToLower(modelName)
|
||||
|
||||
// 1. 精确匹配
|
||||
if pricing, ok := s.pricingData[modelLower]; ok {
|
||||
return pricing
|
||||
}
|
||||
if pricing, ok := s.pricingData[modelName]; ok {
|
||||
return pricing
|
||||
}
|
||||
|
||||
// 2. 处理常见的模型名称变体
|
||||
// claude-opus-4-5-20251101 -> claude-opus-4.5-20251101
|
||||
normalized := strings.ReplaceAll(modelLower, "-4-5-", "-4.5-")
|
||||
if pricing, ok := s.pricingData[normalized]; ok {
|
||||
return pricing
|
||||
}
|
||||
|
||||
// 3. 尝试模糊匹配(去掉版本号后缀)
|
||||
// claude-opus-4-5-20251101 -> claude-opus-4.5
|
||||
baseName := s.extractBaseName(modelLower)
|
||||
for key, pricing := range s.pricingData {
|
||||
keyBase := s.extractBaseName(strings.ToLower(key))
|
||||
if keyBase == baseName {
|
||||
return pricing
|
||||
}
|
||||
}
|
||||
|
||||
// 4. 基于模型系列匹配
|
||||
return s.matchByModelFamily(modelLower)
|
||||
}
|
||||
|
||||
// extractBaseName 提取基础模型名称(去掉日期版本号)
|
||||
func (s *PricingService) extractBaseName(model string) string {
|
||||
// 移除日期后缀 (如 -20251101, -20241022)
|
||||
parts := strings.Split(model, "-")
|
||||
result := make([]string, 0, len(parts))
|
||||
for _, part := range parts {
|
||||
// 跳过看起来像日期的部分(8位数字)
|
||||
if len(part) == 8 && isNumeric(part) {
|
||||
continue
|
||||
}
|
||||
// 跳过版本号(如 v1:0)
|
||||
if strings.Contains(part, ":") {
|
||||
continue
|
||||
}
|
||||
result = append(result, part)
|
||||
}
|
||||
return strings.Join(result, "-")
|
||||
}
|
||||
|
||||
// matchByModelFamily 基于模型系列匹配
|
||||
func (s *PricingService) matchByModelFamily(model string) *LiteLLMModelPricing {
|
||||
// Claude模型系列匹配规则
|
||||
familyPatterns := map[string][]string{
|
||||
"opus-4.5": {"claude-opus-4.5", "claude-opus-4-5"},
|
||||
"opus-4": {"claude-opus-4", "claude-3-opus"},
|
||||
"sonnet-4.5": {"claude-sonnet-4.5", "claude-sonnet-4-5"},
|
||||
"sonnet-4": {"claude-sonnet-4", "claude-3-5-sonnet"},
|
||||
"sonnet-3.5": {"claude-3-5-sonnet", "claude-3.5-sonnet"},
|
||||
"sonnet-3": {"claude-3-sonnet"},
|
||||
"haiku-3.5": {"claude-3-5-haiku", "claude-3.5-haiku"},
|
||||
"haiku-3": {"claude-3-haiku"},
|
||||
}
|
||||
|
||||
// 确定模型属于哪个系列
|
||||
var matchedFamily string
|
||||
for family, patterns := range familyPatterns {
|
||||
for _, pattern := range patterns {
|
||||
if strings.Contains(model, pattern) || strings.Contains(model, strings.ReplaceAll(pattern, "-", "")) {
|
||||
matchedFamily = family
|
||||
break
|
||||
}
|
||||
}
|
||||
if matchedFamily != "" {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if matchedFamily == "" {
|
||||
// 简单的系列匹配
|
||||
if strings.Contains(model, "opus") {
|
||||
if strings.Contains(model, "4.5") || strings.Contains(model, "4-5") {
|
||||
matchedFamily = "opus-4.5"
|
||||
} else {
|
||||
matchedFamily = "opus-4"
|
||||
}
|
||||
} else if strings.Contains(model, "sonnet") {
|
||||
if strings.Contains(model, "4.5") || strings.Contains(model, "4-5") {
|
||||
matchedFamily = "sonnet-4.5"
|
||||
} else if strings.Contains(model, "3-5") || strings.Contains(model, "3.5") {
|
||||
matchedFamily = "sonnet-3.5"
|
||||
} else {
|
||||
matchedFamily = "sonnet-4"
|
||||
}
|
||||
} else if strings.Contains(model, "haiku") {
|
||||
if strings.Contains(model, "3-5") || strings.Contains(model, "3.5") {
|
||||
matchedFamily = "haiku-3.5"
|
||||
} else {
|
||||
matchedFamily = "haiku-3"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if matchedFamily == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 在价格数据中查找该系列的模型
|
||||
patterns := familyPatterns[matchedFamily]
|
||||
for _, pattern := range patterns {
|
||||
for key, pricing := range s.pricingData {
|
||||
keyLower := strings.ToLower(key)
|
||||
if strings.Contains(keyLower, pattern) {
|
||||
log.Printf("[Pricing] Fuzzy matched %s -> %s", model, key)
|
||||
return pricing
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetStatus 获取服务状态
|
||||
func (s *PricingService) GetStatus() map[string]interface{} {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
return map[string]interface{}{
|
||||
"model_count": len(s.pricingData),
|
||||
"last_updated": s.lastUpdated,
|
||||
"local_hash": s.localHash[:min(8, len(s.localHash))],
|
||||
}
|
||||
}
|
||||
|
||||
// ForceUpdate 强制更新
|
||||
func (s *PricingService) ForceUpdate() error {
|
||||
return s.downloadPricingData()
|
||||
}
|
||||
|
||||
// getPricingFilePath 获取价格文件路径
|
||||
func (s *PricingService) getPricingFilePath() string {
|
||||
return filepath.Join(s.cfg.Pricing.DataDir, "model_pricing.json")
|
||||
}
|
||||
|
||||
// getHashFilePath 获取哈希文件路径
|
||||
func (s *PricingService) getHashFilePath() string {
|
||||
return filepath.Join(s.cfg.Pricing.DataDir, "model_pricing.sha256")
|
||||
}
|
||||
|
||||
// isNumeric 检查字符串是否为纯数字
|
||||
func isNumeric(s string) bool {
|
||||
for _, c := range s {
|
||||
if c < '0' || c > '9' {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
192
backend/internal/service/proxy_service.go
Normal file
192
backend/internal/service/proxy_service.go
Normal file
@@ -0,0 +1,192 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/repository"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrProxyNotFound = errors.New("proxy not found")
|
||||
)
|
||||
|
||||
// CreateProxyRequest 创建代理请求
|
||||
type CreateProxyRequest struct {
|
||||
Name string `json:"name"`
|
||||
Protocol string `json:"protocol"`
|
||||
Host string `json:"host"`
|
||||
Port int `json:"port"`
|
||||
Username string `json:"username"`
|
||||
Password string `json:"password"`
|
||||
}
|
||||
|
||||
// UpdateProxyRequest 更新代理请求
|
||||
type UpdateProxyRequest struct {
|
||||
Name *string `json:"name"`
|
||||
Protocol *string `json:"protocol"`
|
||||
Host *string `json:"host"`
|
||||
Port *int `json:"port"`
|
||||
Username *string `json:"username"`
|
||||
Password *string `json:"password"`
|
||||
Status *string `json:"status"`
|
||||
}
|
||||
|
||||
// ProxyService 代理管理服务
|
||||
type ProxyService struct {
|
||||
proxyRepo *repository.ProxyRepository
|
||||
}
|
||||
|
||||
// NewProxyService 创建代理服务实例
|
||||
func NewProxyService(proxyRepo *repository.ProxyRepository) *ProxyService {
|
||||
return &ProxyService{
|
||||
proxyRepo: proxyRepo,
|
||||
}
|
||||
}
|
||||
|
||||
// Create 创建代理
|
||||
func (s *ProxyService) Create(ctx context.Context, req CreateProxyRequest) (*model.Proxy, error) {
|
||||
// 创建代理
|
||||
proxy := &model.Proxy{
|
||||
Name: req.Name,
|
||||
Protocol: req.Protocol,
|
||||
Host: req.Host,
|
||||
Port: req.Port,
|
||||
Username: req.Username,
|
||||
Password: req.Password,
|
||||
Status: model.StatusActive,
|
||||
}
|
||||
|
||||
if err := s.proxyRepo.Create(ctx, proxy); err != nil {
|
||||
return nil, fmt.Errorf("create proxy: %w", err)
|
||||
}
|
||||
|
||||
return proxy, nil
|
||||
}
|
||||
|
||||
// GetByID 根据ID获取代理
|
||||
func (s *ProxyService) GetByID(ctx context.Context, id int64) (*model.Proxy, error) {
|
||||
proxy, err := s.proxyRepo.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, ErrProxyNotFound
|
||||
}
|
||||
return nil, fmt.Errorf("get proxy: %w", err)
|
||||
}
|
||||
return proxy, nil
|
||||
}
|
||||
|
||||
// List 获取代理列表
|
||||
func (s *ProxyService) List(ctx context.Context, params repository.PaginationParams) ([]model.Proxy, *repository.PaginationResult, error) {
|
||||
proxies, pagination, err := s.proxyRepo.List(ctx, params)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("list proxies: %w", err)
|
||||
}
|
||||
return proxies, pagination, nil
|
||||
}
|
||||
|
||||
// ListActive 获取活跃代理列表
|
||||
func (s *ProxyService) ListActive(ctx context.Context) ([]model.Proxy, error) {
|
||||
proxies, err := s.proxyRepo.ListActive(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list active proxies: %w", err)
|
||||
}
|
||||
return proxies, nil
|
||||
}
|
||||
|
||||
// Update 更新代理
|
||||
func (s *ProxyService) Update(ctx context.Context, id int64, req UpdateProxyRequest) (*model.Proxy, error) {
|
||||
proxy, err := s.proxyRepo.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, ErrProxyNotFound
|
||||
}
|
||||
return nil, fmt.Errorf("get proxy: %w", err)
|
||||
}
|
||||
|
||||
// 更新字段
|
||||
if req.Name != nil {
|
||||
proxy.Name = *req.Name
|
||||
}
|
||||
|
||||
if req.Protocol != nil {
|
||||
proxy.Protocol = *req.Protocol
|
||||
}
|
||||
|
||||
if req.Host != nil {
|
||||
proxy.Host = *req.Host
|
||||
}
|
||||
|
||||
if req.Port != nil {
|
||||
proxy.Port = *req.Port
|
||||
}
|
||||
|
||||
if req.Username != nil {
|
||||
proxy.Username = *req.Username
|
||||
}
|
||||
|
||||
if req.Password != nil {
|
||||
proxy.Password = *req.Password
|
||||
}
|
||||
|
||||
if req.Status != nil {
|
||||
proxy.Status = *req.Status
|
||||
}
|
||||
|
||||
if err := s.proxyRepo.Update(ctx, proxy); err != nil {
|
||||
return nil, fmt.Errorf("update proxy: %w", err)
|
||||
}
|
||||
|
||||
return proxy, nil
|
||||
}
|
||||
|
||||
// Delete 删除代理
|
||||
func (s *ProxyService) Delete(ctx context.Context, id int64) error {
|
||||
// 检查代理是否存在
|
||||
_, err := s.proxyRepo.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return ErrProxyNotFound
|
||||
}
|
||||
return fmt.Errorf("get proxy: %w", err)
|
||||
}
|
||||
|
||||
if err := s.proxyRepo.Delete(ctx, id); err != nil {
|
||||
return fmt.Errorf("delete proxy: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// TestConnection 测试代理连接(需要实现具体测试逻辑)
|
||||
func (s *ProxyService) TestConnection(ctx context.Context, id int64) error {
|
||||
proxy, err := s.proxyRepo.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return ErrProxyNotFound
|
||||
}
|
||||
return fmt.Errorf("get proxy: %w", err)
|
||||
}
|
||||
|
||||
// TODO: 实现代理连接测试逻辑
|
||||
// 可以尝试通过代理发送测试请求
|
||||
_ = proxy
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetURL 获取代理URL
|
||||
func (s *ProxyService) GetURL(ctx context.Context, id int64) (string, error) {
|
||||
proxy, err := s.proxyRepo.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return "", ErrProxyNotFound
|
||||
}
|
||||
return "", fmt.Errorf("get proxy: %w", err)
|
||||
}
|
||||
|
||||
return proxy.URL(), nil
|
||||
}
|
||||
170
backend/internal/service/ratelimit_service.go
Normal file
170
backend/internal/service/ratelimit_service.go
Normal file
@@ -0,0 +1,170 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"sub2api/internal/config"
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/repository"
|
||||
)
|
||||
|
||||
// RateLimitService 处理限流和过载状态管理
|
||||
type RateLimitService struct {
|
||||
repos *repository.Repositories
|
||||
cfg *config.Config
|
||||
}
|
||||
|
||||
// NewRateLimitService 创建RateLimitService实例
|
||||
func NewRateLimitService(repos *repository.Repositories, cfg *config.Config) *RateLimitService {
|
||||
return &RateLimitService{
|
||||
repos: repos,
|
||||
cfg: cfg,
|
||||
}
|
||||
}
|
||||
|
||||
// HandleUpstreamError 处理上游错误响应,标记账号状态
|
||||
// 返回是否应该停止该账号的调度
|
||||
func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *model.Account, statusCode int, headers http.Header, responseBody []byte) (shouldDisable bool) {
|
||||
// apikey 类型账号:检查自定义错误码配置
|
||||
// 如果启用且错误码不在列表中,则不处理(不停止调度、不标记限流/过载)
|
||||
if !account.ShouldHandleErrorCode(statusCode) {
|
||||
log.Printf("Account %d: error %d skipped (not in custom error codes)", account.ID, statusCode)
|
||||
return false
|
||||
}
|
||||
|
||||
switch statusCode {
|
||||
case 401:
|
||||
// 认证失败:停止调度,记录错误
|
||||
s.handleAuthError(ctx, account, "Authentication failed (401): invalid or expired credentials")
|
||||
return true
|
||||
case 403:
|
||||
// 禁止访问:停止调度,记录错误
|
||||
s.handleAuthError(ctx, account, "Access forbidden (403): account may be suspended or lack permissions")
|
||||
return true
|
||||
case 429:
|
||||
s.handle429(ctx, account, headers)
|
||||
return false
|
||||
case 529:
|
||||
s.handle529(ctx, account)
|
||||
return false
|
||||
default:
|
||||
// 其他5xx错误:记录但不停止调度
|
||||
if statusCode >= 500 {
|
||||
log.Printf("Account %d received upstream error %d", account.ID, statusCode)
|
||||
}
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// handleAuthError 处理认证类错误(401/403),停止账号调度
|
||||
func (s *RateLimitService) handleAuthError(ctx context.Context, account *model.Account, errorMsg string) {
|
||||
if err := s.repos.Account.SetError(ctx, account.ID, errorMsg); err != nil {
|
||||
log.Printf("SetError failed for account %d: %v", account.ID, err)
|
||||
return
|
||||
}
|
||||
log.Printf("Account %d disabled due to auth error: %s", account.ID, errorMsg)
|
||||
}
|
||||
|
||||
// handle429 处理429限流错误
|
||||
// 解析响应头获取重置时间,标记账号为限流状态
|
||||
func (s *RateLimitService) handle429(ctx context.Context, account *model.Account, headers http.Header) {
|
||||
// 解析重置时间戳
|
||||
resetTimestamp := headers.Get("anthropic-ratelimit-unified-reset")
|
||||
if resetTimestamp == "" {
|
||||
// 没有重置时间,使用默认5分钟
|
||||
resetAt := time.Now().Add(5 * time.Minute)
|
||||
if err := s.repos.Account.SetRateLimited(ctx, account.ID, resetAt); err != nil {
|
||||
log.Printf("SetRateLimited failed for account %d: %v", account.ID, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// 解析Unix时间戳
|
||||
ts, err := strconv.ParseInt(resetTimestamp, 10, 64)
|
||||
if err != nil {
|
||||
log.Printf("Parse reset timestamp failed: %v", err)
|
||||
resetAt := time.Now().Add(5 * time.Minute)
|
||||
if err := s.repos.Account.SetRateLimited(ctx, account.ID, resetAt); err != nil {
|
||||
log.Printf("SetRateLimited failed for account %d: %v", account.ID, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
resetAt := time.Unix(ts, 0)
|
||||
|
||||
// 标记限流状态
|
||||
if err := s.repos.Account.SetRateLimited(ctx, account.ID, resetAt); err != nil {
|
||||
log.Printf("SetRateLimited failed for account %d: %v", account.ID, err)
|
||||
return
|
||||
}
|
||||
|
||||
// 根据重置时间反推5h窗口
|
||||
windowEnd := resetAt
|
||||
windowStart := resetAt.Add(-5 * time.Hour)
|
||||
if err := s.repos.Account.UpdateSessionWindow(ctx, account.ID, &windowStart, &windowEnd, "rejected"); err != nil {
|
||||
log.Printf("UpdateSessionWindow failed for account %d: %v", account.ID, err)
|
||||
}
|
||||
|
||||
log.Printf("Account %d rate limited until %v", account.ID, resetAt)
|
||||
}
|
||||
|
||||
// handle529 处理529过载错误
|
||||
// 根据配置设置过载冷却时间
|
||||
func (s *RateLimitService) handle529(ctx context.Context, account *model.Account) {
|
||||
cooldownMinutes := s.cfg.RateLimit.OverloadCooldownMinutes
|
||||
if cooldownMinutes <= 0 {
|
||||
cooldownMinutes = 10 // 默认10分钟
|
||||
}
|
||||
|
||||
until := time.Now().Add(time.Duration(cooldownMinutes) * time.Minute)
|
||||
if err := s.repos.Account.SetOverloaded(ctx, account.ID, until); err != nil {
|
||||
log.Printf("SetOverloaded failed for account %d: %v", account.ID, err)
|
||||
return
|
||||
}
|
||||
|
||||
log.Printf("Account %d overloaded until %v", account.ID, until)
|
||||
}
|
||||
|
||||
// UpdateSessionWindow 从成功响应更新5h窗口状态
|
||||
func (s *RateLimitService) UpdateSessionWindow(ctx context.Context, account *model.Account, headers http.Header) {
|
||||
status := headers.Get("anthropic-ratelimit-unified-5h-status")
|
||||
if status == "" {
|
||||
return
|
||||
}
|
||||
|
||||
// 检查是否需要初始化时间窗口
|
||||
// 对于 Setup Token 账号,首次成功请求时需要预测时间窗口
|
||||
var windowStart, windowEnd *time.Time
|
||||
needInitWindow := account.SessionWindowEnd == nil || time.Now().After(*account.SessionWindowEnd)
|
||||
|
||||
if needInitWindow && (status == "allowed" || status == "allowed_warning") {
|
||||
// 预测时间窗口:从当前时间的整点开始,+5小时为结束
|
||||
// 例如:现在是 14:30,窗口为 14:00 ~ 19:00
|
||||
now := time.Now()
|
||||
start := time.Date(now.Year(), now.Month(), now.Day(), now.Hour(), 0, 0, 0, now.Location())
|
||||
end := start.Add(5 * time.Hour)
|
||||
windowStart = &start
|
||||
windowEnd = &end
|
||||
log.Printf("Account %d: initializing 5h window from %v to %v (status: %s)", account.ID, start, end, status)
|
||||
}
|
||||
|
||||
if err := s.repos.Account.UpdateSessionWindow(ctx, account.ID, windowStart, windowEnd, status); err != nil {
|
||||
log.Printf("UpdateSessionWindow failed for account %d: %v", account.ID, err)
|
||||
}
|
||||
|
||||
// 如果状态为allowed且之前有限流,说明窗口已重置,清除限流状态
|
||||
if status == "allowed" && account.IsRateLimited() {
|
||||
if err := s.repos.Account.ClearRateLimit(ctx, account.ID); err != nil {
|
||||
log.Printf("ClearRateLimit failed for account %d: %v", account.ID, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ClearRateLimit 清除账号的限流状态
|
||||
func (s *RateLimitService) ClearRateLimit(ctx context.Context, accountID int64) error {
|
||||
return s.repos.Account.ClearRateLimit(ctx, accountID)
|
||||
}
|
||||
392
backend/internal/service/redeem_service.go
Normal file
392
backend/internal/service/redeem_service.go
Normal file
@@ -0,0 +1,392 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/repository"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrRedeemCodeNotFound = errors.New("redeem code not found")
|
||||
ErrRedeemCodeUsed = errors.New("redeem code already used")
|
||||
ErrRedeemCodeInvalid = errors.New("invalid redeem code")
|
||||
ErrInsufficientBalance = errors.New("insufficient balance")
|
||||
ErrRedeemRateLimited = errors.New("too many failed attempts, please try again later")
|
||||
ErrRedeemCodeLocked = errors.New("redeem code is being processed, please try again")
|
||||
)
|
||||
|
||||
const (
|
||||
redeemRateLimitKeyPrefix = "redeem:rate_limit:"
|
||||
redeemLockKeyPrefix = "redeem:lock:"
|
||||
redeemMaxErrorsPerHour = 20
|
||||
redeemRateLimitDuration = time.Hour
|
||||
redeemLockDuration = 10 * time.Second // 锁超时时间,防止死锁
|
||||
)
|
||||
|
||||
// GenerateCodesRequest 生成兑换码请求
|
||||
type GenerateCodesRequest struct {
|
||||
Count int `json:"count"`
|
||||
Value float64 `json:"value"`
|
||||
Type string `json:"type"`
|
||||
}
|
||||
|
||||
// RedeemCodeResponse 兑换码响应
|
||||
type RedeemCodeResponse struct {
|
||||
Code string `json:"code"`
|
||||
Value float64 `json:"value"`
|
||||
Status string `json:"status"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
}
|
||||
|
||||
// RedeemService 兑换码服务
|
||||
type RedeemService struct {
|
||||
redeemRepo *repository.RedeemCodeRepository
|
||||
userRepo *repository.UserRepository
|
||||
subscriptionService *SubscriptionService
|
||||
rdb *redis.Client
|
||||
billingCacheService *BillingCacheService
|
||||
}
|
||||
|
||||
// NewRedeemService 创建兑换码服务实例
|
||||
func NewRedeemService(redeemRepo *repository.RedeemCodeRepository, userRepo *repository.UserRepository, subscriptionService *SubscriptionService, rdb *redis.Client) *RedeemService {
|
||||
return &RedeemService{
|
||||
redeemRepo: redeemRepo,
|
||||
userRepo: userRepo,
|
||||
subscriptionService: subscriptionService,
|
||||
rdb: rdb,
|
||||
}
|
||||
}
|
||||
|
||||
// SetBillingCacheService 设置计费缓存服务(用于缓存失效)
|
||||
func (s *RedeemService) SetBillingCacheService(billingCacheService *BillingCacheService) {
|
||||
s.billingCacheService = billingCacheService
|
||||
}
|
||||
|
||||
// GenerateRandomCode 生成随机兑换码
|
||||
func (s *RedeemService) GenerateRandomCode() (string, error) {
|
||||
// 生成16字节随机数据
|
||||
bytes := make([]byte, 16)
|
||||
if _, err := rand.Read(bytes); err != nil {
|
||||
return "", fmt.Errorf("generate random bytes: %w", err)
|
||||
}
|
||||
|
||||
// 转换为十六进制字符串
|
||||
code := hex.EncodeToString(bytes)
|
||||
|
||||
// 格式化为 XXXX-XXXX-XXXX-XXXX 格式
|
||||
parts := []string{
|
||||
strings.ToUpper(code[0:8]),
|
||||
strings.ToUpper(code[8:16]),
|
||||
strings.ToUpper(code[16:24]),
|
||||
strings.ToUpper(code[24:32]),
|
||||
}
|
||||
|
||||
return strings.Join(parts, "-"), nil
|
||||
}
|
||||
|
||||
// GenerateCodes 批量生成兑换码
|
||||
func (s *RedeemService) GenerateCodes(ctx context.Context, req GenerateCodesRequest) ([]model.RedeemCode, error) {
|
||||
if req.Count <= 0 {
|
||||
return nil, errors.New("count must be greater than 0")
|
||||
}
|
||||
|
||||
if req.Value <= 0 {
|
||||
return nil, errors.New("value must be greater than 0")
|
||||
}
|
||||
|
||||
if req.Count > 1000 {
|
||||
return nil, errors.New("cannot generate more than 1000 codes at once")
|
||||
}
|
||||
|
||||
codeType := req.Type
|
||||
if codeType == "" {
|
||||
codeType = model.RedeemTypeBalance
|
||||
}
|
||||
|
||||
codes := make([]model.RedeemCode, 0, req.Count)
|
||||
for i := 0; i < req.Count; i++ {
|
||||
code, err := s.GenerateRandomCode()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("generate code: %w", err)
|
||||
}
|
||||
|
||||
codes = append(codes, model.RedeemCode{
|
||||
Code: code,
|
||||
Type: codeType,
|
||||
Value: req.Value,
|
||||
Status: model.StatusUnused,
|
||||
})
|
||||
}
|
||||
|
||||
// 批量插入
|
||||
if err := s.redeemRepo.CreateBatch(ctx, codes); err != nil {
|
||||
return nil, fmt.Errorf("create batch codes: %w", err)
|
||||
}
|
||||
|
||||
return codes, nil
|
||||
}
|
||||
|
||||
// checkRedeemRateLimit 检查用户兑换错误次数是否超限
|
||||
func (s *RedeemService) checkRedeemRateLimit(ctx context.Context, userID int64) error {
|
||||
if s.rdb == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
key := fmt.Sprintf("%s%d", redeemRateLimitKeyPrefix, userID)
|
||||
|
||||
count, err := s.rdb.Get(ctx, key).Int()
|
||||
if err != nil && !errors.Is(err, redis.Nil) {
|
||||
// Redis 出错时不阻止用户操作
|
||||
return nil
|
||||
}
|
||||
|
||||
if count >= redeemMaxErrorsPerHour {
|
||||
return ErrRedeemRateLimited
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// incrementRedeemErrorCount 增加用户兑换错误计数
|
||||
func (s *RedeemService) incrementRedeemErrorCount(ctx context.Context, userID int64) {
|
||||
if s.rdb == nil {
|
||||
return
|
||||
}
|
||||
|
||||
key := fmt.Sprintf("%s%d", redeemRateLimitKeyPrefix, userID)
|
||||
|
||||
pipe := s.rdb.Pipeline()
|
||||
pipe.Incr(ctx, key)
|
||||
pipe.Expire(ctx, key, redeemRateLimitDuration)
|
||||
_, _ = pipe.Exec(ctx)
|
||||
}
|
||||
|
||||
// acquireRedeemLock 尝试获取兑换码的分布式锁
|
||||
// 返回 true 表示获取成功,false 表示锁已被占用
|
||||
func (s *RedeemService) acquireRedeemLock(ctx context.Context, code string) bool {
|
||||
if s.rdb == nil {
|
||||
return true // 无 Redis 时降级为不加锁
|
||||
}
|
||||
|
||||
key := redeemLockKeyPrefix + code
|
||||
ok, err := s.rdb.SetNX(ctx, key, "1", redeemLockDuration).Result()
|
||||
if err != nil {
|
||||
// Redis 出错时不阻止操作,依赖数据库层面的状态检查
|
||||
return true
|
||||
}
|
||||
return ok
|
||||
}
|
||||
|
||||
// releaseRedeemLock 释放兑换码的分布式锁
|
||||
func (s *RedeemService) releaseRedeemLock(ctx context.Context, code string) {
|
||||
if s.rdb == nil {
|
||||
return
|
||||
}
|
||||
|
||||
key := redeemLockKeyPrefix + code
|
||||
s.rdb.Del(ctx, key)
|
||||
}
|
||||
|
||||
// Redeem 使用兑换码
|
||||
func (s *RedeemService) Redeem(ctx context.Context, userID int64, code string) (*model.RedeemCode, error) {
|
||||
// 检查限流
|
||||
if err := s.checkRedeemRateLimit(ctx, userID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 获取分布式锁,防止同一兑换码并发使用
|
||||
if !s.acquireRedeemLock(ctx, code) {
|
||||
return nil, ErrRedeemCodeLocked
|
||||
}
|
||||
defer s.releaseRedeemLock(ctx, code)
|
||||
|
||||
// 查找兑换码
|
||||
redeemCode, err := s.redeemRepo.GetByCode(ctx, code)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
s.incrementRedeemErrorCount(ctx, userID)
|
||||
return nil, ErrRedeemCodeNotFound
|
||||
}
|
||||
return nil, fmt.Errorf("get redeem code: %w", err)
|
||||
}
|
||||
|
||||
// 检查兑换码状态
|
||||
if !redeemCode.CanUse() {
|
||||
s.incrementRedeemErrorCount(ctx, userID)
|
||||
return nil, ErrRedeemCodeUsed
|
||||
}
|
||||
|
||||
// 验证兑换码类型的前置条件
|
||||
if redeemCode.Type == model.RedeemTypeSubscription && redeemCode.GroupID == nil {
|
||||
return nil, errors.New("invalid subscription redeem code: missing group_id")
|
||||
}
|
||||
|
||||
// 获取用户信息
|
||||
user, err := s.userRepo.GetByID(ctx, userID)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, ErrUserNotFound
|
||||
}
|
||||
return nil, fmt.Errorf("get user: %w", err)
|
||||
}
|
||||
_ = user // 使用变量避免未使用错误
|
||||
|
||||
// 【关键】先标记兑换码为已使用,确保并发安全
|
||||
// 利用数据库乐观锁(WHERE status = 'unused')保证原子性
|
||||
if err := s.redeemRepo.Use(ctx, redeemCode.ID, userID); err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
// 兑换码已被其他请求使用
|
||||
return nil, ErrRedeemCodeUsed
|
||||
}
|
||||
return nil, fmt.Errorf("mark code as used: %w", err)
|
||||
}
|
||||
|
||||
// 执行兑换逻辑(兑换码已被锁定,此时可安全操作)
|
||||
switch redeemCode.Type {
|
||||
case model.RedeemTypeBalance:
|
||||
// 增加用户余额
|
||||
if err := s.userRepo.UpdateBalance(ctx, userID, redeemCode.Value); err != nil {
|
||||
return nil, fmt.Errorf("update user balance: %w", err)
|
||||
}
|
||||
// 失效余额缓存
|
||||
if s.billingCacheService != nil {
|
||||
go func() {
|
||||
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
s.billingCacheService.InvalidateUserBalance(cacheCtx, userID)
|
||||
}()
|
||||
}
|
||||
|
||||
case model.RedeemTypeConcurrency:
|
||||
// 增加用户并发数
|
||||
if err := s.userRepo.UpdateConcurrency(ctx, userID, int(redeemCode.Value)); err != nil {
|
||||
return nil, fmt.Errorf("update user concurrency: %w", err)
|
||||
}
|
||||
|
||||
case model.RedeemTypeSubscription:
|
||||
validityDays := redeemCode.ValidityDays
|
||||
if validityDays <= 0 {
|
||||
validityDays = 30
|
||||
}
|
||||
_, _, err := s.subscriptionService.AssignOrExtendSubscription(ctx, &AssignSubscriptionInput{
|
||||
UserID: userID,
|
||||
GroupID: *redeemCode.GroupID,
|
||||
ValidityDays: validityDays,
|
||||
AssignedBy: 0, // 系统分配
|
||||
Notes: fmt.Sprintf("通过兑换码 %s 兑换", redeemCode.Code),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("assign or extend subscription: %w", err)
|
||||
}
|
||||
// 失效订阅缓存
|
||||
if s.billingCacheService != nil {
|
||||
groupID := *redeemCode.GroupID
|
||||
go func() {
|
||||
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
s.billingCacheService.InvalidateSubscription(cacheCtx, userID, groupID)
|
||||
}()
|
||||
}
|
||||
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported redeem type: %s", redeemCode.Type)
|
||||
}
|
||||
|
||||
// 重新获取更新后的兑换码
|
||||
redeemCode, err = s.redeemRepo.GetByID(ctx, redeemCode.ID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get updated redeem code: %w", err)
|
||||
}
|
||||
|
||||
return redeemCode, nil
|
||||
}
|
||||
|
||||
// GetByID 根据ID获取兑换码
|
||||
func (s *RedeemService) GetByID(ctx context.Context, id int64) (*model.RedeemCode, error) {
|
||||
code, err := s.redeemRepo.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, ErrRedeemCodeNotFound
|
||||
}
|
||||
return nil, fmt.Errorf("get redeem code: %w", err)
|
||||
}
|
||||
return code, nil
|
||||
}
|
||||
|
||||
// GetByCode 根据Code获取兑换码
|
||||
func (s *RedeemService) GetByCode(ctx context.Context, code string) (*model.RedeemCode, error) {
|
||||
redeemCode, err := s.redeemRepo.GetByCode(ctx, code)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, ErrRedeemCodeNotFound
|
||||
}
|
||||
return nil, fmt.Errorf("get redeem code: %w", err)
|
||||
}
|
||||
return redeemCode, nil
|
||||
}
|
||||
|
||||
// List 获取兑换码列表(管理员功能)
|
||||
func (s *RedeemService) List(ctx context.Context, params repository.PaginationParams) ([]model.RedeemCode, *repository.PaginationResult, error) {
|
||||
codes, pagination, err := s.redeemRepo.List(ctx, params)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("list redeem codes: %w", err)
|
||||
}
|
||||
return codes, pagination, nil
|
||||
}
|
||||
|
||||
// Delete 删除兑换码(管理员功能)
|
||||
func (s *RedeemService) Delete(ctx context.Context, id int64) error {
|
||||
// 检查兑换码是否存在
|
||||
code, err := s.redeemRepo.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return ErrRedeemCodeNotFound
|
||||
}
|
||||
return fmt.Errorf("get redeem code: %w", err)
|
||||
}
|
||||
|
||||
// 不允许删除已使用的兑换码
|
||||
if code.IsUsed() {
|
||||
return errors.New("cannot delete used redeem code")
|
||||
}
|
||||
|
||||
if err := s.redeemRepo.Delete(ctx, id); err != nil {
|
||||
return fmt.Errorf("delete redeem code: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetStats 获取兑换码统计信息
|
||||
func (s *RedeemService) GetStats(ctx context.Context) (map[string]interface{}, error) {
|
||||
// TODO: 实现统计逻辑
|
||||
// 统计未使用、已使用的兑换码数量
|
||||
// 统计总面值等
|
||||
|
||||
stats := map[string]interface{}{
|
||||
"total_codes": 0,
|
||||
"unused_codes": 0,
|
||||
"used_codes": 0,
|
||||
"total_value": 0.0,
|
||||
}
|
||||
|
||||
return stats, nil
|
||||
}
|
||||
|
||||
// GetUserHistory 获取用户的兑换历史
|
||||
func (s *RedeemService) GetUserHistory(ctx context.Context, userID int64, limit int) ([]model.RedeemCode, error) {
|
||||
codes, err := s.redeemRepo.ListByUser(ctx, userID, limit)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get user redeem history: %w", err)
|
||||
}
|
||||
return codes, nil
|
||||
}
|
||||
139
backend/internal/service/service.go
Normal file
139
backend/internal/service/service.go
Normal file
@@ -0,0 +1,139 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"sub2api/internal/config"
|
||||
"sub2api/internal/repository"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
// Services 服务集合容器
|
||||
type Services struct {
|
||||
Auth *AuthService
|
||||
User *UserService
|
||||
ApiKey *ApiKeyService
|
||||
Group *GroupService
|
||||
Account *AccountService
|
||||
Proxy *ProxyService
|
||||
Redeem *RedeemService
|
||||
Usage *UsageService
|
||||
Pricing *PricingService
|
||||
Billing *BillingService
|
||||
BillingCache *BillingCacheService
|
||||
Admin AdminService
|
||||
Gateway *GatewayService
|
||||
OAuth *OAuthService
|
||||
RateLimit *RateLimitService
|
||||
AccountUsage *AccountUsageService
|
||||
AccountTest *AccountTestService
|
||||
Setting *SettingService
|
||||
Email *EmailService
|
||||
EmailQueue *EmailQueueService
|
||||
Turnstile *TurnstileService
|
||||
Subscription *SubscriptionService
|
||||
Concurrency *ConcurrencyService
|
||||
Identity *IdentityService
|
||||
}
|
||||
|
||||
// NewServices 创建所有服务实例
|
||||
func NewServices(repos *repository.Repositories, rdb *redis.Client, cfg *config.Config) *Services {
|
||||
// 初始化价格服务
|
||||
pricingService := NewPricingService(cfg)
|
||||
if err := pricingService.Initialize(); err != nil {
|
||||
// 价格服务初始化失败不应阻止启动,使用回退价格
|
||||
println("[Service] Warning: Pricing service initialization failed:", err.Error())
|
||||
}
|
||||
|
||||
// 初始化计费服务(依赖价格服务)
|
||||
billingService := NewBillingService(cfg, pricingService)
|
||||
|
||||
// 初始化其他服务
|
||||
authService := NewAuthService(repos.User, cfg)
|
||||
userService := NewUserService(repos.User, cfg)
|
||||
apiKeyService := NewApiKeyService(repos.ApiKey, repos.User, repos.Group, repos.UserSubscription, rdb, cfg)
|
||||
groupService := NewGroupService(repos.Group)
|
||||
accountService := NewAccountService(repos.Account, repos.Group)
|
||||
proxyService := NewProxyService(repos.Proxy)
|
||||
usageService := NewUsageService(repos.UsageLog, repos.User)
|
||||
|
||||
// 初始化订阅服务 (RedeemService 依赖)
|
||||
subscriptionService := NewSubscriptionService(repos)
|
||||
|
||||
// 初始化兑换服务 (依赖订阅服务)
|
||||
redeemService := NewRedeemService(repos.RedeemCode, repos.User, subscriptionService, rdb)
|
||||
|
||||
// 初始化Admin服务
|
||||
adminService := NewAdminService(repos)
|
||||
|
||||
// 初始化OAuth服务(GatewayService依赖)
|
||||
oauthService := NewOAuthService(repos.Proxy)
|
||||
|
||||
// 初始化限流服务
|
||||
rateLimitService := NewRateLimitService(repos, cfg)
|
||||
|
||||
// 初始化计费缓存服务
|
||||
billingCacheService := NewBillingCacheService(rdb, repos.User, repos.UserSubscription)
|
||||
|
||||
// 初始化账号使用量服务
|
||||
accountUsageService := NewAccountUsageService(repos, oauthService)
|
||||
|
||||
// 初始化账号测试服务
|
||||
accountTestService := NewAccountTestService(repos, oauthService)
|
||||
|
||||
// 初始化身份指纹服务
|
||||
identityService := NewIdentityService(rdb)
|
||||
|
||||
// 初始化Gateway服务
|
||||
gatewayService := NewGatewayService(repos, rdb, cfg, oauthService, billingService, rateLimitService, billingCacheService, identityService)
|
||||
|
||||
// 初始化设置服务
|
||||
settingService := NewSettingService(repos.Setting, cfg)
|
||||
emailService := NewEmailService(repos.Setting, rdb)
|
||||
|
||||
// 初始化邮件队列服务
|
||||
emailQueueService := NewEmailQueueService(emailService, 3)
|
||||
|
||||
// 初始化Turnstile服务
|
||||
turnstileService := NewTurnstileService(settingService)
|
||||
|
||||
// 设置Auth服务的依赖(用于注册开关和邮件验证)
|
||||
authService.SetSettingService(settingService)
|
||||
authService.SetEmailService(emailService)
|
||||
authService.SetTurnstileService(turnstileService)
|
||||
authService.SetEmailQueueService(emailQueueService)
|
||||
|
||||
// 初始化并发控制服务
|
||||
concurrencyService := NewConcurrencyService(rdb)
|
||||
|
||||
// 注入计费缓存服务到需要失效缓存的服务
|
||||
redeemService.SetBillingCacheService(billingCacheService)
|
||||
subscriptionService.SetBillingCacheService(billingCacheService)
|
||||
SetAdminServiceBillingCache(adminService, billingCacheService)
|
||||
|
||||
return &Services{
|
||||
Auth: authService,
|
||||
User: userService,
|
||||
ApiKey: apiKeyService,
|
||||
Group: groupService,
|
||||
Account: accountService,
|
||||
Proxy: proxyService,
|
||||
Redeem: redeemService,
|
||||
Usage: usageService,
|
||||
Pricing: pricingService,
|
||||
Billing: billingService,
|
||||
BillingCache: billingCacheService,
|
||||
Admin: adminService,
|
||||
Gateway: gatewayService,
|
||||
OAuth: oauthService,
|
||||
RateLimit: rateLimitService,
|
||||
AccountUsage: accountUsageService,
|
||||
AccountTest: accountTestService,
|
||||
Setting: settingService,
|
||||
Email: emailService,
|
||||
EmailQueue: emailQueueService,
|
||||
Turnstile: turnstileService,
|
||||
Subscription: subscriptionService,
|
||||
Concurrency: concurrencyService,
|
||||
Identity: identityService,
|
||||
}
|
||||
}
|
||||
264
backend/internal/service/setting_service.go
Normal file
264
backend/internal/service/setting_service.go
Normal file
@@ -0,0 +1,264 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"sub2api/internal/config"
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/repository"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrRegistrationDisabled = errors.New("registration is currently disabled")
|
||||
)
|
||||
|
||||
// SettingService 系统设置服务
|
||||
type SettingService struct {
|
||||
settingRepo *repository.SettingRepository
|
||||
cfg *config.Config
|
||||
}
|
||||
|
||||
// NewSettingService 创建系统设置服务实例
|
||||
func NewSettingService(settingRepo *repository.SettingRepository, cfg *config.Config) *SettingService {
|
||||
return &SettingService{
|
||||
settingRepo: settingRepo,
|
||||
cfg: cfg,
|
||||
}
|
||||
}
|
||||
|
||||
// GetAllSettings 获取所有系统设置
|
||||
func (s *SettingService) GetAllSettings(ctx context.Context) (*model.SystemSettings, error) {
|
||||
settings, err := s.settingRepo.GetAll(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get all settings: %w", err)
|
||||
}
|
||||
|
||||
return s.parseSettings(settings), nil
|
||||
}
|
||||
|
||||
// GetPublicSettings 获取公开设置(无需登录)
|
||||
func (s *SettingService) GetPublicSettings(ctx context.Context) (*model.PublicSettings, error) {
|
||||
keys := []string{
|
||||
model.SettingKeyRegistrationEnabled,
|
||||
model.SettingKeyEmailVerifyEnabled,
|
||||
model.SettingKeyTurnstileEnabled,
|
||||
model.SettingKeyTurnstileSiteKey,
|
||||
model.SettingKeySiteName,
|
||||
model.SettingKeySiteLogo,
|
||||
model.SettingKeySiteSubtitle,
|
||||
model.SettingKeyApiBaseUrl,
|
||||
model.SettingKeyContactInfo,
|
||||
}
|
||||
|
||||
settings, err := s.settingRepo.GetMultiple(ctx, keys)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get public settings: %w", err)
|
||||
}
|
||||
|
||||
return &model.PublicSettings{
|
||||
RegistrationEnabled: settings[model.SettingKeyRegistrationEnabled] == "true",
|
||||
EmailVerifyEnabled: settings[model.SettingKeyEmailVerifyEnabled] == "true",
|
||||
TurnstileEnabled: settings[model.SettingKeyTurnstileEnabled] == "true",
|
||||
TurnstileSiteKey: settings[model.SettingKeyTurnstileSiteKey],
|
||||
SiteName: s.getStringOrDefault(settings, model.SettingKeySiteName, "Sub2API"),
|
||||
SiteLogo: settings[model.SettingKeySiteLogo],
|
||||
SiteSubtitle: s.getStringOrDefault(settings, model.SettingKeySiteSubtitle, "Subscription to API Conversion Platform"),
|
||||
ApiBaseUrl: settings[model.SettingKeyApiBaseUrl],
|
||||
ContactInfo: settings[model.SettingKeyContactInfo],
|
||||
}, nil
|
||||
}
|
||||
|
||||
// UpdateSettings 更新系统设置
|
||||
func (s *SettingService) UpdateSettings(ctx context.Context, settings *model.SystemSettings) error {
|
||||
updates := make(map[string]string)
|
||||
|
||||
// 注册设置
|
||||
updates[model.SettingKeyRegistrationEnabled] = strconv.FormatBool(settings.RegistrationEnabled)
|
||||
updates[model.SettingKeyEmailVerifyEnabled] = strconv.FormatBool(settings.EmailVerifyEnabled)
|
||||
|
||||
// 邮件服务设置(只有非空才更新密码)
|
||||
updates[model.SettingKeySmtpHost] = settings.SmtpHost
|
||||
updates[model.SettingKeySmtpPort] = strconv.Itoa(settings.SmtpPort)
|
||||
updates[model.SettingKeySmtpUsername] = settings.SmtpUsername
|
||||
if settings.SmtpPassword != "" {
|
||||
updates[model.SettingKeySmtpPassword] = settings.SmtpPassword
|
||||
}
|
||||
updates[model.SettingKeySmtpFrom] = settings.SmtpFrom
|
||||
updates[model.SettingKeySmtpFromName] = settings.SmtpFromName
|
||||
updates[model.SettingKeySmtpUseTLS] = strconv.FormatBool(settings.SmtpUseTLS)
|
||||
|
||||
// Cloudflare Turnstile 设置(只有非空才更新密钥)
|
||||
updates[model.SettingKeyTurnstileEnabled] = strconv.FormatBool(settings.TurnstileEnabled)
|
||||
updates[model.SettingKeyTurnstileSiteKey] = settings.TurnstileSiteKey
|
||||
if settings.TurnstileSecretKey != "" {
|
||||
updates[model.SettingKeyTurnstileSecretKey] = settings.TurnstileSecretKey
|
||||
}
|
||||
|
||||
// OEM设置
|
||||
updates[model.SettingKeySiteName] = settings.SiteName
|
||||
updates[model.SettingKeySiteLogo] = settings.SiteLogo
|
||||
updates[model.SettingKeySiteSubtitle] = settings.SiteSubtitle
|
||||
updates[model.SettingKeyApiBaseUrl] = settings.ApiBaseUrl
|
||||
updates[model.SettingKeyContactInfo] = settings.ContactInfo
|
||||
|
||||
// 默认配置
|
||||
updates[model.SettingKeyDefaultConcurrency] = strconv.Itoa(settings.DefaultConcurrency)
|
||||
updates[model.SettingKeyDefaultBalance] = strconv.FormatFloat(settings.DefaultBalance, 'f', 8, 64)
|
||||
|
||||
return s.settingRepo.SetMultiple(ctx, updates)
|
||||
}
|
||||
|
||||
// IsRegistrationEnabled 检查是否开放注册
|
||||
func (s *SettingService) IsRegistrationEnabled(ctx context.Context) bool {
|
||||
value, err := s.settingRepo.GetValue(ctx, model.SettingKeyRegistrationEnabled)
|
||||
if err != nil {
|
||||
// 默认开放注册
|
||||
return true
|
||||
}
|
||||
return value == "true"
|
||||
}
|
||||
|
||||
// IsEmailVerifyEnabled 检查是否开启邮件验证
|
||||
func (s *SettingService) IsEmailVerifyEnabled(ctx context.Context) bool {
|
||||
value, err := s.settingRepo.GetValue(ctx, model.SettingKeyEmailVerifyEnabled)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return value == "true"
|
||||
}
|
||||
|
||||
// GetSiteName 获取网站名称
|
||||
func (s *SettingService) GetSiteName(ctx context.Context) string {
|
||||
value, err := s.settingRepo.GetValue(ctx, model.SettingKeySiteName)
|
||||
if err != nil || value == "" {
|
||||
return "Sub2API"
|
||||
}
|
||||
return value
|
||||
}
|
||||
|
||||
// GetDefaultConcurrency 获取默认并发量
|
||||
func (s *SettingService) GetDefaultConcurrency(ctx context.Context) int {
|
||||
value, err := s.settingRepo.GetValue(ctx, model.SettingKeyDefaultConcurrency)
|
||||
if err != nil {
|
||||
return s.cfg.Default.UserConcurrency
|
||||
}
|
||||
if v, err := strconv.Atoi(value); err == nil && v > 0 {
|
||||
return v
|
||||
}
|
||||
return s.cfg.Default.UserConcurrency
|
||||
}
|
||||
|
||||
// GetDefaultBalance 获取默认余额
|
||||
func (s *SettingService) GetDefaultBalance(ctx context.Context) float64 {
|
||||
value, err := s.settingRepo.GetValue(ctx, model.SettingKeyDefaultBalance)
|
||||
if err != nil {
|
||||
return s.cfg.Default.UserBalance
|
||||
}
|
||||
if v, err := strconv.ParseFloat(value, 64); err == nil && v >= 0 {
|
||||
return v
|
||||
}
|
||||
return s.cfg.Default.UserBalance
|
||||
}
|
||||
|
||||
// InitializeDefaultSettings 初始化默认设置
|
||||
func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
|
||||
// 检查是否已有设置
|
||||
_, err := s.settingRepo.GetValue(ctx, model.SettingKeyRegistrationEnabled)
|
||||
if err == nil {
|
||||
// 已有设置,不需要初始化
|
||||
return nil
|
||||
}
|
||||
if !errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return fmt.Errorf("check existing settings: %w", err)
|
||||
}
|
||||
|
||||
// 初始化默认设置
|
||||
defaults := map[string]string{
|
||||
model.SettingKeyRegistrationEnabled: "true",
|
||||
model.SettingKeyEmailVerifyEnabled: "false",
|
||||
model.SettingKeySiteName: "Sub2API",
|
||||
model.SettingKeySiteLogo: "",
|
||||
model.SettingKeyDefaultConcurrency: strconv.Itoa(s.cfg.Default.UserConcurrency),
|
||||
model.SettingKeyDefaultBalance: strconv.FormatFloat(s.cfg.Default.UserBalance, 'f', 8, 64),
|
||||
model.SettingKeySmtpPort: "587",
|
||||
model.SettingKeySmtpUseTLS: "false",
|
||||
}
|
||||
|
||||
return s.settingRepo.SetMultiple(ctx, defaults)
|
||||
}
|
||||
|
||||
// parseSettings 解析设置到结构体
|
||||
func (s *SettingService) parseSettings(settings map[string]string) *model.SystemSettings {
|
||||
result := &model.SystemSettings{
|
||||
RegistrationEnabled: settings[model.SettingKeyRegistrationEnabled] == "true",
|
||||
EmailVerifyEnabled: settings[model.SettingKeyEmailVerifyEnabled] == "true",
|
||||
SmtpHost: settings[model.SettingKeySmtpHost],
|
||||
SmtpUsername: settings[model.SettingKeySmtpUsername],
|
||||
SmtpFrom: settings[model.SettingKeySmtpFrom],
|
||||
SmtpFromName: settings[model.SettingKeySmtpFromName],
|
||||
SmtpUseTLS: settings[model.SettingKeySmtpUseTLS] == "true",
|
||||
TurnstileEnabled: settings[model.SettingKeyTurnstileEnabled] == "true",
|
||||
TurnstileSiteKey: settings[model.SettingKeyTurnstileSiteKey],
|
||||
SiteName: s.getStringOrDefault(settings, model.SettingKeySiteName, "Sub2API"),
|
||||
SiteLogo: settings[model.SettingKeySiteLogo],
|
||||
SiteSubtitle: s.getStringOrDefault(settings, model.SettingKeySiteSubtitle, "Subscription to API Conversion Platform"),
|
||||
ApiBaseUrl: settings[model.SettingKeyApiBaseUrl],
|
||||
ContactInfo: settings[model.SettingKeyContactInfo],
|
||||
}
|
||||
|
||||
// 解析整数类型
|
||||
if port, err := strconv.Atoi(settings[model.SettingKeySmtpPort]); err == nil {
|
||||
result.SmtpPort = port
|
||||
} else {
|
||||
result.SmtpPort = 587
|
||||
}
|
||||
|
||||
if concurrency, err := strconv.Atoi(settings[model.SettingKeyDefaultConcurrency]); err == nil {
|
||||
result.DefaultConcurrency = concurrency
|
||||
} else {
|
||||
result.DefaultConcurrency = s.cfg.Default.UserConcurrency
|
||||
}
|
||||
|
||||
// 解析浮点数类型
|
||||
if balance, err := strconv.ParseFloat(settings[model.SettingKeyDefaultBalance], 64); err == nil {
|
||||
result.DefaultBalance = balance
|
||||
} else {
|
||||
result.DefaultBalance = s.cfg.Default.UserBalance
|
||||
}
|
||||
|
||||
// 敏感信息直接返回,方便测试连接时使用
|
||||
result.SmtpPassword = settings[model.SettingKeySmtpPassword]
|
||||
result.TurnstileSecretKey = settings[model.SettingKeyTurnstileSecretKey]
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// getStringOrDefault 获取字符串值或默认值
|
||||
func (s *SettingService) getStringOrDefault(settings map[string]string, key, defaultValue string) string {
|
||||
if value, ok := settings[key]; ok && value != "" {
|
||||
return value
|
||||
}
|
||||
return defaultValue
|
||||
}
|
||||
|
||||
// IsTurnstileEnabled 检查是否启用 Turnstile 验证
|
||||
func (s *SettingService) IsTurnstileEnabled(ctx context.Context) bool {
|
||||
value, err := s.settingRepo.GetValue(ctx, model.SettingKeyTurnstileEnabled)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return value == "true"
|
||||
}
|
||||
|
||||
// GetTurnstileSecretKey 获取 Turnstile Secret Key
|
||||
func (s *SettingService) GetTurnstileSecretKey(ctx context.Context) string {
|
||||
value, err := s.settingRepo.GetValue(ctx, model.SettingKeyTurnstileSecretKey)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return value
|
||||
}
|
||||
575
backend/internal/service/subscription_service.go
Normal file
575
backend/internal/service/subscription_service.go
Normal file
@@ -0,0 +1,575 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/repository"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrSubscriptionNotFound = errors.New("subscription not found")
|
||||
ErrSubscriptionExpired = errors.New("subscription has expired")
|
||||
ErrSubscriptionSuspended = errors.New("subscription is suspended")
|
||||
ErrSubscriptionAlreadyExists = errors.New("subscription already exists for this user and group")
|
||||
ErrGroupNotSubscriptionType = errors.New("group is not a subscription type")
|
||||
ErrDailyLimitExceeded = errors.New("daily usage limit exceeded")
|
||||
ErrWeeklyLimitExceeded = errors.New("weekly usage limit exceeded")
|
||||
ErrMonthlyLimitExceeded = errors.New("monthly usage limit exceeded")
|
||||
)
|
||||
|
||||
// SubscriptionService 订阅服务
|
||||
type SubscriptionService struct {
|
||||
repos *repository.Repositories
|
||||
billingCacheService *BillingCacheService
|
||||
}
|
||||
|
||||
// NewSubscriptionService 创建订阅服务
|
||||
func NewSubscriptionService(repos *repository.Repositories) *SubscriptionService {
|
||||
return &SubscriptionService{repos: repos}
|
||||
}
|
||||
|
||||
// SetBillingCacheService 设置计费缓存服务(用于缓存失效)
|
||||
func (s *SubscriptionService) SetBillingCacheService(billingCacheService *BillingCacheService) {
|
||||
s.billingCacheService = billingCacheService
|
||||
}
|
||||
|
||||
// AssignSubscriptionInput 分配订阅输入
|
||||
type AssignSubscriptionInput struct {
|
||||
UserID int64
|
||||
GroupID int64
|
||||
ValidityDays int
|
||||
AssignedBy int64
|
||||
Notes string
|
||||
}
|
||||
|
||||
// AssignSubscription 分配订阅给用户(不允许重复分配)
|
||||
func (s *SubscriptionService) AssignSubscription(ctx context.Context, input *AssignSubscriptionInput) (*model.UserSubscription, error) {
|
||||
// 检查分组是否存在且为订阅类型
|
||||
group, err := s.repos.Group.GetByID(ctx, input.GroupID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("group not found: %w", err)
|
||||
}
|
||||
if !group.IsSubscriptionType() {
|
||||
return nil, ErrGroupNotSubscriptionType
|
||||
}
|
||||
|
||||
// 检查是否已存在订阅
|
||||
exists, err := s.repos.UserSubscription.ExistsByUserIDAndGroupID(ctx, input.UserID, input.GroupID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if exists {
|
||||
return nil, ErrSubscriptionAlreadyExists
|
||||
}
|
||||
|
||||
sub, err := s.createSubscription(ctx, input)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 失效订阅缓存
|
||||
if s.billingCacheService != nil {
|
||||
userID, groupID := input.UserID, input.GroupID
|
||||
go func() {
|
||||
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
s.billingCacheService.InvalidateSubscription(cacheCtx, userID, groupID)
|
||||
}()
|
||||
}
|
||||
|
||||
return sub, nil
|
||||
}
|
||||
|
||||
// AssignOrExtendSubscription 分配或续期订阅(用于兑换码等场景)
|
||||
// 如果用户已有同分组的订阅:
|
||||
// - 未过期:从当前过期时间累加天数
|
||||
// - 已过期:从当前时间开始计算新的过期时间,并激活订阅
|
||||
// 如果没有订阅:创建新订阅
|
||||
func (s *SubscriptionService) AssignOrExtendSubscription(ctx context.Context, input *AssignSubscriptionInput) (*model.UserSubscription, bool, error) {
|
||||
// 检查分组是否存在且为订阅类型
|
||||
group, err := s.repos.Group.GetByID(ctx, input.GroupID)
|
||||
if err != nil {
|
||||
return nil, false, fmt.Errorf("group not found: %w", err)
|
||||
}
|
||||
if !group.IsSubscriptionType() {
|
||||
return nil, false, ErrGroupNotSubscriptionType
|
||||
}
|
||||
|
||||
// 查询是否已有订阅
|
||||
existingSub, err := s.repos.UserSubscription.GetByUserIDAndGroupID(ctx, input.UserID, input.GroupID)
|
||||
if err != nil {
|
||||
// 不存在记录是正常情况,其他错误需要返回
|
||||
existingSub = nil
|
||||
}
|
||||
|
||||
validityDays := input.ValidityDays
|
||||
if validityDays <= 0 {
|
||||
validityDays = 30
|
||||
}
|
||||
|
||||
// 已有订阅,执行续期
|
||||
if existingSub != nil {
|
||||
now := time.Now()
|
||||
var newExpiresAt time.Time
|
||||
|
||||
if existingSub.ExpiresAt.After(now) {
|
||||
// 未过期:从当前过期时间累加
|
||||
newExpiresAt = existingSub.ExpiresAt.AddDate(0, 0, validityDays)
|
||||
} else {
|
||||
// 已过期:从当前时间开始计算
|
||||
newExpiresAt = now.AddDate(0, 0, validityDays)
|
||||
}
|
||||
|
||||
// 更新过期时间
|
||||
if err := s.repos.UserSubscription.ExtendExpiry(ctx, existingSub.ID, newExpiresAt); err != nil {
|
||||
return nil, false, fmt.Errorf("extend subscription: %w", err)
|
||||
}
|
||||
|
||||
// 如果订阅已过期或被暂停,恢复为active状态
|
||||
if existingSub.Status != model.SubscriptionStatusActive {
|
||||
if err := s.repos.UserSubscription.UpdateStatus(ctx, existingSub.ID, model.SubscriptionStatusActive); err != nil {
|
||||
return nil, false, fmt.Errorf("update subscription status: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// 追加备注
|
||||
if input.Notes != "" {
|
||||
newNotes := existingSub.Notes
|
||||
if newNotes != "" {
|
||||
newNotes += "\n"
|
||||
}
|
||||
newNotes += input.Notes
|
||||
if err := s.repos.UserSubscription.UpdateNotes(ctx, existingSub.ID, newNotes); err != nil {
|
||||
// 备注更新失败不影响主流程
|
||||
}
|
||||
}
|
||||
|
||||
// 失效订阅缓存
|
||||
if s.billingCacheService != nil {
|
||||
userID, groupID := input.UserID, input.GroupID
|
||||
go func() {
|
||||
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
s.billingCacheService.InvalidateSubscription(cacheCtx, userID, groupID)
|
||||
}()
|
||||
}
|
||||
|
||||
// 返回更新后的订阅
|
||||
sub, err := s.repos.UserSubscription.GetByID(ctx, existingSub.ID)
|
||||
return sub, true, err // true 表示是续期
|
||||
}
|
||||
|
||||
// 没有订阅,创建新订阅
|
||||
sub, err := s.createSubscription(ctx, input)
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
|
||||
// 失效订阅缓存
|
||||
if s.billingCacheService != nil {
|
||||
userID, groupID := input.UserID, input.GroupID
|
||||
go func() {
|
||||
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
s.billingCacheService.InvalidateSubscription(cacheCtx, userID, groupID)
|
||||
}()
|
||||
}
|
||||
|
||||
return sub, false, nil // false 表示是新建
|
||||
}
|
||||
|
||||
// createSubscription 创建新订阅(内部方法)
|
||||
func (s *SubscriptionService) createSubscription(ctx context.Context, input *AssignSubscriptionInput) (*model.UserSubscription, error) {
|
||||
validityDays := input.ValidityDays
|
||||
if validityDays <= 0 {
|
||||
validityDays = 30
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
sub := &model.UserSubscription{
|
||||
UserID: input.UserID,
|
||||
GroupID: input.GroupID,
|
||||
StartsAt: now,
|
||||
ExpiresAt: now.AddDate(0, 0, validityDays),
|
||||
Status: model.SubscriptionStatusActive,
|
||||
AssignedAt: now,
|
||||
Notes: input.Notes,
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
}
|
||||
// 只有当 AssignedBy > 0 时才设置(0 表示系统分配,如兑换码)
|
||||
if input.AssignedBy > 0 {
|
||||
sub.AssignedBy = &input.AssignedBy
|
||||
}
|
||||
|
||||
if err := s.repos.UserSubscription.Create(ctx, sub); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 重新获取完整订阅信息(包含关联)
|
||||
return s.repos.UserSubscription.GetByID(ctx, sub.ID)
|
||||
}
|
||||
|
||||
// BulkAssignSubscriptionInput 批量分配订阅输入
|
||||
type BulkAssignSubscriptionInput struct {
|
||||
UserIDs []int64
|
||||
GroupID int64
|
||||
ValidityDays int
|
||||
AssignedBy int64
|
||||
Notes string
|
||||
}
|
||||
|
||||
// BulkAssignResult 批量分配结果
|
||||
type BulkAssignResult struct {
|
||||
SuccessCount int
|
||||
FailedCount int
|
||||
Subscriptions []model.UserSubscription
|
||||
Errors []string
|
||||
}
|
||||
|
||||
// BulkAssignSubscription 批量分配订阅
|
||||
func (s *SubscriptionService) BulkAssignSubscription(ctx context.Context, input *BulkAssignSubscriptionInput) (*BulkAssignResult, error) {
|
||||
result := &BulkAssignResult{
|
||||
Subscriptions: make([]model.UserSubscription, 0),
|
||||
Errors: make([]string, 0),
|
||||
}
|
||||
|
||||
for _, userID := range input.UserIDs {
|
||||
sub, err := s.AssignSubscription(ctx, &AssignSubscriptionInput{
|
||||
UserID: userID,
|
||||
GroupID: input.GroupID,
|
||||
ValidityDays: input.ValidityDays,
|
||||
AssignedBy: input.AssignedBy,
|
||||
Notes: input.Notes,
|
||||
})
|
||||
if err != nil {
|
||||
result.FailedCount++
|
||||
result.Errors = append(result.Errors, fmt.Sprintf("user %d: %v", userID, err))
|
||||
} else {
|
||||
result.SuccessCount++
|
||||
result.Subscriptions = append(result.Subscriptions, *sub)
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// RevokeSubscription 撤销订阅
|
||||
func (s *SubscriptionService) RevokeSubscription(ctx context.Context, subscriptionID int64) error {
|
||||
// 先获取订阅信息用于失效缓存
|
||||
sub, err := s.repos.UserSubscription.GetByID(ctx, subscriptionID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := s.repos.UserSubscription.Delete(ctx, subscriptionID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 失效订阅缓存
|
||||
if s.billingCacheService != nil {
|
||||
userID, groupID := sub.UserID, sub.GroupID
|
||||
go func() {
|
||||
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
s.billingCacheService.InvalidateSubscription(cacheCtx, userID, groupID)
|
||||
}()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ExtendSubscription 延长订阅
|
||||
func (s *SubscriptionService) ExtendSubscription(ctx context.Context, subscriptionID int64, days int) (*model.UserSubscription, error) {
|
||||
sub, err := s.repos.UserSubscription.GetByID(ctx, subscriptionID)
|
||||
if err != nil {
|
||||
return nil, ErrSubscriptionNotFound
|
||||
}
|
||||
|
||||
// 计算新的过期时间
|
||||
newExpiresAt := sub.ExpiresAt.AddDate(0, 0, days)
|
||||
if err := s.repos.UserSubscription.ExtendExpiry(ctx, subscriptionID, newExpiresAt); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 如果订阅已过期,恢复为active状态
|
||||
if sub.Status == model.SubscriptionStatusExpired {
|
||||
if err := s.repos.UserSubscription.UpdateStatus(ctx, subscriptionID, model.SubscriptionStatusActive); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// 失效订阅缓存
|
||||
if s.billingCacheService != nil {
|
||||
userID, groupID := sub.UserID, sub.GroupID
|
||||
go func() {
|
||||
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
s.billingCacheService.InvalidateSubscription(cacheCtx, userID, groupID)
|
||||
}()
|
||||
}
|
||||
|
||||
return s.repos.UserSubscription.GetByID(ctx, subscriptionID)
|
||||
}
|
||||
|
||||
// GetByID 根据ID获取订阅
|
||||
func (s *SubscriptionService) GetByID(ctx context.Context, id int64) (*model.UserSubscription, error) {
|
||||
return s.repos.UserSubscription.GetByID(ctx, id)
|
||||
}
|
||||
|
||||
// GetActiveSubscription 获取用户对特定分组的有效订阅
|
||||
func (s *SubscriptionService) GetActiveSubscription(ctx context.Context, userID, groupID int64) (*model.UserSubscription, error) {
|
||||
sub, err := s.repos.UserSubscription.GetActiveByUserIDAndGroupID(ctx, userID, groupID)
|
||||
if err != nil {
|
||||
return nil, ErrSubscriptionNotFound
|
||||
}
|
||||
return sub, nil
|
||||
}
|
||||
|
||||
// ListUserSubscriptions 获取用户的所有订阅
|
||||
func (s *SubscriptionService) ListUserSubscriptions(ctx context.Context, userID int64) ([]model.UserSubscription, error) {
|
||||
return s.repos.UserSubscription.ListByUserID(ctx, userID)
|
||||
}
|
||||
|
||||
// ListActiveUserSubscriptions 获取用户的所有有效订阅
|
||||
func (s *SubscriptionService) ListActiveUserSubscriptions(ctx context.Context, userID int64) ([]model.UserSubscription, error) {
|
||||
return s.repos.UserSubscription.ListActiveByUserID(ctx, userID)
|
||||
}
|
||||
|
||||
// ListGroupSubscriptions 获取分组的所有订阅
|
||||
func (s *SubscriptionService) ListGroupSubscriptions(ctx context.Context, groupID int64, page, pageSize int) ([]model.UserSubscription, *repository.PaginationResult, error) {
|
||||
params := repository.PaginationParams{Page: page, PageSize: pageSize}
|
||||
return s.repos.UserSubscription.ListByGroupID(ctx, groupID, params)
|
||||
}
|
||||
|
||||
// List 获取所有订阅(分页,支持筛选)
|
||||
func (s *SubscriptionService) List(ctx context.Context, page, pageSize int, userID, groupID *int64, status string) ([]model.UserSubscription, *repository.PaginationResult, error) {
|
||||
params := repository.PaginationParams{Page: page, PageSize: pageSize}
|
||||
return s.repos.UserSubscription.List(ctx, params, userID, groupID, status)
|
||||
}
|
||||
|
||||
// CheckAndActivateWindow 检查并激活窗口(首次使用时)
|
||||
func (s *SubscriptionService) CheckAndActivateWindow(ctx context.Context, sub *model.UserSubscription) error {
|
||||
if sub.IsWindowActivated() {
|
||||
return nil
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
return s.repos.UserSubscription.ActivateWindows(ctx, sub.ID, now)
|
||||
}
|
||||
|
||||
// CheckAndResetWindows 检查并重置过期的窗口
|
||||
func (s *SubscriptionService) CheckAndResetWindows(ctx context.Context, sub *model.UserSubscription) error {
|
||||
now := time.Now()
|
||||
|
||||
// 日窗口重置(24小时)
|
||||
if sub.NeedsDailyReset() {
|
||||
if err := s.repos.UserSubscription.ResetDailyUsage(ctx, sub.ID, now); err != nil {
|
||||
return err
|
||||
}
|
||||
sub.DailyWindowStart = &now
|
||||
sub.DailyUsageUSD = 0
|
||||
}
|
||||
|
||||
// 周窗口重置(7天)
|
||||
if sub.NeedsWeeklyReset() {
|
||||
if err := s.repos.UserSubscription.ResetWeeklyUsage(ctx, sub.ID, now); err != nil {
|
||||
return err
|
||||
}
|
||||
sub.WeeklyWindowStart = &now
|
||||
sub.WeeklyUsageUSD = 0
|
||||
}
|
||||
|
||||
// 月窗口重置(30天)
|
||||
if sub.NeedsMonthlyReset() {
|
||||
if err := s.repos.UserSubscription.ResetMonthlyUsage(ctx, sub.ID, now); err != nil {
|
||||
return err
|
||||
}
|
||||
sub.MonthlyWindowStart = &now
|
||||
sub.MonthlyUsageUSD = 0
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// CheckUsageLimits 检查使用限额(返回错误如果超限)
|
||||
func (s *SubscriptionService) CheckUsageLimits(ctx context.Context, sub *model.UserSubscription, group *model.Group, additionalCost float64) error {
|
||||
if !sub.CheckDailyLimit(group, additionalCost) {
|
||||
return ErrDailyLimitExceeded
|
||||
}
|
||||
if !sub.CheckWeeklyLimit(group, additionalCost) {
|
||||
return ErrWeeklyLimitExceeded
|
||||
}
|
||||
if !sub.CheckMonthlyLimit(group, additionalCost) {
|
||||
return ErrMonthlyLimitExceeded
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// RecordUsage 记录使用量到订阅
|
||||
func (s *SubscriptionService) RecordUsage(ctx context.Context, subscriptionID int64, costUSD float64) error {
|
||||
return s.repos.UserSubscription.IncrementUsage(ctx, subscriptionID, costUSD)
|
||||
}
|
||||
|
||||
// SubscriptionProgress 订阅进度
|
||||
type SubscriptionProgress struct {
|
||||
ID int64 `json:"id"`
|
||||
GroupName string `json:"group_name"`
|
||||
ExpiresAt time.Time `json:"expires_at"`
|
||||
ExpiresInDays int `json:"expires_in_days"`
|
||||
Daily *UsageWindowProgress `json:"daily,omitempty"`
|
||||
Weekly *UsageWindowProgress `json:"weekly,omitempty"`
|
||||
Monthly *UsageWindowProgress `json:"monthly,omitempty"`
|
||||
}
|
||||
|
||||
// UsageWindowProgress 使用窗口进度
|
||||
type UsageWindowProgress struct {
|
||||
LimitUSD float64 `json:"limit_usd"`
|
||||
UsedUSD float64 `json:"used_usd"`
|
||||
RemainingUSD float64 `json:"remaining_usd"`
|
||||
Percentage float64 `json:"percentage"`
|
||||
WindowStart time.Time `json:"window_start"`
|
||||
ResetsAt time.Time `json:"resets_at"`
|
||||
ResetsInSeconds int64 `json:"resets_in_seconds"`
|
||||
}
|
||||
|
||||
// GetSubscriptionProgress 获取订阅使用进度
|
||||
func (s *SubscriptionService) GetSubscriptionProgress(ctx context.Context, subscriptionID int64) (*SubscriptionProgress, error) {
|
||||
sub, err := s.repos.UserSubscription.GetByID(ctx, subscriptionID)
|
||||
if err != nil {
|
||||
return nil, ErrSubscriptionNotFound
|
||||
}
|
||||
|
||||
group := sub.Group
|
||||
if group == nil {
|
||||
group, err = s.repos.Group.GetByID(ctx, sub.GroupID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
progress := &SubscriptionProgress{
|
||||
ID: sub.ID,
|
||||
GroupName: group.Name,
|
||||
ExpiresAt: sub.ExpiresAt,
|
||||
ExpiresInDays: sub.DaysRemaining(),
|
||||
}
|
||||
|
||||
// 日进度
|
||||
if group.HasDailyLimit() && sub.DailyWindowStart != nil {
|
||||
limit := *group.DailyLimitUSD
|
||||
resetsAt := sub.DailyWindowStart.Add(24 * time.Hour)
|
||||
progress.Daily = &UsageWindowProgress{
|
||||
LimitUSD: limit,
|
||||
UsedUSD: sub.DailyUsageUSD,
|
||||
RemainingUSD: limit - sub.DailyUsageUSD,
|
||||
Percentage: (sub.DailyUsageUSD / limit) * 100,
|
||||
WindowStart: *sub.DailyWindowStart,
|
||||
ResetsAt: resetsAt,
|
||||
ResetsInSeconds: int64(time.Until(resetsAt).Seconds()),
|
||||
}
|
||||
if progress.Daily.RemainingUSD < 0 {
|
||||
progress.Daily.RemainingUSD = 0
|
||||
}
|
||||
if progress.Daily.Percentage > 100 {
|
||||
progress.Daily.Percentage = 100
|
||||
}
|
||||
if progress.Daily.ResetsInSeconds < 0 {
|
||||
progress.Daily.ResetsInSeconds = 0
|
||||
}
|
||||
}
|
||||
|
||||
// 周进度
|
||||
if group.HasWeeklyLimit() && sub.WeeklyWindowStart != nil {
|
||||
limit := *group.WeeklyLimitUSD
|
||||
resetsAt := sub.WeeklyWindowStart.Add(7 * 24 * time.Hour)
|
||||
progress.Weekly = &UsageWindowProgress{
|
||||
LimitUSD: limit,
|
||||
UsedUSD: sub.WeeklyUsageUSD,
|
||||
RemainingUSD: limit - sub.WeeklyUsageUSD,
|
||||
Percentage: (sub.WeeklyUsageUSD / limit) * 100,
|
||||
WindowStart: *sub.WeeklyWindowStart,
|
||||
ResetsAt: resetsAt,
|
||||
ResetsInSeconds: int64(time.Until(resetsAt).Seconds()),
|
||||
}
|
||||
if progress.Weekly.RemainingUSD < 0 {
|
||||
progress.Weekly.RemainingUSD = 0
|
||||
}
|
||||
if progress.Weekly.Percentage > 100 {
|
||||
progress.Weekly.Percentage = 100
|
||||
}
|
||||
if progress.Weekly.ResetsInSeconds < 0 {
|
||||
progress.Weekly.ResetsInSeconds = 0
|
||||
}
|
||||
}
|
||||
|
||||
// 月进度
|
||||
if group.HasMonthlyLimit() && sub.MonthlyWindowStart != nil {
|
||||
limit := *group.MonthlyLimitUSD
|
||||
resetsAt := sub.MonthlyWindowStart.Add(30 * 24 * time.Hour)
|
||||
progress.Monthly = &UsageWindowProgress{
|
||||
LimitUSD: limit,
|
||||
UsedUSD: sub.MonthlyUsageUSD,
|
||||
RemainingUSD: limit - sub.MonthlyUsageUSD,
|
||||
Percentage: (sub.MonthlyUsageUSD / limit) * 100,
|
||||
WindowStart: *sub.MonthlyWindowStart,
|
||||
ResetsAt: resetsAt,
|
||||
ResetsInSeconds: int64(time.Until(resetsAt).Seconds()),
|
||||
}
|
||||
if progress.Monthly.RemainingUSD < 0 {
|
||||
progress.Monthly.RemainingUSD = 0
|
||||
}
|
||||
if progress.Monthly.Percentage > 100 {
|
||||
progress.Monthly.Percentage = 100
|
||||
}
|
||||
if progress.Monthly.ResetsInSeconds < 0 {
|
||||
progress.Monthly.ResetsInSeconds = 0
|
||||
}
|
||||
}
|
||||
|
||||
return progress, nil
|
||||
}
|
||||
|
||||
// GetUserSubscriptionsWithProgress 获取用户所有订阅及进度
|
||||
func (s *SubscriptionService) GetUserSubscriptionsWithProgress(ctx context.Context, userID int64) ([]SubscriptionProgress, error) {
|
||||
subs, err := s.repos.UserSubscription.ListActiveByUserID(ctx, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
progresses := make([]SubscriptionProgress, 0, len(subs))
|
||||
for _, sub := range subs {
|
||||
progress, err := s.GetSubscriptionProgress(ctx, sub.ID)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
progresses = append(progresses, *progress)
|
||||
}
|
||||
|
||||
return progresses, nil
|
||||
}
|
||||
|
||||
// UpdateExpiredSubscriptions 更新过期订阅状态(定时任务调用)
|
||||
func (s *SubscriptionService) UpdateExpiredSubscriptions(ctx context.Context) (int64, error) {
|
||||
return s.repos.UserSubscription.BatchUpdateExpiredStatus(ctx)
|
||||
}
|
||||
|
||||
// ValidateSubscription 验证订阅是否有效
|
||||
func (s *SubscriptionService) ValidateSubscription(ctx context.Context, sub *model.UserSubscription) error {
|
||||
if sub.Status == model.SubscriptionStatusExpired {
|
||||
return ErrSubscriptionExpired
|
||||
}
|
||||
if sub.Status == model.SubscriptionStatusSuspended {
|
||||
return ErrSubscriptionSuspended
|
||||
}
|
||||
if sub.IsExpired() {
|
||||
// 更新状态
|
||||
_ = s.repos.UserSubscription.UpdateStatus(ctx, sub.ID, model.SubscriptionStatusExpired)
|
||||
return ErrSubscriptionExpired
|
||||
}
|
||||
return nil
|
||||
}
|
||||
111
backend/internal/service/turnstile_service.go
Normal file
111
backend/internal/service/turnstile_service.go
Normal file
@@ -0,0 +1,111 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrTurnstileVerificationFailed = errors.New("turnstile verification failed")
|
||||
ErrTurnstileNotConfigured = errors.New("turnstile not configured")
|
||||
)
|
||||
|
||||
const turnstileVerifyURL = "https://challenges.cloudflare.com/turnstile/v0/siteverify"
|
||||
|
||||
// TurnstileService Turnstile 验证服务
|
||||
type TurnstileService struct {
|
||||
settingService *SettingService
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
// TurnstileVerifyResponse Cloudflare Turnstile 验证响应
|
||||
type TurnstileVerifyResponse struct {
|
||||
Success bool `json:"success"`
|
||||
ChallengeTS string `json:"challenge_ts"`
|
||||
Hostname string `json:"hostname"`
|
||||
ErrorCodes []string `json:"error-codes"`
|
||||
Action string `json:"action"`
|
||||
CData string `json:"cdata"`
|
||||
}
|
||||
|
||||
// NewTurnstileService 创建 Turnstile 服务实例
|
||||
func NewTurnstileService(settingService *SettingService) *TurnstileService {
|
||||
return &TurnstileService{
|
||||
settingService: settingService,
|
||||
httpClient: &http.Client{
|
||||
Timeout: 10 * time.Second,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// VerifyToken 验证 Turnstile token
|
||||
func (s *TurnstileService) VerifyToken(ctx context.Context, token string, remoteIP string) error {
|
||||
// 检查是否启用 Turnstile
|
||||
if !s.settingService.IsTurnstileEnabled(ctx) {
|
||||
log.Println("[Turnstile] Disabled, skipping verification")
|
||||
return nil
|
||||
}
|
||||
|
||||
// 获取 Secret Key
|
||||
secretKey := s.settingService.GetTurnstileSecretKey(ctx)
|
||||
if secretKey == "" {
|
||||
log.Println("[Turnstile] Secret key not configured")
|
||||
return ErrTurnstileNotConfigured
|
||||
}
|
||||
|
||||
// 如果 token 为空,返回错误
|
||||
if token == "" {
|
||||
log.Println("[Turnstile] Token is empty")
|
||||
return ErrTurnstileVerificationFailed
|
||||
}
|
||||
|
||||
// 构建请求
|
||||
formData := url.Values{}
|
||||
formData.Set("secret", secretKey)
|
||||
formData.Set("response", token)
|
||||
if remoteIP != "" {
|
||||
formData.Set("remoteip", remoteIP)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, turnstileVerifyURL, strings.NewReader(formData.Encode()))
|
||||
if err != nil {
|
||||
return fmt.Errorf("create request: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
|
||||
// 发送请求
|
||||
log.Printf("[Turnstile] Verifying token for IP: %s", remoteIP)
|
||||
resp, err := s.httpClient.Do(req)
|
||||
if err != nil {
|
||||
log.Printf("[Turnstile] Request failed: %v", err)
|
||||
return fmt.Errorf("send request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// 解析响应
|
||||
var result TurnstileVerifyResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
||||
log.Printf("[Turnstile] Failed to decode response: %v", err)
|
||||
return fmt.Errorf("decode response: %w", err)
|
||||
}
|
||||
|
||||
if !result.Success {
|
||||
log.Printf("[Turnstile] Verification failed, error codes: %v", result.ErrorCodes)
|
||||
return ErrTurnstileVerificationFailed
|
||||
}
|
||||
|
||||
log.Println("[Turnstile] Verification successful")
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsEnabled 检查 Turnstile 是否启用
|
||||
func (s *TurnstileService) IsEnabled(ctx context.Context) bool {
|
||||
return s.settingService.IsTurnstileEnabled(ctx)
|
||||
}
|
||||
621
backend/internal/service/update_service.go
Normal file
621
backend/internal/service/update_service.go
Normal file
@@ -0,0 +1,621 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"archive/tar"
|
||||
"bufio"
|
||||
"compress/gzip"
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
const (
|
||||
updateCacheKey = "update_check_cache"
|
||||
updateCacheTTL = 1200 // 20 minutes
|
||||
githubRepo = "Wei-Shaw/sub2api"
|
||||
|
||||
// Security: allowed download domains for updates
|
||||
allowedDownloadHost = "github.com"
|
||||
allowedAssetHost = "objects.githubusercontent.com"
|
||||
|
||||
// Security: max download size (500MB)
|
||||
maxDownloadSize = 500 * 1024 * 1024
|
||||
)
|
||||
|
||||
// UpdateService handles software updates
|
||||
type UpdateService struct {
|
||||
rdb *redis.Client
|
||||
currentVersion string
|
||||
buildType string // "source" for manual builds, "release" for CI builds
|
||||
}
|
||||
|
||||
// NewUpdateService creates a new UpdateService
|
||||
func NewUpdateService(rdb *redis.Client, version, buildType string) *UpdateService {
|
||||
return &UpdateService{
|
||||
rdb: rdb,
|
||||
currentVersion: version,
|
||||
buildType: buildType,
|
||||
}
|
||||
}
|
||||
|
||||
// UpdateInfo contains update information
|
||||
type UpdateInfo struct {
|
||||
CurrentVersion string `json:"current_version"`
|
||||
LatestVersion string `json:"latest_version"`
|
||||
HasUpdate bool `json:"has_update"`
|
||||
ReleaseInfo *ReleaseInfo `json:"release_info,omitempty"`
|
||||
Cached bool `json:"cached"`
|
||||
Warning string `json:"warning,omitempty"`
|
||||
BuildType string `json:"build_type"` // "source" or "release"
|
||||
}
|
||||
|
||||
// ReleaseInfo contains GitHub release details
|
||||
type ReleaseInfo struct {
|
||||
Name string `json:"name"`
|
||||
Body string `json:"body"`
|
||||
PublishedAt string `json:"published_at"`
|
||||
HtmlURL string `json:"html_url"`
|
||||
Assets []Asset `json:"assets,omitempty"`
|
||||
}
|
||||
|
||||
// Asset represents a release asset
|
||||
type Asset struct {
|
||||
Name string `json:"name"`
|
||||
DownloadURL string `json:"download_url"`
|
||||
Size int64 `json:"size"`
|
||||
}
|
||||
|
||||
// GitHubRelease represents GitHub API response
|
||||
type GitHubRelease struct {
|
||||
TagName string `json:"tag_name"`
|
||||
Name string `json:"name"`
|
||||
Body string `json:"body"`
|
||||
PublishedAt string `json:"published_at"`
|
||||
HtmlUrl string `json:"html_url"`
|
||||
Assets []GitHubAsset `json:"assets"`
|
||||
}
|
||||
|
||||
type GitHubAsset struct {
|
||||
Name string `json:"name"`
|
||||
BrowserDownloadUrl string `json:"browser_download_url"`
|
||||
Size int64 `json:"size"`
|
||||
}
|
||||
|
||||
// CheckUpdate checks for available updates
|
||||
func (s *UpdateService) CheckUpdate(ctx context.Context, force bool) (*UpdateInfo, error) {
|
||||
// Try cache first
|
||||
if !force {
|
||||
if cached, err := s.getFromCache(ctx); err == nil && cached != nil {
|
||||
return cached, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Fetch from GitHub
|
||||
info, err := s.fetchLatestRelease(ctx)
|
||||
if err != nil {
|
||||
// Return cached on error
|
||||
if cached, cacheErr := s.getFromCache(ctx); cacheErr == nil && cached != nil {
|
||||
cached.Warning = "Using cached data: " + err.Error()
|
||||
return cached, nil
|
||||
}
|
||||
return &UpdateInfo{
|
||||
CurrentVersion: s.currentVersion,
|
||||
LatestVersion: s.currentVersion,
|
||||
HasUpdate: false,
|
||||
Warning: err.Error(),
|
||||
BuildType: s.buildType,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Cache result
|
||||
s.saveToCache(ctx, info)
|
||||
return info, nil
|
||||
}
|
||||
|
||||
// PerformUpdate downloads and applies the update
|
||||
func (s *UpdateService) PerformUpdate(ctx context.Context) error {
|
||||
info, err := s.CheckUpdate(ctx, true)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !info.HasUpdate {
|
||||
return fmt.Errorf("no update available")
|
||||
}
|
||||
|
||||
// Find matching archive and checksum for current platform
|
||||
archiveName := s.getArchiveName()
|
||||
var downloadURL string
|
||||
var checksumURL string
|
||||
|
||||
for _, asset := range info.ReleaseInfo.Assets {
|
||||
if strings.Contains(asset.Name, archiveName) && !strings.HasSuffix(asset.Name, ".txt") {
|
||||
downloadURL = asset.DownloadURL
|
||||
}
|
||||
if asset.Name == "checksums.txt" {
|
||||
checksumURL = asset.DownloadURL
|
||||
}
|
||||
}
|
||||
|
||||
if downloadURL == "" {
|
||||
return fmt.Errorf("no compatible release found for %s/%s", runtime.GOOS, runtime.GOARCH)
|
||||
}
|
||||
|
||||
// SECURITY: Validate download URL is from trusted domain
|
||||
if err := validateDownloadURL(downloadURL); err != nil {
|
||||
return fmt.Errorf("invalid download URL: %w", err)
|
||||
}
|
||||
if checksumURL != "" {
|
||||
if err := validateDownloadURL(checksumURL); err != nil {
|
||||
return fmt.Errorf("invalid checksum URL: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Get current executable path
|
||||
exePath, err := os.Executable()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get executable path: %w", err)
|
||||
}
|
||||
exePath, err = filepath.EvalSymlinks(exePath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to resolve symlinks: %w", err)
|
||||
}
|
||||
|
||||
// Create temp directory for extraction
|
||||
tempDir, err := os.MkdirTemp("", "sub2api-update-*")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create temp dir: %w", err)
|
||||
}
|
||||
defer os.RemoveAll(tempDir)
|
||||
|
||||
// Download archive
|
||||
archivePath := filepath.Join(tempDir, filepath.Base(downloadURL))
|
||||
if err := s.downloadFile(ctx, downloadURL, archivePath); err != nil {
|
||||
return fmt.Errorf("download failed: %w", err)
|
||||
}
|
||||
|
||||
// Verify checksum if available
|
||||
if checksumURL != "" {
|
||||
if err := s.verifyChecksum(ctx, archivePath, checksumURL); err != nil {
|
||||
return fmt.Errorf("checksum verification failed: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Extract binary from archive
|
||||
newBinaryPath := filepath.Join(tempDir, "sub2api")
|
||||
if err := s.extractBinary(archivePath, newBinaryPath); err != nil {
|
||||
return fmt.Errorf("extraction failed: %w", err)
|
||||
}
|
||||
|
||||
// Backup current binary
|
||||
backupFile := exePath + ".backup"
|
||||
if err := os.Rename(exePath, backupFile); err != nil {
|
||||
return fmt.Errorf("backup failed: %w", err)
|
||||
}
|
||||
|
||||
// Replace with new binary
|
||||
if err := copyFile(newBinaryPath, exePath); err != nil {
|
||||
os.Rename(backupFile, exePath)
|
||||
return fmt.Errorf("replace failed: %w", err)
|
||||
}
|
||||
|
||||
// Make executable
|
||||
if err := os.Chmod(exePath, 0755); err != nil {
|
||||
return fmt.Errorf("chmod failed: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Rollback restores the previous version
|
||||
func (s *UpdateService) Rollback() error {
|
||||
exePath, err := os.Executable()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get executable path: %w", err)
|
||||
}
|
||||
exePath, err = filepath.EvalSymlinks(exePath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to resolve symlinks: %w", err)
|
||||
}
|
||||
|
||||
backupFile := exePath + ".backup"
|
||||
if _, err := os.Stat(backupFile); os.IsNotExist(err) {
|
||||
return fmt.Errorf("no backup found")
|
||||
}
|
||||
|
||||
// Replace current with backup
|
||||
if err := os.Rename(backupFile, exePath); err != nil {
|
||||
return fmt.Errorf("rollback failed: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RestartService triggers a service restart via systemd
|
||||
func (s *UpdateService) RestartService() error {
|
||||
if runtime.GOOS != "linux" {
|
||||
return fmt.Errorf("systemd restart only available on Linux")
|
||||
}
|
||||
|
||||
// Try direct systemctl first (works if running as root or with proper permissions)
|
||||
cmd := exec.Command("systemctl", "restart", "sub2api")
|
||||
if err := cmd.Run(); err != nil {
|
||||
// Try with sudo (requires NOPASSWD sudoers entry)
|
||||
sudoCmd := exec.Command("sudo", "systemctl", "restart", "sub2api")
|
||||
if sudoErr := sudoCmd.Run(); sudoErr != nil {
|
||||
return fmt.Errorf("systemctl restart failed: %w (sudo also failed: %v)", err, sudoErr)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *UpdateService) fetchLatestRelease(ctx context.Context) (*UpdateInfo, error) {
|
||||
url := fmt.Sprintf("https://api.github.com/repos/%s/releases/latest", githubRepo)
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Accept", "application/vnd.github.v3+json")
|
||||
req.Header.Set("User-Agent", "Sub2API-Updater")
|
||||
|
||||
client := &http.Client{Timeout: 30 * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode == http.StatusNotFound {
|
||||
return &UpdateInfo{
|
||||
CurrentVersion: s.currentVersion,
|
||||
LatestVersion: s.currentVersion,
|
||||
HasUpdate: false,
|
||||
Warning: "No releases found",
|
||||
BuildType: s.buildType,
|
||||
}, nil
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("GitHub API returned %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
var release GitHubRelease
|
||||
if err := json.NewDecoder(resp.Body).Decode(&release); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
latestVersion := strings.TrimPrefix(release.TagName, "v")
|
||||
|
||||
assets := make([]Asset, len(release.Assets))
|
||||
for i, a := range release.Assets {
|
||||
assets[i] = Asset{
|
||||
Name: a.Name,
|
||||
DownloadURL: a.BrowserDownloadUrl,
|
||||
Size: a.Size,
|
||||
}
|
||||
}
|
||||
|
||||
return &UpdateInfo{
|
||||
CurrentVersion: s.currentVersion,
|
||||
LatestVersion: latestVersion,
|
||||
HasUpdate: compareVersions(s.currentVersion, latestVersion) < 0,
|
||||
ReleaseInfo: &ReleaseInfo{
|
||||
Name: release.Name,
|
||||
Body: release.Body,
|
||||
PublishedAt: release.PublishedAt,
|
||||
HtmlURL: release.HtmlUrl,
|
||||
Assets: assets,
|
||||
},
|
||||
Cached: false,
|
||||
BuildType: s.buildType,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *UpdateService) downloadFile(ctx context.Context, downloadURL, dest string) error {
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", downloadURL, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
client := &http.Client{Timeout: 10 * time.Minute}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return fmt.Errorf("download returned %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
// SECURITY: Check Content-Length if available
|
||||
if resp.ContentLength > maxDownloadSize {
|
||||
return fmt.Errorf("file too large: %d bytes (max %d)", resp.ContentLength, maxDownloadSize)
|
||||
}
|
||||
|
||||
out, err := os.Create(dest)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer out.Close()
|
||||
|
||||
// SECURITY: Use LimitReader to enforce max download size even if Content-Length is missing/wrong
|
||||
limited := io.LimitReader(resp.Body, maxDownloadSize+1)
|
||||
written, err := io.Copy(out, limited)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Check if we hit the limit (downloaded more than maxDownloadSize)
|
||||
if written > maxDownloadSize {
|
||||
os.Remove(dest) // Clean up partial file
|
||||
return fmt.Errorf("download exceeded maximum size of %d bytes", maxDownloadSize)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *UpdateService) getArchiveName() string {
|
||||
osName := runtime.GOOS
|
||||
arch := runtime.GOARCH
|
||||
return fmt.Sprintf("%s_%s", osName, arch)
|
||||
}
|
||||
|
||||
// validateDownloadURL checks if the URL is from an allowed domain
|
||||
// SECURITY: This prevents SSRF and ensures downloads only come from trusted GitHub domains
|
||||
func validateDownloadURL(rawURL string) error {
|
||||
parsedURL, err := url.Parse(rawURL)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid URL: %w", err)
|
||||
}
|
||||
|
||||
// Must be HTTPS
|
||||
if parsedURL.Scheme != "https" {
|
||||
return fmt.Errorf("only HTTPS URLs are allowed")
|
||||
}
|
||||
|
||||
// Check against allowed hosts
|
||||
host := parsedURL.Host
|
||||
// GitHub release URLs can be from github.com or objects.githubusercontent.com
|
||||
if host != allowedDownloadHost &&
|
||||
!strings.HasSuffix(host, "."+allowedDownloadHost) &&
|
||||
host != allowedAssetHost &&
|
||||
!strings.HasSuffix(host, "."+allowedAssetHost) {
|
||||
return fmt.Errorf("download from untrusted host: %s", host)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *UpdateService) verifyChecksum(ctx context.Context, filePath, checksumURL string) error {
|
||||
// Download checksums file
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", checksumURL, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
client := &http.Client{Timeout: 30 * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return fmt.Errorf("failed to download checksums: %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
// Calculate file hash
|
||||
f, err := os.Open(filePath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
h := sha256.New()
|
||||
if _, err := io.Copy(h, f); err != nil {
|
||||
return err
|
||||
}
|
||||
actualHash := hex.EncodeToString(h.Sum(nil))
|
||||
|
||||
// Find expected hash in checksums file
|
||||
fileName := filepath.Base(filePath)
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
parts := strings.Fields(line)
|
||||
if len(parts) == 2 && parts[1] == fileName {
|
||||
if parts[0] == actualHash {
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("checksum mismatch: expected %s, got %s", parts[0], actualHash)
|
||||
}
|
||||
}
|
||||
|
||||
return fmt.Errorf("checksum not found for %s", fileName)
|
||||
}
|
||||
|
||||
func (s *UpdateService) extractBinary(archivePath, destPath string) error {
|
||||
f, err := os.Open(archivePath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
var reader io.Reader = f
|
||||
|
||||
// Handle gzip compression
|
||||
if strings.HasSuffix(archivePath, ".gz") || strings.HasSuffix(archivePath, ".tar.gz") || strings.HasSuffix(archivePath, ".tgz") {
|
||||
gzr, err := gzip.NewReader(f)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer gzr.Close()
|
||||
reader = gzr
|
||||
}
|
||||
|
||||
// Handle tar archive
|
||||
if strings.Contains(archivePath, ".tar") {
|
||||
tr := tar.NewReader(reader)
|
||||
for {
|
||||
hdr, err := tr.Next()
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// SECURITY: Prevent Zip Slip / Path Traversal attack
|
||||
// Only allow files with safe base names, no directory traversal
|
||||
baseName := filepath.Base(hdr.Name)
|
||||
|
||||
// Check for path traversal attempts
|
||||
if strings.Contains(hdr.Name, "..") {
|
||||
return fmt.Errorf("path traversal attempt detected: %s", hdr.Name)
|
||||
}
|
||||
|
||||
// Validate the entry is a regular file
|
||||
if hdr.Typeflag != tar.TypeReg {
|
||||
continue // Skip directories and special files
|
||||
}
|
||||
|
||||
// Only extract the specific binary we need
|
||||
if baseName == "sub2api" || baseName == "sub2api.exe" {
|
||||
// Additional security: limit file size (max 500MB)
|
||||
const maxBinarySize = 500 * 1024 * 1024
|
||||
if hdr.Size > maxBinarySize {
|
||||
return fmt.Errorf("binary too large: %d bytes (max %d)", hdr.Size, maxBinarySize)
|
||||
}
|
||||
|
||||
out, err := os.Create(destPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Use LimitReader to prevent decompression bombs
|
||||
limited := io.LimitReader(tr, maxBinarySize)
|
||||
if _, err := io.Copy(out, limited); err != nil {
|
||||
out.Close()
|
||||
return err
|
||||
}
|
||||
out.Close()
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("binary not found in archive")
|
||||
}
|
||||
|
||||
// Direct copy for non-tar files (with size limit)
|
||||
const maxBinarySize = 500 * 1024 * 1024
|
||||
out, err := os.Create(destPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer out.Close()
|
||||
|
||||
limited := io.LimitReader(reader, maxBinarySize)
|
||||
_, err = io.Copy(out, limited)
|
||||
return err
|
||||
}
|
||||
|
||||
func copyFile(src, dst string) error {
|
||||
in, err := os.Open(src)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer in.Close()
|
||||
|
||||
out, err := os.Create(dst)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer out.Close()
|
||||
|
||||
_, err = io.Copy(out, in)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *UpdateService) getFromCache(ctx context.Context) (*UpdateInfo, error) {
|
||||
data, err := s.rdb.Get(ctx, updateCacheKey).Result()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var cached struct {
|
||||
Latest string `json:"latest"`
|
||||
ReleaseInfo *ReleaseInfo `json:"release_info"`
|
||||
Timestamp int64 `json:"timestamp"`
|
||||
}
|
||||
if err := json.Unmarshal([]byte(data), &cached); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if time.Now().Unix()-cached.Timestamp > updateCacheTTL {
|
||||
return nil, fmt.Errorf("cache expired")
|
||||
}
|
||||
|
||||
return &UpdateInfo{
|
||||
CurrentVersion: s.currentVersion,
|
||||
LatestVersion: cached.Latest,
|
||||
HasUpdate: compareVersions(s.currentVersion, cached.Latest) < 0,
|
||||
ReleaseInfo: cached.ReleaseInfo,
|
||||
Cached: true,
|
||||
BuildType: s.buildType,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *UpdateService) saveToCache(ctx context.Context, info *UpdateInfo) {
|
||||
cacheData := struct {
|
||||
Latest string `json:"latest"`
|
||||
ReleaseInfo *ReleaseInfo `json:"release_info"`
|
||||
Timestamp int64 `json:"timestamp"`
|
||||
}{
|
||||
Latest: info.LatestVersion,
|
||||
ReleaseInfo: info.ReleaseInfo,
|
||||
Timestamp: time.Now().Unix(),
|
||||
}
|
||||
|
||||
data, _ := json.Marshal(cacheData)
|
||||
s.rdb.Set(ctx, updateCacheKey, data, time.Duration(updateCacheTTL)*time.Second)
|
||||
}
|
||||
|
||||
// compareVersions compares two semantic versions
|
||||
func compareVersions(current, latest string) int {
|
||||
currentParts := parseVersion(current)
|
||||
latestParts := parseVersion(latest)
|
||||
|
||||
for i := 0; i < 3; i++ {
|
||||
if currentParts[i] < latestParts[i] {
|
||||
return -1
|
||||
}
|
||||
if currentParts[i] > latestParts[i] {
|
||||
return 1
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func parseVersion(v string) [3]int {
|
||||
v = strings.TrimPrefix(v, "v")
|
||||
parts := strings.Split(v, ".")
|
||||
result := [3]int{0, 0, 0}
|
||||
for i := 0; i < len(parts) && i < 3; i++ {
|
||||
fmt.Sscanf(parts[i], "%d", &result[i])
|
||||
}
|
||||
return result
|
||||
}
|
||||
283
backend/internal/service/usage_service.go
Normal file
283
backend/internal/service/usage_service.go
Normal file
@@ -0,0 +1,283 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/repository"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrUsageLogNotFound = errors.New("usage log not found")
|
||||
)
|
||||
|
||||
// CreateUsageLogRequest 创建使用日志请求
|
||||
type CreateUsageLogRequest struct {
|
||||
UserID int64 `json:"user_id"`
|
||||
ApiKeyID int64 `json:"api_key_id"`
|
||||
AccountID int64 `json:"account_id"`
|
||||
RequestID string `json:"request_id"`
|
||||
Model string `json:"model"`
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
CacheCreationTokens int `json:"cache_creation_tokens"`
|
||||
CacheReadTokens int `json:"cache_read_tokens"`
|
||||
CacheCreation5mTokens int `json:"cache_creation_5m_tokens"`
|
||||
CacheCreation1hTokens int `json:"cache_creation_1h_tokens"`
|
||||
InputCost float64 `json:"input_cost"`
|
||||
OutputCost float64 `json:"output_cost"`
|
||||
CacheCreationCost float64 `json:"cache_creation_cost"`
|
||||
CacheReadCost float64 `json:"cache_read_cost"`
|
||||
TotalCost float64 `json:"total_cost"`
|
||||
ActualCost float64 `json:"actual_cost"`
|
||||
RateMultiplier float64 `json:"rate_multiplier"`
|
||||
Stream bool `json:"stream"`
|
||||
DurationMs *int `json:"duration_ms"`
|
||||
}
|
||||
|
||||
// UsageStats 使用统计
|
||||
type UsageStats struct {
|
||||
TotalRequests int64 `json:"total_requests"`
|
||||
TotalInputTokens int64 `json:"total_input_tokens"`
|
||||
TotalOutputTokens int64 `json:"total_output_tokens"`
|
||||
TotalCacheTokens int64 `json:"total_cache_tokens"`
|
||||
TotalTokens int64 `json:"total_tokens"`
|
||||
TotalCost float64 `json:"total_cost"`
|
||||
TotalActualCost float64 `json:"total_actual_cost"`
|
||||
AverageDurationMs float64 `json:"average_duration_ms"`
|
||||
}
|
||||
|
||||
// UsageService 使用统计服务
|
||||
type UsageService struct {
|
||||
usageRepo *repository.UsageLogRepository
|
||||
userRepo *repository.UserRepository
|
||||
}
|
||||
|
||||
// NewUsageService 创建使用统计服务实例
|
||||
func NewUsageService(usageRepo *repository.UsageLogRepository, userRepo *repository.UserRepository) *UsageService {
|
||||
return &UsageService{
|
||||
usageRepo: usageRepo,
|
||||
userRepo: userRepo,
|
||||
}
|
||||
}
|
||||
|
||||
// Create 创建使用日志
|
||||
func (s *UsageService) Create(ctx context.Context, req CreateUsageLogRequest) (*model.UsageLog, error) {
|
||||
// 验证用户存在
|
||||
_, err := s.userRepo.GetByID(ctx, req.UserID)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, ErrUserNotFound
|
||||
}
|
||||
return nil, fmt.Errorf("get user: %w", err)
|
||||
}
|
||||
|
||||
// 创建使用日志
|
||||
usageLog := &model.UsageLog{
|
||||
UserID: req.UserID,
|
||||
ApiKeyID: req.ApiKeyID,
|
||||
AccountID: req.AccountID,
|
||||
RequestID: req.RequestID,
|
||||
Model: req.Model,
|
||||
InputTokens: req.InputTokens,
|
||||
OutputTokens: req.OutputTokens,
|
||||
CacheCreationTokens: req.CacheCreationTokens,
|
||||
CacheReadTokens: req.CacheReadTokens,
|
||||
CacheCreation5mTokens: req.CacheCreation5mTokens,
|
||||
CacheCreation1hTokens: req.CacheCreation1hTokens,
|
||||
InputCost: req.InputCost,
|
||||
OutputCost: req.OutputCost,
|
||||
CacheCreationCost: req.CacheCreationCost,
|
||||
CacheReadCost: req.CacheReadCost,
|
||||
TotalCost: req.TotalCost,
|
||||
ActualCost: req.ActualCost,
|
||||
RateMultiplier: req.RateMultiplier,
|
||||
Stream: req.Stream,
|
||||
DurationMs: req.DurationMs,
|
||||
}
|
||||
|
||||
if err := s.usageRepo.Create(ctx, usageLog); err != nil {
|
||||
return nil, fmt.Errorf("create usage log: %w", err)
|
||||
}
|
||||
|
||||
// 扣除用户余额
|
||||
if req.ActualCost > 0 {
|
||||
if err := s.userRepo.UpdateBalance(ctx, req.UserID, -req.ActualCost); err != nil {
|
||||
return nil, fmt.Errorf("update user balance: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return usageLog, nil
|
||||
}
|
||||
|
||||
// GetByID 根据ID获取使用日志
|
||||
func (s *UsageService) GetByID(ctx context.Context, id int64) (*model.UsageLog, error) {
|
||||
log, err := s.usageRepo.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, ErrUsageLogNotFound
|
||||
}
|
||||
return nil, fmt.Errorf("get usage log: %w", err)
|
||||
}
|
||||
return log, nil
|
||||
}
|
||||
|
||||
// ListByUser 获取用户的使用日志列表
|
||||
func (s *UsageService) ListByUser(ctx context.Context, userID int64, params repository.PaginationParams) ([]model.UsageLog, *repository.PaginationResult, error) {
|
||||
logs, pagination, err := s.usageRepo.ListByUser(ctx, userID, params)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("list usage logs: %w", err)
|
||||
}
|
||||
return logs, pagination, nil
|
||||
}
|
||||
|
||||
// ListByApiKey 获取API Key的使用日志列表
|
||||
func (s *UsageService) ListByApiKey(ctx context.Context, apiKeyID int64, params repository.PaginationParams) ([]model.UsageLog, *repository.PaginationResult, error) {
|
||||
logs, pagination, err := s.usageRepo.ListByApiKey(ctx, apiKeyID, params)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("list usage logs: %w", err)
|
||||
}
|
||||
return logs, pagination, nil
|
||||
}
|
||||
|
||||
// ListByAccount 获取账号的使用日志列表
|
||||
func (s *UsageService) ListByAccount(ctx context.Context, accountID int64, params repository.PaginationParams) ([]model.UsageLog, *repository.PaginationResult, error) {
|
||||
logs, pagination, err := s.usageRepo.ListByAccount(ctx, accountID, params)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("list usage logs: %w", err)
|
||||
}
|
||||
return logs, pagination, nil
|
||||
}
|
||||
|
||||
// GetStatsByUser 获取用户的使用统计
|
||||
func (s *UsageService) GetStatsByUser(ctx context.Context, userID int64, startTime, endTime time.Time) (*UsageStats, error) {
|
||||
logs, _, err := s.usageRepo.ListByUserAndTimeRange(ctx, userID, startTime, endTime)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list usage logs: %w", err)
|
||||
}
|
||||
|
||||
return s.calculateStats(logs), nil
|
||||
}
|
||||
|
||||
// GetStatsByApiKey 获取API Key的使用统计
|
||||
func (s *UsageService) GetStatsByApiKey(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) (*UsageStats, error) {
|
||||
logs, _, err := s.usageRepo.ListByApiKeyAndTimeRange(ctx, apiKeyID, startTime, endTime)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list usage logs: %w", err)
|
||||
}
|
||||
|
||||
return s.calculateStats(logs), nil
|
||||
}
|
||||
|
||||
// GetStatsByAccount 获取账号的使用统计
|
||||
func (s *UsageService) GetStatsByAccount(ctx context.Context, accountID int64, startTime, endTime time.Time) (*UsageStats, error) {
|
||||
logs, _, err := s.usageRepo.ListByAccountAndTimeRange(ctx, accountID, startTime, endTime)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list usage logs: %w", err)
|
||||
}
|
||||
|
||||
return s.calculateStats(logs), nil
|
||||
}
|
||||
|
||||
// GetStatsByModel 获取模型的使用统计
|
||||
func (s *UsageService) GetStatsByModel(ctx context.Context, modelName string, startTime, endTime time.Time) (*UsageStats, error) {
|
||||
logs, _, err := s.usageRepo.ListByModelAndTimeRange(ctx, modelName, startTime, endTime)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list usage logs: %w", err)
|
||||
}
|
||||
|
||||
return s.calculateStats(logs), nil
|
||||
}
|
||||
|
||||
// GetDailyStats 获取每日使用统计(最近N天)
|
||||
func (s *UsageService) GetDailyStats(ctx context.Context, userID int64, days int) ([]map[string]interface{}, error) {
|
||||
endTime := time.Now()
|
||||
startTime := endTime.AddDate(0, 0, -days)
|
||||
|
||||
logs, _, err := s.usageRepo.ListByUserAndTimeRange(ctx, userID, startTime, endTime)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list usage logs: %w", err)
|
||||
}
|
||||
|
||||
// 按日期分组统计
|
||||
dailyStats := make(map[string]*UsageStats)
|
||||
for _, log := range logs {
|
||||
dateKey := log.CreatedAt.Format("2006-01-02")
|
||||
if _, exists := dailyStats[dateKey]; !exists {
|
||||
dailyStats[dateKey] = &UsageStats{}
|
||||
}
|
||||
|
||||
stats := dailyStats[dateKey]
|
||||
stats.TotalRequests++
|
||||
stats.TotalInputTokens += int64(log.InputTokens)
|
||||
stats.TotalOutputTokens += int64(log.OutputTokens)
|
||||
stats.TotalCacheTokens += int64(log.CacheCreationTokens + log.CacheReadTokens)
|
||||
stats.TotalTokens += int64(log.TotalTokens())
|
||||
stats.TotalCost += log.TotalCost
|
||||
stats.TotalActualCost += log.ActualCost
|
||||
|
||||
if log.DurationMs != nil {
|
||||
stats.AverageDurationMs += float64(*log.DurationMs)
|
||||
}
|
||||
}
|
||||
|
||||
// 计算平均值并转换为数组
|
||||
result := make([]map[string]interface{}, 0, len(dailyStats))
|
||||
for date, stats := range dailyStats {
|
||||
if stats.TotalRequests > 0 {
|
||||
stats.AverageDurationMs /= float64(stats.TotalRequests)
|
||||
}
|
||||
|
||||
result = append(result, map[string]interface{}{
|
||||
"date": date,
|
||||
"total_requests": stats.TotalRequests,
|
||||
"total_input_tokens": stats.TotalInputTokens,
|
||||
"total_output_tokens": stats.TotalOutputTokens,
|
||||
"total_cache_tokens": stats.TotalCacheTokens,
|
||||
"total_tokens": stats.TotalTokens,
|
||||
"total_cost": stats.TotalCost,
|
||||
"total_actual_cost": stats.TotalActualCost,
|
||||
"average_duration_ms": stats.AverageDurationMs,
|
||||
})
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// calculateStats 计算统计数据
|
||||
func (s *UsageService) calculateStats(logs []model.UsageLog) *UsageStats {
|
||||
stats := &UsageStats{}
|
||||
|
||||
for _, log := range logs {
|
||||
stats.TotalRequests++
|
||||
stats.TotalInputTokens += int64(log.InputTokens)
|
||||
stats.TotalOutputTokens += int64(log.OutputTokens)
|
||||
stats.TotalCacheTokens += int64(log.CacheCreationTokens + log.CacheReadTokens)
|
||||
stats.TotalTokens += int64(log.TotalTokens())
|
||||
stats.TotalCost += log.TotalCost
|
||||
stats.TotalActualCost += log.ActualCost
|
||||
|
||||
if log.DurationMs != nil {
|
||||
stats.AverageDurationMs += float64(*log.DurationMs)
|
||||
}
|
||||
}
|
||||
|
||||
// 计算平均持续时间
|
||||
if stats.TotalRequests > 0 {
|
||||
stats.AverageDurationMs /= float64(stats.TotalRequests)
|
||||
}
|
||||
|
||||
return stats
|
||||
}
|
||||
|
||||
// Delete 删除使用日志(管理员功能,谨慎使用)
|
||||
func (s *UsageService) Delete(ctx context.Context, id int64) error {
|
||||
if err := s.usageRepo.Delete(ctx, id); err != nil {
|
||||
return fmt.Errorf("delete usage log: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
177
backend/internal/service/user_service.go
Normal file
177
backend/internal/service/user_service.go
Normal file
@@ -0,0 +1,177 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sub2api/internal/config"
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/repository"
|
||||
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrUserNotFound = errors.New("user not found")
|
||||
ErrPasswordIncorrect = errors.New("current password is incorrect")
|
||||
ErrInsufficientPerms = errors.New("insufficient permissions")
|
||||
)
|
||||
|
||||
// UpdateProfileRequest 更新用户资料请求
|
||||
type UpdateProfileRequest struct {
|
||||
Email *string `json:"email"`
|
||||
Concurrency *int `json:"concurrency"`
|
||||
}
|
||||
|
||||
// ChangePasswordRequest 修改密码请求
|
||||
type ChangePasswordRequest struct {
|
||||
CurrentPassword string `json:"current_password"`
|
||||
NewPassword string `json:"new_password"`
|
||||
}
|
||||
|
||||
// UserService 用户服务
|
||||
type UserService struct {
|
||||
userRepo *repository.UserRepository
|
||||
cfg *config.Config
|
||||
}
|
||||
|
||||
// NewUserService 创建用户服务实例
|
||||
func NewUserService(userRepo *repository.UserRepository, cfg *config.Config) *UserService {
|
||||
return &UserService{
|
||||
userRepo: userRepo,
|
||||
cfg: cfg,
|
||||
}
|
||||
}
|
||||
|
||||
// GetProfile 获取用户资料
|
||||
func (s *UserService) GetProfile(ctx context.Context, userID int64) (*model.User, error) {
|
||||
user, err := s.userRepo.GetByID(ctx, userID)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, ErrUserNotFound
|
||||
}
|
||||
return nil, fmt.Errorf("get user: %w", err)
|
||||
}
|
||||
return user, nil
|
||||
}
|
||||
|
||||
// UpdateProfile 更新用户资料
|
||||
func (s *UserService) UpdateProfile(ctx context.Context, userID int64, req UpdateProfileRequest) (*model.User, error) {
|
||||
user, err := s.userRepo.GetByID(ctx, userID)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, ErrUserNotFound
|
||||
}
|
||||
return nil, fmt.Errorf("get user: %w", err)
|
||||
}
|
||||
|
||||
// 更新字段
|
||||
if req.Email != nil {
|
||||
// 检查新邮箱是否已被使用
|
||||
exists, err := s.userRepo.ExistsByEmail(ctx, *req.Email)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("check email exists: %w", err)
|
||||
}
|
||||
if exists && *req.Email != user.Email {
|
||||
return nil, ErrEmailExists
|
||||
}
|
||||
user.Email = *req.Email
|
||||
}
|
||||
|
||||
if req.Concurrency != nil {
|
||||
user.Concurrency = *req.Concurrency
|
||||
}
|
||||
|
||||
if err := s.userRepo.Update(ctx, user); err != nil {
|
||||
return nil, fmt.Errorf("update user: %w", err)
|
||||
}
|
||||
|
||||
return user, nil
|
||||
}
|
||||
|
||||
// ChangePassword 修改密码
|
||||
func (s *UserService) ChangePassword(ctx context.Context, userID int64, req ChangePasswordRequest) error {
|
||||
user, err := s.userRepo.GetByID(ctx, userID)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return ErrUserNotFound
|
||||
}
|
||||
return fmt.Errorf("get user: %w", err)
|
||||
}
|
||||
|
||||
// 验证当前密码
|
||||
if err := bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(req.CurrentPassword)); err != nil {
|
||||
return ErrPasswordIncorrect
|
||||
}
|
||||
|
||||
// 生成新密码哈希
|
||||
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(req.NewPassword), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
return fmt.Errorf("hash password: %w", err)
|
||||
}
|
||||
|
||||
user.PasswordHash = string(hashedPassword)
|
||||
|
||||
if err := s.userRepo.Update(ctx, user); err != nil {
|
||||
return fmt.Errorf("update user: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetByID 根据ID获取用户(管理员功能)
|
||||
func (s *UserService) GetByID(ctx context.Context, id int64) (*model.User, error) {
|
||||
user, err := s.userRepo.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, ErrUserNotFound
|
||||
}
|
||||
return nil, fmt.Errorf("get user: %w", err)
|
||||
}
|
||||
return user, nil
|
||||
}
|
||||
|
||||
// List 获取用户列表(管理员功能)
|
||||
func (s *UserService) List(ctx context.Context, params repository.PaginationParams) ([]model.User, *repository.PaginationResult, error) {
|
||||
users, pagination, err := s.userRepo.List(ctx, params)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("list users: %w", err)
|
||||
}
|
||||
return users, pagination, nil
|
||||
}
|
||||
|
||||
// UpdateBalance 更新用户余额(管理员功能)
|
||||
func (s *UserService) UpdateBalance(ctx context.Context, userID int64, amount float64) error {
|
||||
if err := s.userRepo.UpdateBalance(ctx, userID, amount); err != nil {
|
||||
return fmt.Errorf("update balance: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateStatus 更新用户状态(管理员功能)
|
||||
func (s *UserService) UpdateStatus(ctx context.Context, userID int64, status string) error {
|
||||
user, err := s.userRepo.GetByID(ctx, userID)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return ErrUserNotFound
|
||||
}
|
||||
return fmt.Errorf("get user: %w", err)
|
||||
}
|
||||
|
||||
user.Status = status
|
||||
|
||||
if err := s.userRepo.Update(ctx, user); err != nil {
|
||||
return fmt.Errorf("update user: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Delete 删除用户(管理员功能)
|
||||
func (s *UserService) Delete(ctx context.Context, userID int64) error {
|
||||
if err := s.userRepo.Delete(ctx, userID); err != nil {
|
||||
return fmt.Errorf("delete user: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
294
backend/internal/setup/cli.go
Normal file
294
backend/internal/setup/cli.go
Normal file
@@ -0,0 +1,294 @@
|
||||
package setup
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"fmt"
|
||||
"net/mail"
|
||||
"os"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"golang.org/x/term"
|
||||
)
|
||||
|
||||
// CLI input validation functions (matching Web API validation)
|
||||
func cliValidateHostname(host string) bool {
|
||||
validHost := regexp.MustCompile(`^[a-zA-Z0-9.\-:]+$`)
|
||||
return validHost.MatchString(host) && len(host) <= 253
|
||||
}
|
||||
|
||||
func cliValidateDBName(name string) bool {
|
||||
validName := regexp.MustCompile(`^[a-zA-Z][a-zA-Z0-9_]*$`)
|
||||
return validName.MatchString(name) && len(name) <= 63
|
||||
}
|
||||
|
||||
func cliValidateUsername(name string) bool {
|
||||
validName := regexp.MustCompile(`^[a-zA-Z0-9_]+$`)
|
||||
return validName.MatchString(name) && len(name) <= 63
|
||||
}
|
||||
|
||||
func cliValidateEmail(email string) bool {
|
||||
_, err := mail.ParseAddress(email)
|
||||
return err == nil && len(email) <= 254
|
||||
}
|
||||
|
||||
func cliValidatePort(port int) bool {
|
||||
return port > 0 && port <= 65535
|
||||
}
|
||||
|
||||
func cliValidateSSLMode(mode string) bool {
|
||||
validModes := map[string]bool{
|
||||
"disable": true, "require": true, "verify-ca": true, "verify-full": true,
|
||||
}
|
||||
return validModes[mode]
|
||||
}
|
||||
|
||||
// RunCLI runs the CLI setup wizard
|
||||
func RunCLI() error {
|
||||
reader := bufio.NewReader(os.Stdin)
|
||||
|
||||
fmt.Println()
|
||||
fmt.Println("╔═══════════════════════════════════════════╗")
|
||||
fmt.Println("║ Sub2API Installation Wizard ║")
|
||||
fmt.Println("╚═══════════════════════════════════════════╝")
|
||||
fmt.Println()
|
||||
|
||||
cfg := &SetupConfig{
|
||||
Server: ServerConfig{
|
||||
Host: "0.0.0.0",
|
||||
Port: 8080,
|
||||
Mode: "release",
|
||||
},
|
||||
JWT: JWTConfig{
|
||||
ExpireHour: 24,
|
||||
},
|
||||
}
|
||||
|
||||
// Database configuration with validation
|
||||
fmt.Println("── Database Configuration ──")
|
||||
|
||||
for {
|
||||
cfg.Database.Host = promptString(reader, "PostgreSQL Host", "localhost")
|
||||
if cliValidateHostname(cfg.Database.Host) {
|
||||
break
|
||||
}
|
||||
fmt.Println(" Invalid hostname format. Use alphanumeric, dots, hyphens only.")
|
||||
}
|
||||
|
||||
for {
|
||||
cfg.Database.Port = promptInt(reader, "PostgreSQL Port", 5432)
|
||||
if cliValidatePort(cfg.Database.Port) {
|
||||
break
|
||||
}
|
||||
fmt.Println(" Invalid port. Must be between 1 and 65535.")
|
||||
}
|
||||
|
||||
for {
|
||||
cfg.Database.User = promptString(reader, "PostgreSQL User", "postgres")
|
||||
if cliValidateUsername(cfg.Database.User) {
|
||||
break
|
||||
}
|
||||
fmt.Println(" Invalid username. Use alphanumeric and underscores only.")
|
||||
}
|
||||
|
||||
cfg.Database.Password = promptPassword("PostgreSQL Password")
|
||||
|
||||
for {
|
||||
cfg.Database.DBName = promptString(reader, "Database Name", "sub2api")
|
||||
if cliValidateDBName(cfg.Database.DBName) {
|
||||
break
|
||||
}
|
||||
fmt.Println(" Invalid database name. Start with letter, use alphanumeric and underscores.")
|
||||
}
|
||||
|
||||
for {
|
||||
cfg.Database.SSLMode = promptString(reader, "SSL Mode", "disable")
|
||||
if cliValidateSSLMode(cfg.Database.SSLMode) {
|
||||
break
|
||||
}
|
||||
fmt.Println(" Invalid SSL mode. Use: disable, require, verify-ca, or verify-full.")
|
||||
}
|
||||
|
||||
fmt.Println()
|
||||
fmt.Print("Testing database connection... ")
|
||||
if err := TestDatabaseConnection(&cfg.Database); err != nil {
|
||||
fmt.Println("FAILED")
|
||||
return fmt.Errorf("database connection failed: %w", err)
|
||||
}
|
||||
fmt.Println("OK")
|
||||
|
||||
// Redis configuration with validation
|
||||
fmt.Println()
|
||||
fmt.Println("── Redis Configuration ──")
|
||||
|
||||
for {
|
||||
cfg.Redis.Host = promptString(reader, "Redis Host", "localhost")
|
||||
if cliValidateHostname(cfg.Redis.Host) {
|
||||
break
|
||||
}
|
||||
fmt.Println(" Invalid hostname format. Use alphanumeric, dots, hyphens only.")
|
||||
}
|
||||
|
||||
for {
|
||||
cfg.Redis.Port = promptInt(reader, "Redis Port", 6379)
|
||||
if cliValidatePort(cfg.Redis.Port) {
|
||||
break
|
||||
}
|
||||
fmt.Println(" Invalid port. Must be between 1 and 65535.")
|
||||
}
|
||||
|
||||
cfg.Redis.Password = promptPassword("Redis Password (optional)")
|
||||
|
||||
for {
|
||||
cfg.Redis.DB = promptInt(reader, "Redis DB", 0)
|
||||
if cfg.Redis.DB >= 0 && cfg.Redis.DB <= 15 {
|
||||
break
|
||||
}
|
||||
fmt.Println(" Invalid Redis DB. Must be between 0 and 15.")
|
||||
}
|
||||
|
||||
fmt.Println()
|
||||
fmt.Print("Testing Redis connection... ")
|
||||
if err := TestRedisConnection(&cfg.Redis); err != nil {
|
||||
fmt.Println("FAILED")
|
||||
return fmt.Errorf("redis connection failed: %w", err)
|
||||
}
|
||||
fmt.Println("OK")
|
||||
|
||||
// Admin configuration with validation
|
||||
fmt.Println()
|
||||
fmt.Println("── Admin Account ──")
|
||||
|
||||
for {
|
||||
cfg.Admin.Email = promptString(reader, "Admin Email", "admin@example.com")
|
||||
if cliValidateEmail(cfg.Admin.Email) {
|
||||
break
|
||||
}
|
||||
fmt.Println(" Invalid email format.")
|
||||
}
|
||||
|
||||
for {
|
||||
cfg.Admin.Password = promptPassword("Admin Password")
|
||||
// SECURITY: Match Web API requirement of 8 characters minimum
|
||||
if len(cfg.Admin.Password) < 8 {
|
||||
fmt.Println(" Password must be at least 8 characters")
|
||||
continue
|
||||
}
|
||||
if len(cfg.Admin.Password) > 128 {
|
||||
fmt.Println(" Password must be at most 128 characters")
|
||||
continue
|
||||
}
|
||||
confirm := promptPassword("Confirm Password")
|
||||
if cfg.Admin.Password != confirm {
|
||||
fmt.Println(" Passwords do not match")
|
||||
continue
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
// Server configuration with validation
|
||||
fmt.Println()
|
||||
fmt.Println("── Server Configuration ──")
|
||||
|
||||
for {
|
||||
cfg.Server.Port = promptInt(reader, "Server Port", 8080)
|
||||
if cliValidatePort(cfg.Server.Port) {
|
||||
break
|
||||
}
|
||||
fmt.Println(" Invalid port. Must be between 1 and 65535.")
|
||||
}
|
||||
|
||||
// Confirm and install
|
||||
fmt.Println()
|
||||
fmt.Println("── Configuration Summary ──")
|
||||
fmt.Printf("Database: %s@%s:%d/%s\n", cfg.Database.User, cfg.Database.Host, cfg.Database.Port, cfg.Database.DBName)
|
||||
fmt.Printf("Redis: %s:%d\n", cfg.Redis.Host, cfg.Redis.Port)
|
||||
fmt.Printf("Admin: %s\n", cfg.Admin.Email)
|
||||
fmt.Printf("Server: :%d\n", cfg.Server.Port)
|
||||
fmt.Println()
|
||||
|
||||
if !promptConfirm(reader, "Proceed with installation?") {
|
||||
fmt.Println("Installation cancelled")
|
||||
return nil
|
||||
}
|
||||
|
||||
fmt.Println()
|
||||
fmt.Print("Installing... ")
|
||||
if err := Install(cfg); err != nil {
|
||||
fmt.Println("FAILED")
|
||||
return err
|
||||
}
|
||||
fmt.Println("OK")
|
||||
|
||||
fmt.Println()
|
||||
fmt.Println("╔═══════════════════════════════════════════╗")
|
||||
fmt.Println("║ Installation Complete! ║")
|
||||
fmt.Println("╚═══════════════════════════════════════════╝")
|
||||
fmt.Println()
|
||||
fmt.Println("Start the server with:")
|
||||
fmt.Println(" ./sub2api")
|
||||
fmt.Println()
|
||||
fmt.Printf("Admin panel: http://localhost:%d\n", cfg.Server.Port)
|
||||
fmt.Println()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func promptString(reader *bufio.Reader, prompt, defaultVal string) string {
|
||||
if defaultVal != "" {
|
||||
fmt.Printf(" %s [%s]: ", prompt, defaultVal)
|
||||
} else {
|
||||
fmt.Printf(" %s: ", prompt)
|
||||
}
|
||||
|
||||
input, _ := reader.ReadString('\n')
|
||||
input = strings.TrimSpace(input)
|
||||
|
||||
if input == "" {
|
||||
return defaultVal
|
||||
}
|
||||
return input
|
||||
}
|
||||
|
||||
func promptInt(reader *bufio.Reader, prompt string, defaultVal int) int {
|
||||
fmt.Printf(" %s [%d]: ", prompt, defaultVal)
|
||||
|
||||
input, _ := reader.ReadString('\n')
|
||||
input = strings.TrimSpace(input)
|
||||
|
||||
if input == "" {
|
||||
return defaultVal
|
||||
}
|
||||
|
||||
val, err := strconv.Atoi(input)
|
||||
if err != nil {
|
||||
return defaultVal
|
||||
}
|
||||
return val
|
||||
}
|
||||
|
||||
func promptPassword(prompt string) string {
|
||||
fmt.Printf(" %s: ", prompt)
|
||||
|
||||
// Try to read password without echo
|
||||
if term.IsTerminal(int(os.Stdin.Fd())) {
|
||||
password, err := term.ReadPassword(int(os.Stdin.Fd()))
|
||||
fmt.Println()
|
||||
if err == nil {
|
||||
return string(password)
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback to regular input
|
||||
reader := bufio.NewReader(os.Stdin)
|
||||
input, _ := reader.ReadString('\n')
|
||||
return strings.TrimSpace(input)
|
||||
}
|
||||
|
||||
func promptConfirm(reader *bufio.Reader, prompt string) bool {
|
||||
fmt.Printf("%s [y/N]: ", prompt)
|
||||
input, _ := reader.ReadString('\n')
|
||||
input = strings.TrimSpace(strings.ToLower(input))
|
||||
return input == "y" || input == "yes"
|
||||
}
|
||||
344
backend/internal/setup/handler.go
Normal file
344
backend/internal/setup/handler.go
Normal file
@@ -0,0 +1,344 @@
|
||||
package setup
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/mail"
|
||||
"regexp"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"sub2api/internal/pkg/response"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// installMutex prevents concurrent installation attempts (TOCTOU protection)
|
||||
var installMutex sync.Mutex
|
||||
|
||||
// RegisterRoutes registers setup wizard routes
|
||||
func RegisterRoutes(r *gin.Engine) {
|
||||
setup := r.Group("/setup")
|
||||
{
|
||||
// Status endpoint is always accessible (read-only)
|
||||
setup.GET("/status", getStatus)
|
||||
|
||||
// All modification endpoints are protected by setupGuard
|
||||
protected := setup.Group("")
|
||||
protected.Use(setupGuard())
|
||||
{
|
||||
protected.POST("/test-db", testDatabase)
|
||||
protected.POST("/test-redis", testRedis)
|
||||
protected.POST("/install", install)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// SetupStatus represents the current setup state
|
||||
type SetupStatus struct {
|
||||
NeedsSetup bool `json:"needs_setup"`
|
||||
Step string `json:"step"`
|
||||
}
|
||||
|
||||
// getStatus returns the current setup status
|
||||
func getStatus(c *gin.Context) {
|
||||
response.Success(c, SetupStatus{
|
||||
NeedsSetup: NeedsSetup(),
|
||||
Step: "welcome",
|
||||
})
|
||||
}
|
||||
|
||||
// setupGuard middleware ensures setup endpoints are only accessible during setup mode
|
||||
func setupGuard() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
if !NeedsSetup() {
|
||||
response.Error(c, http.StatusForbidden, "Setup is not allowed: system is already installed")
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// validateHostname checks if a hostname/IP is safe (no injection characters)
|
||||
func validateHostname(host string) bool {
|
||||
// Allow only alphanumeric, dots, hyphens, and colons (for IPv6)
|
||||
validHost := regexp.MustCompile(`^[a-zA-Z0-9.\-:]+$`)
|
||||
return validHost.MatchString(host) && len(host) <= 253
|
||||
}
|
||||
|
||||
// validateDBName checks if database name is safe
|
||||
func validateDBName(name string) bool {
|
||||
// Allow only alphanumeric and underscores, starting with letter
|
||||
validName := regexp.MustCompile(`^[a-zA-Z][a-zA-Z0-9_]*$`)
|
||||
return validName.MatchString(name) && len(name) <= 63
|
||||
}
|
||||
|
||||
// validateUsername checks if username is safe
|
||||
func validateUsername(name string) bool {
|
||||
// Allow only alphanumeric and underscores
|
||||
validName := regexp.MustCompile(`^[a-zA-Z0-9_]+$`)
|
||||
return validName.MatchString(name) && len(name) <= 63
|
||||
}
|
||||
|
||||
// validateEmail checks if email format is valid
|
||||
func validateEmail(email string) bool {
|
||||
_, err := mail.ParseAddress(email)
|
||||
return err == nil && len(email) <= 254
|
||||
}
|
||||
|
||||
// validatePassword checks password strength
|
||||
func validatePassword(password string) error {
|
||||
if len(password) < 8 {
|
||||
return fmt.Errorf("password must be at least 8 characters")
|
||||
}
|
||||
if len(password) > 128 {
|
||||
return fmt.Errorf("password must be at most 128 characters")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// validatePort checks if port is in valid range
|
||||
func validatePort(port int) bool {
|
||||
return port > 0 && port <= 65535
|
||||
}
|
||||
|
||||
// validateSSLMode checks if SSL mode is valid
|
||||
func validateSSLMode(mode string) bool {
|
||||
validModes := map[string]bool{
|
||||
"disable": true, "require": true, "verify-ca": true, "verify-full": true,
|
||||
}
|
||||
return validModes[mode]
|
||||
}
|
||||
|
||||
// TestDatabaseRequest represents database test request
|
||||
type TestDatabaseRequest struct {
|
||||
Host string `json:"host" binding:"required"`
|
||||
Port int `json:"port" binding:"required"`
|
||||
User string `json:"user" binding:"required"`
|
||||
Password string `json:"password"`
|
||||
DBName string `json:"dbname" binding:"required"`
|
||||
SSLMode string `json:"sslmode"`
|
||||
}
|
||||
|
||||
// testDatabase tests database connection
|
||||
func testDatabase(c *gin.Context) {
|
||||
var req TestDatabaseRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.Error(c, http.StatusBadRequest, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// Security: Validate all inputs to prevent injection attacks
|
||||
if !validateHostname(req.Host) {
|
||||
response.Error(c, http.StatusBadRequest, "Invalid hostname format")
|
||||
return
|
||||
}
|
||||
if !validatePort(req.Port) {
|
||||
response.Error(c, http.StatusBadRequest, "Invalid port number")
|
||||
return
|
||||
}
|
||||
if !validateUsername(req.User) {
|
||||
response.Error(c, http.StatusBadRequest, "Invalid username format")
|
||||
return
|
||||
}
|
||||
if !validateDBName(req.DBName) {
|
||||
response.Error(c, http.StatusBadRequest, "Invalid database name format")
|
||||
return
|
||||
}
|
||||
|
||||
if req.SSLMode == "" {
|
||||
req.SSLMode = "disable"
|
||||
}
|
||||
if !validateSSLMode(req.SSLMode) {
|
||||
response.Error(c, http.StatusBadRequest, "Invalid SSL mode")
|
||||
return
|
||||
}
|
||||
|
||||
cfg := &DatabaseConfig{
|
||||
Host: req.Host,
|
||||
Port: req.Port,
|
||||
User: req.User,
|
||||
Password: req.Password,
|
||||
DBName: req.DBName,
|
||||
SSLMode: req.SSLMode,
|
||||
}
|
||||
|
||||
if err := TestDatabaseConnection(cfg); err != nil {
|
||||
response.Error(c, http.StatusBadRequest, "Connection failed: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{"message": "Connection successful"})
|
||||
}
|
||||
|
||||
// TestRedisRequest represents Redis test request
|
||||
type TestRedisRequest struct {
|
||||
Host string `json:"host" binding:"required"`
|
||||
Port int `json:"port" binding:"required"`
|
||||
Password string `json:"password"`
|
||||
DB int `json:"db"`
|
||||
}
|
||||
|
||||
// testRedis tests Redis connection
|
||||
func testRedis(c *gin.Context) {
|
||||
var req TestRedisRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.Error(c, http.StatusBadRequest, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// Security: Validate inputs
|
||||
if !validateHostname(req.Host) {
|
||||
response.Error(c, http.StatusBadRequest, "Invalid hostname format")
|
||||
return
|
||||
}
|
||||
if !validatePort(req.Port) {
|
||||
response.Error(c, http.StatusBadRequest, "Invalid port number")
|
||||
return
|
||||
}
|
||||
if req.DB < 0 || req.DB > 15 {
|
||||
response.Error(c, http.StatusBadRequest, "Invalid Redis database number (0-15)")
|
||||
return
|
||||
}
|
||||
|
||||
cfg := &RedisConfig{
|
||||
Host: req.Host,
|
||||
Port: req.Port,
|
||||
Password: req.Password,
|
||||
DB: req.DB,
|
||||
}
|
||||
|
||||
if err := TestRedisConnection(cfg); err != nil {
|
||||
response.Error(c, http.StatusBadRequest, "Connection failed: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{"message": "Connection successful"})
|
||||
}
|
||||
|
||||
// InstallRequest represents installation request
|
||||
type InstallRequest struct {
|
||||
Database DatabaseConfig `json:"database" binding:"required"`
|
||||
Redis RedisConfig `json:"redis" binding:"required"`
|
||||
Admin AdminConfig `json:"admin" binding:"required"`
|
||||
Server ServerConfig `json:"server"`
|
||||
}
|
||||
|
||||
// install performs the installation
|
||||
func install(c *gin.Context) {
|
||||
// TOCTOU Protection: Acquire mutex to prevent concurrent installation
|
||||
installMutex.Lock()
|
||||
defer installMutex.Unlock()
|
||||
|
||||
// Double-check after acquiring lock
|
||||
if !NeedsSetup() {
|
||||
response.Error(c, http.StatusForbidden, "Setup is not allowed: system is already installed")
|
||||
return
|
||||
}
|
||||
|
||||
var req InstallRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.Error(c, http.StatusBadRequest, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// ========== COMPREHENSIVE INPUT VALIDATION ==========
|
||||
// Database validation
|
||||
if !validateHostname(req.Database.Host) {
|
||||
response.Error(c, http.StatusBadRequest, "Invalid database hostname")
|
||||
return
|
||||
}
|
||||
if !validatePort(req.Database.Port) {
|
||||
response.Error(c, http.StatusBadRequest, "Invalid database port")
|
||||
return
|
||||
}
|
||||
if !validateUsername(req.Database.User) {
|
||||
response.Error(c, http.StatusBadRequest, "Invalid database username")
|
||||
return
|
||||
}
|
||||
if !validateDBName(req.Database.DBName) {
|
||||
response.Error(c, http.StatusBadRequest, "Invalid database name")
|
||||
return
|
||||
}
|
||||
|
||||
// Redis validation
|
||||
if !validateHostname(req.Redis.Host) {
|
||||
response.Error(c, http.StatusBadRequest, "Invalid Redis hostname")
|
||||
return
|
||||
}
|
||||
if !validatePort(req.Redis.Port) {
|
||||
response.Error(c, http.StatusBadRequest, "Invalid Redis port")
|
||||
return
|
||||
}
|
||||
if req.Redis.DB < 0 || req.Redis.DB > 15 {
|
||||
response.Error(c, http.StatusBadRequest, "Invalid Redis database number")
|
||||
return
|
||||
}
|
||||
|
||||
// Admin validation
|
||||
if !validateEmail(req.Admin.Email) {
|
||||
response.Error(c, http.StatusBadRequest, "Invalid admin email format")
|
||||
return
|
||||
}
|
||||
if err := validatePassword(req.Admin.Password); err != nil {
|
||||
response.Error(c, http.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// Server validation
|
||||
if req.Server.Port != 0 && !validatePort(req.Server.Port) {
|
||||
response.Error(c, http.StatusBadRequest, "Invalid server port")
|
||||
return
|
||||
}
|
||||
|
||||
// ========== SET DEFAULTS ==========
|
||||
if req.Database.SSLMode == "" {
|
||||
req.Database.SSLMode = "disable"
|
||||
}
|
||||
if !validateSSLMode(req.Database.SSLMode) {
|
||||
response.Error(c, http.StatusBadRequest, "Invalid SSL mode")
|
||||
return
|
||||
}
|
||||
if req.Server.Host == "" {
|
||||
req.Server.Host = "0.0.0.0"
|
||||
}
|
||||
if req.Server.Port == 0 {
|
||||
req.Server.Port = 8080
|
||||
}
|
||||
if req.Server.Mode == "" {
|
||||
req.Server.Mode = "release"
|
||||
}
|
||||
// Validate server mode
|
||||
if req.Server.Mode != "release" && req.Server.Mode != "debug" {
|
||||
response.Error(c, http.StatusBadRequest, "Invalid server mode (must be 'release' or 'debug')")
|
||||
return
|
||||
}
|
||||
|
||||
// Trim whitespace from string inputs
|
||||
req.Admin.Email = strings.TrimSpace(req.Admin.Email)
|
||||
req.Database.Host = strings.TrimSpace(req.Database.Host)
|
||||
req.Database.User = strings.TrimSpace(req.Database.User)
|
||||
req.Database.DBName = strings.TrimSpace(req.Database.DBName)
|
||||
req.Redis.Host = strings.TrimSpace(req.Redis.Host)
|
||||
|
||||
cfg := &SetupConfig{
|
||||
Database: req.Database,
|
||||
Redis: req.Redis,
|
||||
Admin: req.Admin,
|
||||
Server: req.Server,
|
||||
JWT: JWTConfig{
|
||||
ExpireHour: 24,
|
||||
},
|
||||
}
|
||||
|
||||
if err := Install(cfg); err != nil {
|
||||
response.Error(c, http.StatusInternalServerError, "Installation failed: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{
|
||||
"message": "Installation completed successfully",
|
||||
"restart": true,
|
||||
})
|
||||
}
|
||||
564
backend/internal/setup/setup.go
Normal file
564
backend/internal/setup/setup.go
Normal file
@@ -0,0 +1,564 @@
|
||||
package setup
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
"gorm.io/driver/postgres"
|
||||
"gorm.io/gorm"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
// Config paths
|
||||
const (
|
||||
ConfigFile = "config.yaml"
|
||||
EnvFile = ".env"
|
||||
)
|
||||
|
||||
// SetupConfig holds the setup configuration
|
||||
type SetupConfig struct {
|
||||
Database DatabaseConfig `json:"database" yaml:"database"`
|
||||
Redis RedisConfig `json:"redis" yaml:"redis"`
|
||||
Admin AdminConfig `json:"admin" yaml:"-"` // Not stored in config file
|
||||
Server ServerConfig `json:"server" yaml:"server"`
|
||||
JWT JWTConfig `json:"jwt" yaml:"jwt"`
|
||||
Timezone string `json:"timezone" yaml:"timezone"` // e.g. "Asia/Shanghai", "UTC"
|
||||
}
|
||||
|
||||
type DatabaseConfig struct {
|
||||
Host string `json:"host" yaml:"host"`
|
||||
Port int `json:"port" yaml:"port"`
|
||||
User string `json:"user" yaml:"user"`
|
||||
Password string `json:"password" yaml:"password"`
|
||||
DBName string `json:"dbname" yaml:"dbname"`
|
||||
SSLMode string `json:"sslmode" yaml:"sslmode"`
|
||||
}
|
||||
|
||||
type RedisConfig struct {
|
||||
Host string `json:"host" yaml:"host"`
|
||||
Port int `json:"port" yaml:"port"`
|
||||
Password string `json:"password" yaml:"password"`
|
||||
DB int `json:"db" yaml:"db"`
|
||||
}
|
||||
|
||||
type AdminConfig struct {
|
||||
Email string `json:"email"`
|
||||
Password string `json:"password"`
|
||||
}
|
||||
|
||||
type ServerConfig struct {
|
||||
Host string `json:"host" yaml:"host"`
|
||||
Port int `json:"port" yaml:"port"`
|
||||
Mode string `json:"mode" yaml:"mode"`
|
||||
}
|
||||
|
||||
type JWTConfig struct {
|
||||
Secret string `json:"secret" yaml:"secret"`
|
||||
ExpireHour int `json:"expire_hour" yaml:"expire_hour"`
|
||||
}
|
||||
|
||||
// NeedsSetup checks if the system needs initial setup
|
||||
// Uses multiple checks to prevent attackers from forcing re-setup by deleting config
|
||||
func NeedsSetup() bool {
|
||||
// Check 1: Config file must not exist
|
||||
if _, err := os.Stat(ConfigFile); !os.IsNotExist(err) {
|
||||
return false // Config exists, no setup needed
|
||||
}
|
||||
|
||||
// Check 2: Installation lock file (harder to bypass)
|
||||
lockFile := ".installed"
|
||||
if _, err := os.Stat(lockFile); !os.IsNotExist(err) {
|
||||
return false // Lock file exists, already installed
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// TestDatabaseConnection tests the database connection
|
||||
func TestDatabaseConnection(cfg *DatabaseConfig) error {
|
||||
dsn := fmt.Sprintf(
|
||||
"host=%s port=%d user=%s password=%s dbname=%s sslmode=%s",
|
||||
cfg.Host, cfg.Port, cfg.User, cfg.Password, cfg.DBName, cfg.SSLMode,
|
||||
)
|
||||
|
||||
db, err := gorm.Open(postgres.Open(dsn), &gorm.Config{})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to connect: %w", err)
|
||||
}
|
||||
|
||||
sqlDB, err := db.DB()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get db instance: %w", err)
|
||||
}
|
||||
defer sqlDB.Close()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := sqlDB.PingContext(ctx); err != nil {
|
||||
return fmt.Errorf("ping failed: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// TestRedisConnection tests the Redis connection
|
||||
func TestRedisConnection(cfg *RedisConfig) error {
|
||||
rdb := redis.NewClient(&redis.Options{
|
||||
Addr: fmt.Sprintf("%s:%d", cfg.Host, cfg.Port),
|
||||
Password: cfg.Password,
|
||||
DB: cfg.DB,
|
||||
})
|
||||
defer rdb.Close()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := rdb.Ping(ctx).Err(); err != nil {
|
||||
return fmt.Errorf("ping failed: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Install performs the installation with the given configuration
|
||||
func Install(cfg *SetupConfig) error {
|
||||
// Security check: prevent re-installation if already installed
|
||||
if !NeedsSetup() {
|
||||
return fmt.Errorf("system is already installed, re-installation is not allowed")
|
||||
}
|
||||
|
||||
// Generate JWT secret if not provided
|
||||
if cfg.JWT.Secret == "" {
|
||||
cfg.JWT.Secret = generateSecret(32)
|
||||
}
|
||||
|
||||
// Test connections
|
||||
if err := TestDatabaseConnection(&cfg.Database); err != nil {
|
||||
return fmt.Errorf("database connection failed: %w", err)
|
||||
}
|
||||
|
||||
if err := TestRedisConnection(&cfg.Redis); err != nil {
|
||||
return fmt.Errorf("redis connection failed: %w", err)
|
||||
}
|
||||
|
||||
// Initialize database
|
||||
if err := initializeDatabase(cfg); err != nil {
|
||||
return fmt.Errorf("database initialization failed: %w", err)
|
||||
}
|
||||
|
||||
// Create admin user
|
||||
if err := createAdminUser(cfg); err != nil {
|
||||
return fmt.Errorf("admin user creation failed: %w", err)
|
||||
}
|
||||
|
||||
// Write config file
|
||||
if err := writeConfigFile(cfg); err != nil {
|
||||
return fmt.Errorf("config file creation failed: %w", err)
|
||||
}
|
||||
|
||||
// Create installation lock file to prevent re-setup attacks
|
||||
if err := createInstallLock(); err != nil {
|
||||
return fmt.Errorf("failed to create install lock: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// createInstallLock creates a lock file to prevent re-installation attacks
|
||||
func createInstallLock() error {
|
||||
lockFile := ".installed"
|
||||
content := fmt.Sprintf("installed_at=%s\n", time.Now().UTC().Format(time.RFC3339))
|
||||
return os.WriteFile(lockFile, []byte(content), 0400) // Read-only for owner
|
||||
}
|
||||
|
||||
func initializeDatabase(cfg *SetupConfig) error {
|
||||
dsn := fmt.Sprintf(
|
||||
"host=%s port=%d user=%s password=%s dbname=%s sslmode=%s",
|
||||
cfg.Database.Host, cfg.Database.Port, cfg.Database.User,
|
||||
cfg.Database.Password, cfg.Database.DBName, cfg.Database.SSLMode,
|
||||
)
|
||||
|
||||
db, err := gorm.Open(postgres.Open(dsn), &gorm.Config{})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
sqlDB, err := db.DB()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer sqlDB.Close()
|
||||
|
||||
// Run auto-migration for all models
|
||||
return db.AutoMigrate(
|
||||
&User{},
|
||||
&Group{},
|
||||
&APIKey{},
|
||||
&Account{},
|
||||
&Proxy{},
|
||||
&RedeemCode{},
|
||||
&UsageLog{},
|
||||
&UserSubscription{},
|
||||
&Setting{},
|
||||
)
|
||||
}
|
||||
|
||||
func createAdminUser(cfg *SetupConfig) error {
|
||||
dsn := fmt.Sprintf(
|
||||
"host=%s port=%d user=%s password=%s dbname=%s sslmode=%s",
|
||||
cfg.Database.Host, cfg.Database.Port, cfg.Database.User,
|
||||
cfg.Database.Password, cfg.Database.DBName, cfg.Database.SSLMode,
|
||||
)
|
||||
|
||||
db, err := gorm.Open(postgres.Open(dsn), &gorm.Config{})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
sqlDB, err := db.DB()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer sqlDB.Close()
|
||||
|
||||
// Check if admin already exists
|
||||
var count int64
|
||||
db.Model(&User{}).Where("role = ?", "admin").Count(&count)
|
||||
if count > 0 {
|
||||
return nil // Admin already exists
|
||||
}
|
||||
|
||||
// Hash password
|
||||
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(cfg.Admin.Password), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Create admin user
|
||||
admin := &User{
|
||||
Email: cfg.Admin.Email,
|
||||
PasswordHash: string(hashedPassword),
|
||||
Role: "admin",
|
||||
Status: "active",
|
||||
Balance: 0,
|
||||
CreatedAt: time.Now(),
|
||||
UpdatedAt: time.Now(),
|
||||
}
|
||||
|
||||
return db.Create(admin).Error
|
||||
}
|
||||
|
||||
func writeConfigFile(cfg *SetupConfig) error {
|
||||
// Ensure timezone has a default value
|
||||
tz := cfg.Timezone
|
||||
if tz == "" {
|
||||
tz = "Asia/Shanghai"
|
||||
}
|
||||
|
||||
// Prepare config for YAML (exclude sensitive data and admin config)
|
||||
yamlConfig := struct {
|
||||
Server ServerConfig `yaml:"server"`
|
||||
Database DatabaseConfig `yaml:"database"`
|
||||
Redis RedisConfig `yaml:"redis"`
|
||||
JWT struct {
|
||||
Secret string `yaml:"secret"`
|
||||
ExpireHour int `yaml:"expire_hour"`
|
||||
} `yaml:"jwt"`
|
||||
Default struct {
|
||||
GroupID uint `yaml:"group_id"`
|
||||
} `yaml:"default"`
|
||||
RateLimit struct {
|
||||
RequestsPerMinute int `yaml:"requests_per_minute"`
|
||||
BurstSize int `yaml:"burst_size"`
|
||||
} `yaml:"rate_limit"`
|
||||
Timezone string `yaml:"timezone"`
|
||||
}{
|
||||
Server: cfg.Server,
|
||||
Database: cfg.Database,
|
||||
Redis: cfg.Redis,
|
||||
JWT: struct {
|
||||
Secret string `yaml:"secret"`
|
||||
ExpireHour int `yaml:"expire_hour"`
|
||||
}{
|
||||
Secret: cfg.JWT.Secret,
|
||||
ExpireHour: cfg.JWT.ExpireHour,
|
||||
},
|
||||
Default: struct {
|
||||
GroupID uint `yaml:"group_id"`
|
||||
}{
|
||||
GroupID: 1,
|
||||
},
|
||||
RateLimit: struct {
|
||||
RequestsPerMinute int `yaml:"requests_per_minute"`
|
||||
BurstSize int `yaml:"burst_size"`
|
||||
}{
|
||||
RequestsPerMinute: 60,
|
||||
BurstSize: 10,
|
||||
},
|
||||
Timezone: tz,
|
||||
}
|
||||
|
||||
data, err := yaml.Marshal(&yamlConfig)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return os.WriteFile(ConfigFile, data, 0600)
|
||||
}
|
||||
|
||||
func generateSecret(length int) string {
|
||||
bytes := make([]byte, length)
|
||||
rand.Read(bytes)
|
||||
return hex.EncodeToString(bytes)
|
||||
}
|
||||
|
||||
// Minimal model definitions for migration (to avoid circular import)
|
||||
type User struct {
|
||||
ID uint `gorm:"primaryKey"`
|
||||
Email string `gorm:"uniqueIndex;not null"`
|
||||
PasswordHash string `gorm:"not null"`
|
||||
Role string `gorm:"default:user"`
|
||||
Status string `gorm:"default:active"`
|
||||
Balance float64 `gorm:"default:0"`
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
}
|
||||
|
||||
type Group struct {
|
||||
ID uint `gorm:"primaryKey"`
|
||||
Name string `gorm:"uniqueIndex;not null"`
|
||||
Description string `gorm:"type:text"`
|
||||
RateMultiplier float64 `gorm:"default:1.0"`
|
||||
IsExclusive bool `gorm:"default:false"`
|
||||
Priority int `gorm:"default:0"`
|
||||
Status string `gorm:"default:active"`
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
}
|
||||
|
||||
type APIKey struct {
|
||||
ID uint `gorm:"primaryKey"`
|
||||
UserID uint `gorm:"index;not null"`
|
||||
Key string `gorm:"uniqueIndex;not null"`
|
||||
Name string
|
||||
GroupID *uint
|
||||
Status string `gorm:"default:active"`
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
}
|
||||
|
||||
type Account struct {
|
||||
ID uint `gorm:"primaryKey"`
|
||||
Platform string `gorm:"not null"`
|
||||
Type string `gorm:"not null"`
|
||||
Credentials string `gorm:"type:text"`
|
||||
Status string `gorm:"default:active"`
|
||||
Priority int `gorm:"default:0"`
|
||||
ProxyID *uint
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
}
|
||||
|
||||
type Proxy struct {
|
||||
ID uint `gorm:"primaryKey"`
|
||||
Name string `gorm:"not null"`
|
||||
Protocol string `gorm:"not null"`
|
||||
Host string `gorm:"not null"`
|
||||
Port int `gorm:"not null"`
|
||||
Username string
|
||||
Password string
|
||||
Status string `gorm:"default:active"`
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
}
|
||||
|
||||
type RedeemCode struct {
|
||||
ID uint `gorm:"primaryKey"`
|
||||
Code string `gorm:"uniqueIndex;not null"`
|
||||
Value float64 `gorm:"not null"`
|
||||
Status string `gorm:"default:unused"`
|
||||
UsedBy *uint
|
||||
UsedAt *time.Time
|
||||
ExpiresAt *time.Time
|
||||
CreatedAt time.Time
|
||||
}
|
||||
|
||||
type UsageLog struct {
|
||||
ID uint `gorm:"primaryKey"`
|
||||
UserID uint `gorm:"index"`
|
||||
APIKeyID uint `gorm:"index"`
|
||||
AccountID *uint `gorm:"index"`
|
||||
Model string `gorm:"index"`
|
||||
InputTokens int
|
||||
OutputTokens int
|
||||
Cost float64
|
||||
CreatedAt time.Time
|
||||
}
|
||||
|
||||
type UserSubscription struct {
|
||||
ID uint `gorm:"primaryKey"`
|
||||
UserID uint `gorm:"index;not null"`
|
||||
GroupID uint `gorm:"index;not null"`
|
||||
Quota int64
|
||||
Used int64 `gorm:"default:0"`
|
||||
Status string
|
||||
ExpiresAt *time.Time
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
}
|
||||
|
||||
type Setting struct {
|
||||
ID uint `gorm:"primaryKey"`
|
||||
Key string `gorm:"uniqueIndex;not null"`
|
||||
Value string `gorm:"type:text"`
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
}
|
||||
|
||||
func (User) TableName() string { return "users" }
|
||||
func (Group) TableName() string { return "groups" }
|
||||
func (APIKey) TableName() string { return "api_keys" }
|
||||
func (Account) TableName() string { return "accounts" }
|
||||
func (Proxy) TableName() string { return "proxies" }
|
||||
func (RedeemCode) TableName() string { return "redeem_codes" }
|
||||
func (UsageLog) TableName() string { return "usage_logs" }
|
||||
func (UserSubscription) TableName() string { return "user_subscriptions" }
|
||||
func (Setting) TableName() string { return "settings" }
|
||||
|
||||
// =============================================================================
|
||||
// Auto Setup for Docker Deployment
|
||||
// =============================================================================
|
||||
|
||||
// AutoSetupEnabled checks if auto setup is enabled via environment variable
|
||||
func AutoSetupEnabled() bool {
|
||||
val := os.Getenv("AUTO_SETUP")
|
||||
return val == "true" || val == "1" || val == "yes"
|
||||
}
|
||||
|
||||
// getEnvOrDefault gets environment variable or returns default value
|
||||
func getEnvOrDefault(key, defaultValue string) string {
|
||||
if val := os.Getenv(key); val != "" {
|
||||
return val
|
||||
}
|
||||
return defaultValue
|
||||
}
|
||||
|
||||
// getEnvIntOrDefault gets environment variable as int or returns default value
|
||||
func getEnvIntOrDefault(key string, defaultValue int) int {
|
||||
if val := os.Getenv(key); val != "" {
|
||||
if i, err := strconv.Atoi(val); err == nil {
|
||||
return i
|
||||
}
|
||||
}
|
||||
return defaultValue
|
||||
}
|
||||
|
||||
// AutoSetupFromEnv performs automatic setup using environment variables
|
||||
// This is designed for Docker deployment where all config is passed via env vars
|
||||
func AutoSetupFromEnv() error {
|
||||
log.Println("Auto setup enabled, configuring from environment variables...")
|
||||
|
||||
// Get timezone from TZ or TIMEZONE env var (TZ is standard for Docker)
|
||||
tz := getEnvOrDefault("TZ", "")
|
||||
if tz == "" {
|
||||
tz = getEnvOrDefault("TIMEZONE", "Asia/Shanghai")
|
||||
}
|
||||
|
||||
// Build config from environment variables
|
||||
cfg := &SetupConfig{
|
||||
Database: DatabaseConfig{
|
||||
Host: getEnvOrDefault("DATABASE_HOST", "localhost"),
|
||||
Port: getEnvIntOrDefault("DATABASE_PORT", 5432),
|
||||
User: getEnvOrDefault("DATABASE_USER", "postgres"),
|
||||
Password: getEnvOrDefault("DATABASE_PASSWORD", ""),
|
||||
DBName: getEnvOrDefault("DATABASE_DBNAME", "sub2api"),
|
||||
SSLMode: getEnvOrDefault("DATABASE_SSLMODE", "disable"),
|
||||
},
|
||||
Redis: RedisConfig{
|
||||
Host: getEnvOrDefault("REDIS_HOST", "localhost"),
|
||||
Port: getEnvIntOrDefault("REDIS_PORT", 6379),
|
||||
Password: getEnvOrDefault("REDIS_PASSWORD", ""),
|
||||
DB: getEnvIntOrDefault("REDIS_DB", 0),
|
||||
},
|
||||
Admin: AdminConfig{
|
||||
Email: getEnvOrDefault("ADMIN_EMAIL", "admin@sub2api.local"),
|
||||
Password: getEnvOrDefault("ADMIN_PASSWORD", ""),
|
||||
},
|
||||
Server: ServerConfig{
|
||||
Host: getEnvOrDefault("SERVER_HOST", "0.0.0.0"),
|
||||
Port: getEnvIntOrDefault("SERVER_PORT", 8080),
|
||||
Mode: getEnvOrDefault("SERVER_MODE", "release"),
|
||||
},
|
||||
JWT: JWTConfig{
|
||||
Secret: getEnvOrDefault("JWT_SECRET", ""),
|
||||
ExpireHour: getEnvIntOrDefault("JWT_EXPIRE_HOUR", 24),
|
||||
},
|
||||
Timezone: tz,
|
||||
}
|
||||
|
||||
// Generate JWT secret if not provided
|
||||
if cfg.JWT.Secret == "" {
|
||||
cfg.JWT.Secret = generateSecret(32)
|
||||
log.Println("Generated JWT secret automatically")
|
||||
}
|
||||
|
||||
// Generate admin password if not provided
|
||||
if cfg.Admin.Password == "" {
|
||||
cfg.Admin.Password = generateSecret(16)
|
||||
log.Printf("Generated admin password: %s", cfg.Admin.Password)
|
||||
log.Println("IMPORTANT: Save this password! It will not be shown again.")
|
||||
}
|
||||
|
||||
// Test database connection
|
||||
log.Println("Testing database connection...")
|
||||
if err := TestDatabaseConnection(&cfg.Database); err != nil {
|
||||
return fmt.Errorf("database connection failed: %w", err)
|
||||
}
|
||||
log.Println("Database connection successful")
|
||||
|
||||
// Test Redis connection
|
||||
log.Println("Testing Redis connection...")
|
||||
if err := TestRedisConnection(&cfg.Redis); err != nil {
|
||||
return fmt.Errorf("redis connection failed: %w", err)
|
||||
}
|
||||
log.Println("Redis connection successful")
|
||||
|
||||
// Initialize database
|
||||
log.Println("Initializing database...")
|
||||
if err := initializeDatabase(cfg); err != nil {
|
||||
return fmt.Errorf("database initialization failed: %w", err)
|
||||
}
|
||||
log.Println("Database initialized successfully")
|
||||
|
||||
// Create admin user
|
||||
log.Println("Creating admin user...")
|
||||
if err := createAdminUser(cfg); err != nil {
|
||||
return fmt.Errorf("admin user creation failed: %w", err)
|
||||
}
|
||||
log.Printf("Admin user created: %s", cfg.Admin.Email)
|
||||
|
||||
// Write config file
|
||||
log.Println("Writing configuration file...")
|
||||
if err := writeConfigFile(cfg); err != nil {
|
||||
return fmt.Errorf("config file creation failed: %w", err)
|
||||
}
|
||||
log.Println("Configuration file created")
|
||||
|
||||
// Create installation lock file
|
||||
if err := createInstallLock(); err != nil {
|
||||
return fmt.Errorf("failed to create install lock: %w", err)
|
||||
}
|
||||
log.Println("Installation lock created")
|
||||
|
||||
log.Println("Auto setup completed successfully!")
|
||||
return nil
|
||||
}
|
||||
79
backend/internal/web/embed.go
Normal file
79
backend/internal/web/embed.go
Normal file
@@ -0,0 +1,79 @@
|
||||
package web
|
||||
|
||||
import (
|
||||
"embed"
|
||||
"io"
|
||||
"io/fs"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
//go:embed dist/*
|
||||
var frontendFS embed.FS
|
||||
|
||||
// ServeEmbeddedFrontend returns a Gin handler that serves embedded frontend assets
|
||||
// and handles SPA routing by falling back to index.html for non-API routes.
|
||||
func ServeEmbeddedFrontend() gin.HandlerFunc {
|
||||
distFS, err := fs.Sub(frontendFS, "dist")
|
||||
if err != nil {
|
||||
panic("failed to get dist subdirectory: " + err.Error())
|
||||
}
|
||||
fileServer := http.FileServer(http.FS(distFS))
|
||||
|
||||
return func(c *gin.Context) {
|
||||
path := c.Request.URL.Path
|
||||
|
||||
// Skip API and gateway routes
|
||||
if strings.HasPrefix(path, "/api/") ||
|
||||
strings.HasPrefix(path, "/v1/") ||
|
||||
strings.HasPrefix(path, "/setup/") ||
|
||||
path == "/health" {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
// Try to serve static file
|
||||
cleanPath := strings.TrimPrefix(path, "/")
|
||||
if cleanPath == "" {
|
||||
cleanPath = "index.html"
|
||||
}
|
||||
|
||||
if file, err := distFS.Open(cleanPath); err == nil {
|
||||
file.Close()
|
||||
fileServer.ServeHTTP(c.Writer, c.Request)
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
// SPA fallback: serve index.html for all other routes
|
||||
serveIndexHTML(c, distFS)
|
||||
}
|
||||
}
|
||||
|
||||
func serveIndexHTML(c *gin.Context, fsys fs.FS) {
|
||||
file, err := fsys.Open("index.html")
|
||||
if err != nil {
|
||||
c.String(http.StatusNotFound, "Frontend not found")
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
content, err := io.ReadAll(file)
|
||||
if err != nil {
|
||||
c.String(http.StatusInternalServerError, "Failed to read index.html")
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
c.Data(http.StatusOK, "text/html; charset=utf-8", content)
|
||||
c.Abort()
|
||||
}
|
||||
|
||||
// HasEmbeddedFrontend checks if frontend assets are embedded
|
||||
func HasEmbeddedFrontend() bool {
|
||||
_, err := frontendFS.ReadFile("dist/index.html")
|
||||
return err == nil
|
||||
}
|
||||
183
backend/migrations/001_init.sql
Normal file
183
backend/migrations/001_init.sql
Normal file
@@ -0,0 +1,183 @@
|
||||
-- Sub2API 初始化数据库迁移脚本
|
||||
-- PostgreSQL 15+
|
||||
|
||||
-- 1. proxies 代理IP表(无外键依赖)
|
||||
CREATE TABLE IF NOT EXISTS proxies (
|
||||
id BIGSERIAL PRIMARY KEY,
|
||||
name VARCHAR(100) NOT NULL,
|
||||
protocol VARCHAR(20) NOT NULL, -- http/https/socks5
|
||||
host VARCHAR(255) NOT NULL,
|
||||
port INT NOT NULL,
|
||||
username VARCHAR(100),
|
||||
password VARCHAR(100),
|
||||
status VARCHAR(20) NOT NULL DEFAULT 'active', -- active/disabled
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
deleted_at TIMESTAMPTZ
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_proxies_status ON proxies(status);
|
||||
CREATE INDEX IF NOT EXISTS idx_proxies_deleted_at ON proxies(deleted_at);
|
||||
|
||||
-- 2. groups 分组表(无外键依赖)
|
||||
CREATE TABLE IF NOT EXISTS groups (
|
||||
id BIGSERIAL PRIMARY KEY,
|
||||
name VARCHAR(100) NOT NULL UNIQUE,
|
||||
description TEXT,
|
||||
rate_multiplier DECIMAL(10, 4) NOT NULL DEFAULT 1.0, -- 费率倍率
|
||||
is_exclusive BOOLEAN NOT NULL DEFAULT FALSE, -- 是否专属分组
|
||||
status VARCHAR(20) NOT NULL DEFAULT 'active', -- active/disabled
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
deleted_at TIMESTAMPTZ
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_groups_name ON groups(name);
|
||||
CREATE INDEX IF NOT EXISTS idx_groups_status ON groups(status);
|
||||
CREATE INDEX IF NOT EXISTS idx_groups_is_exclusive ON groups(is_exclusive);
|
||||
CREATE INDEX IF NOT EXISTS idx_groups_deleted_at ON groups(deleted_at);
|
||||
|
||||
-- 3. users 用户表(无外键依赖)
|
||||
CREATE TABLE IF NOT EXISTS users (
|
||||
id BIGSERIAL PRIMARY KEY,
|
||||
email VARCHAR(255) NOT NULL UNIQUE,
|
||||
password_hash VARCHAR(255) NOT NULL,
|
||||
role VARCHAR(20) NOT NULL DEFAULT 'user', -- admin/user
|
||||
balance DECIMAL(20, 8) NOT NULL DEFAULT 0, -- 余额(可为负数)
|
||||
concurrency INT NOT NULL DEFAULT 5, -- 并发数限制
|
||||
status VARCHAR(20) NOT NULL DEFAULT 'active', -- active/disabled
|
||||
allowed_groups BIGINT[] DEFAULT NULL, -- 允许绑定的分组ID列表
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
deleted_at TIMESTAMPTZ
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_users_email ON users(email);
|
||||
CREATE INDEX IF NOT EXISTS idx_users_status ON users(status);
|
||||
CREATE INDEX IF NOT EXISTS idx_users_deleted_at ON users(deleted_at);
|
||||
|
||||
-- 4. accounts 上游账号表(依赖proxies)
|
||||
CREATE TABLE IF NOT EXISTS accounts (
|
||||
id BIGSERIAL PRIMARY KEY,
|
||||
name VARCHAR(100) NOT NULL,
|
||||
platform VARCHAR(50) NOT NULL, -- anthropic/openai/gemini
|
||||
type VARCHAR(20) NOT NULL, -- oauth/apikey
|
||||
credentials JSONB NOT NULL DEFAULT '{}', -- 凭证信息(加密存储)
|
||||
extra JSONB NOT NULL DEFAULT '{}', -- 扩展信息
|
||||
proxy_id BIGINT REFERENCES proxies(id) ON DELETE SET NULL,
|
||||
concurrency INT NOT NULL DEFAULT 3, -- 账号并发限制
|
||||
priority INT NOT NULL DEFAULT 50, -- 调度优先级(1-100,越小越高)
|
||||
status VARCHAR(20) NOT NULL DEFAULT 'active', -- active/disabled/error
|
||||
error_message TEXT,
|
||||
last_used_at TIMESTAMPTZ,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
deleted_at TIMESTAMPTZ
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_accounts_platform ON accounts(platform);
|
||||
CREATE INDEX IF NOT EXISTS idx_accounts_type ON accounts(type);
|
||||
CREATE INDEX IF NOT EXISTS idx_accounts_status ON accounts(status);
|
||||
CREATE INDEX IF NOT EXISTS idx_accounts_proxy_id ON accounts(proxy_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_accounts_priority ON accounts(priority);
|
||||
CREATE INDEX IF NOT EXISTS idx_accounts_last_used_at ON accounts(last_used_at);
|
||||
CREATE INDEX IF NOT EXISTS idx_accounts_deleted_at ON accounts(deleted_at);
|
||||
|
||||
-- 5. api_keys API密钥表(依赖users, groups)
|
||||
CREATE TABLE IF NOT EXISTS api_keys (
|
||||
id BIGSERIAL PRIMARY KEY,
|
||||
user_id BIGINT NOT NULL REFERENCES users(id) ON DELETE CASCADE,
|
||||
key VARCHAR(64) NOT NULL UNIQUE, -- sk-xxx格式
|
||||
name VARCHAR(100) NOT NULL,
|
||||
group_id BIGINT REFERENCES groups(id) ON DELETE SET NULL,
|
||||
status VARCHAR(20) NOT NULL DEFAULT 'active', -- active/disabled
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
deleted_at TIMESTAMPTZ
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_api_keys_key ON api_keys(key);
|
||||
CREATE INDEX IF NOT EXISTS idx_api_keys_user_id ON api_keys(user_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_api_keys_group_id ON api_keys(group_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_api_keys_status ON api_keys(status);
|
||||
CREATE INDEX IF NOT EXISTS idx_api_keys_deleted_at ON api_keys(deleted_at);
|
||||
|
||||
-- 6. account_groups 账号-分组关联表(依赖accounts, groups)
|
||||
CREATE TABLE IF NOT EXISTS account_groups (
|
||||
account_id BIGINT NOT NULL REFERENCES accounts(id) ON DELETE CASCADE,
|
||||
group_id BIGINT NOT NULL REFERENCES groups(id) ON DELETE CASCADE,
|
||||
priority INT NOT NULL DEFAULT 50, -- 分组内优先级
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
PRIMARY KEY (account_id, group_id)
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_account_groups_group_id ON account_groups(group_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_account_groups_priority ON account_groups(priority);
|
||||
|
||||
-- 7. redeem_codes 卡密表(依赖users)
|
||||
CREATE TABLE IF NOT EXISTS redeem_codes (
|
||||
id BIGSERIAL PRIMARY KEY,
|
||||
code VARCHAR(32) NOT NULL UNIQUE, -- 兑换码
|
||||
type VARCHAR(20) NOT NULL DEFAULT 'balance', -- balance
|
||||
value DECIMAL(20, 8) NOT NULL, -- 面值(USD)
|
||||
status VARCHAR(20) NOT NULL DEFAULT 'unused', -- unused/used
|
||||
used_by BIGINT REFERENCES users(id) ON DELETE SET NULL,
|
||||
used_at TIMESTAMPTZ,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_redeem_codes_code ON redeem_codes(code);
|
||||
CREATE INDEX IF NOT EXISTS idx_redeem_codes_status ON redeem_codes(status);
|
||||
CREATE INDEX IF NOT EXISTS idx_redeem_codes_used_by ON redeem_codes(used_by);
|
||||
|
||||
-- 8. usage_logs 使用记录表(依赖users, api_keys, accounts)
|
||||
CREATE TABLE IF NOT EXISTS usage_logs (
|
||||
id BIGSERIAL PRIMARY KEY,
|
||||
user_id BIGINT NOT NULL REFERENCES users(id) ON DELETE CASCADE,
|
||||
api_key_id BIGINT NOT NULL REFERENCES api_keys(id) ON DELETE CASCADE,
|
||||
account_id BIGINT NOT NULL REFERENCES accounts(id) ON DELETE CASCADE,
|
||||
request_id VARCHAR(64),
|
||||
model VARCHAR(100) NOT NULL,
|
||||
|
||||
-- Token使用量(4类)
|
||||
input_tokens INT NOT NULL DEFAULT 0,
|
||||
output_tokens INT NOT NULL DEFAULT 0,
|
||||
cache_creation_tokens INT NOT NULL DEFAULT 0,
|
||||
cache_read_tokens INT NOT NULL DEFAULT 0,
|
||||
|
||||
-- 详细的缓存创建分类
|
||||
cache_creation_5m_tokens INT NOT NULL DEFAULT 0,
|
||||
cache_creation_1h_tokens INT NOT NULL DEFAULT 0,
|
||||
|
||||
-- 费用(USD)
|
||||
input_cost DECIMAL(20, 10) NOT NULL DEFAULT 0,
|
||||
output_cost DECIMAL(20, 10) NOT NULL DEFAULT 0,
|
||||
cache_creation_cost DECIMAL(20, 10) NOT NULL DEFAULT 0,
|
||||
cache_read_cost DECIMAL(20, 10) NOT NULL DEFAULT 0,
|
||||
total_cost DECIMAL(20, 10) NOT NULL DEFAULT 0, -- 原始总费用
|
||||
actual_cost DECIMAL(20, 10) NOT NULL DEFAULT 0, -- 实际扣除费用
|
||||
|
||||
-- 元数据
|
||||
stream BOOLEAN NOT NULL DEFAULT FALSE,
|
||||
duration_ms INT,
|
||||
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_usage_logs_user_id ON usage_logs(user_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_usage_logs_api_key_id ON usage_logs(api_key_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_usage_logs_account_id ON usage_logs(account_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_usage_logs_model ON usage_logs(model);
|
||||
CREATE INDEX IF NOT EXISTS idx_usage_logs_created_at ON usage_logs(created_at);
|
||||
CREATE INDEX IF NOT EXISTS idx_usage_logs_user_created ON usage_logs(user_id, created_at);
|
||||
|
||||
-- 插入默认管理员用户
|
||||
-- 密码: admin123 (bcrypt hash)
|
||||
INSERT INTO users (email, password_hash, role, balance, concurrency, status)
|
||||
VALUES ('admin@sub2api.com', '$2a$10$N9qo8uLOickgx2ZMRZoMye.IjJbDdJeCo0U2bBPJj9lS/5LqD.C.C', 'admin', 0, 10, 'active')
|
||||
ON CONFLICT (email) DO NOTHING;
|
||||
|
||||
-- 插入默认分组
|
||||
INSERT INTO groups (name, description, rate_multiplier, is_exclusive, status)
|
||||
VALUES ('default', '默认分组', 1.0, false, 'active')
|
||||
ON CONFLICT (name) DO NOTHING;
|
||||
33
backend/migrations/002_account_type_migration.sql
Normal file
33
backend/migrations/002_account_type_migration.sql
Normal file
@@ -0,0 +1,33 @@
|
||||
-- Sub2API 账号类型迁移脚本
|
||||
-- 将 'official' 类型账号迁移为 'oauth' 或 'setup-token'
|
||||
-- 根据 credentials->>'scope' 字段判断:
|
||||
-- - 包含 'user:profile' 的是 'oauth' 类型
|
||||
-- - 只有 'user:inference' 的是 'setup-token' 类型
|
||||
|
||||
-- 1. 将包含 profile scope 的 official 账号迁移为 oauth
|
||||
UPDATE accounts
|
||||
SET type = 'oauth',
|
||||
updated_at = NOW()
|
||||
WHERE type = 'official'
|
||||
AND credentials->>'scope' LIKE '%user:profile%';
|
||||
|
||||
-- 2. 将只有 inference scope 的 official 账号迁移为 setup-token
|
||||
UPDATE accounts
|
||||
SET type = 'setup-token',
|
||||
updated_at = NOW()
|
||||
WHERE type = 'official'
|
||||
AND (
|
||||
credentials->>'scope' = 'user:inference'
|
||||
OR credentials->>'scope' NOT LIKE '%user:profile%'
|
||||
);
|
||||
|
||||
-- 3. 处理没有 scope 字段的旧账号(默认为 oauth)
|
||||
UPDATE accounts
|
||||
SET type = 'oauth',
|
||||
updated_at = NOW()
|
||||
WHERE type = 'official'
|
||||
AND (credentials->>'scope' IS NULL OR credentials->>'scope' = '');
|
||||
|
||||
-- 4. 验证迁移结果(查询是否还有 official 类型账号)
|
||||
-- SELECT COUNT(*) FROM accounts WHERE type = 'official';
|
||||
-- 如果结果为 0,说明迁移成功
|
||||
65
backend/migrations/003_subscription.sql
Normal file
65
backend/migrations/003_subscription.sql
Normal file
@@ -0,0 +1,65 @@
|
||||
-- Sub2API 订阅功能迁移脚本
|
||||
-- 添加订阅分组和用户订阅功能
|
||||
|
||||
-- 1. 扩展 groups 表添加订阅相关字段
|
||||
ALTER TABLE groups ADD COLUMN IF NOT EXISTS platform VARCHAR(50) NOT NULL DEFAULT 'anthropic';
|
||||
ALTER TABLE groups ADD COLUMN IF NOT EXISTS subscription_type VARCHAR(20) NOT NULL DEFAULT 'standard';
|
||||
ALTER TABLE groups ADD COLUMN IF NOT EXISTS daily_limit_usd DECIMAL(20, 8) DEFAULT NULL;
|
||||
ALTER TABLE groups ADD COLUMN IF NOT EXISTS weekly_limit_usd DECIMAL(20, 8) DEFAULT NULL;
|
||||
ALTER TABLE groups ADD COLUMN IF NOT EXISTS monthly_limit_usd DECIMAL(20, 8) DEFAULT NULL;
|
||||
ALTER TABLE groups ADD COLUMN IF NOT EXISTS default_validity_days INT NOT NULL DEFAULT 30;
|
||||
|
||||
-- 添加索引
|
||||
CREATE INDEX IF NOT EXISTS idx_groups_platform ON groups(platform);
|
||||
CREATE INDEX IF NOT EXISTS idx_groups_subscription_type ON groups(subscription_type);
|
||||
|
||||
-- 2. 创建 user_subscriptions 用户订阅表
|
||||
CREATE TABLE IF NOT EXISTS user_subscriptions (
|
||||
id BIGSERIAL PRIMARY KEY,
|
||||
user_id BIGINT NOT NULL REFERENCES users(id) ON DELETE CASCADE,
|
||||
group_id BIGINT NOT NULL REFERENCES groups(id) ON DELETE CASCADE,
|
||||
|
||||
-- 订阅有效期
|
||||
starts_at TIMESTAMPTZ NOT NULL,
|
||||
expires_at TIMESTAMPTZ NOT NULL,
|
||||
status VARCHAR(20) NOT NULL DEFAULT 'active', -- active/expired/suspended
|
||||
|
||||
-- 滑动窗口起始时间(NULL=未激活)
|
||||
daily_window_start TIMESTAMPTZ,
|
||||
weekly_window_start TIMESTAMPTZ,
|
||||
monthly_window_start TIMESTAMPTZ,
|
||||
|
||||
-- 当前窗口已用额度(USD,基于 total_cost 计算)
|
||||
daily_usage_usd DECIMAL(20, 10) NOT NULL DEFAULT 0,
|
||||
weekly_usage_usd DECIMAL(20, 10) NOT NULL DEFAULT 0,
|
||||
monthly_usage_usd DECIMAL(20, 10) NOT NULL DEFAULT 0,
|
||||
|
||||
-- 管理员分配信息
|
||||
assigned_by BIGINT REFERENCES users(id) ON DELETE SET NULL,
|
||||
assigned_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
notes TEXT,
|
||||
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
|
||||
-- 唯一约束:每个用户对每个分组只能有一个订阅
|
||||
UNIQUE(user_id, group_id)
|
||||
);
|
||||
|
||||
-- user_subscriptions 索引
|
||||
CREATE INDEX IF NOT EXISTS idx_user_subscriptions_user_id ON user_subscriptions(user_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_user_subscriptions_group_id ON user_subscriptions(group_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_user_subscriptions_status ON user_subscriptions(status);
|
||||
CREATE INDEX IF NOT EXISTS idx_user_subscriptions_expires_at ON user_subscriptions(expires_at);
|
||||
CREATE INDEX IF NOT EXISTS idx_user_subscriptions_assigned_by ON user_subscriptions(assigned_by);
|
||||
|
||||
-- 3. 扩展 usage_logs 表添加分组和订阅关联
|
||||
ALTER TABLE usage_logs ADD COLUMN IF NOT EXISTS group_id BIGINT REFERENCES groups(id) ON DELETE SET NULL;
|
||||
ALTER TABLE usage_logs ADD COLUMN IF NOT EXISTS subscription_id BIGINT REFERENCES user_subscriptions(id) ON DELETE SET NULL;
|
||||
ALTER TABLE usage_logs ADD COLUMN IF NOT EXISTS rate_multiplier DECIMAL(10, 4) NOT NULL DEFAULT 1;
|
||||
ALTER TABLE usage_logs ADD COLUMN IF NOT EXISTS first_token_ms INT;
|
||||
|
||||
-- usage_logs 新索引
|
||||
CREATE INDEX IF NOT EXISTS idx_usage_logs_group_id ON usage_logs(group_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_usage_logs_subscription_id ON usage_logs(subscription_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_usage_logs_sub_created ON usage_logs(subscription_id, created_at);
|
||||
37
backend/resources/model-pricing/README.md
Normal file
37
backend/resources/model-pricing/README.md
Normal file
@@ -0,0 +1,37 @@
|
||||
# Model Pricing Data
|
||||
|
||||
This directory contains a local copy of the mirrored model pricing data as a fallback mechanism.
|
||||
|
||||
## Source
|
||||
The original file is maintained by the LiteLLM project and mirrored into the `price-mirror` branch of this repository via GitHub Actions:
|
||||
- Mirror branch (configurable via `PRICE_MIRROR_REPO`): https://raw.githubusercontent.com/<your-repo>/price-mirror/model_prices_and_context_window.json
|
||||
- Upstream source: https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json
|
||||
|
||||
## Purpose
|
||||
This local copy serves as a fallback when the remote file cannot be downloaded due to:
|
||||
- Network restrictions
|
||||
- Firewall rules
|
||||
- DNS resolution issues
|
||||
- GitHub being blocked in certain regions
|
||||
- Docker container network limitations
|
||||
|
||||
## Update Process
|
||||
The pricingService will:
|
||||
1. First attempt to download the latest version from GitHub
|
||||
2. If download fails, use this local copy as fallback
|
||||
3. Log a warning when using the fallback file
|
||||
|
||||
## Manual Update
|
||||
To manually update this file with the latest pricing data (if automation is unavailable):
|
||||
```bash
|
||||
curl -s https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json -o model_prices_and_context_window.json
|
||||
```
|
||||
|
||||
## File Format
|
||||
The file contains JSON data with model pricing information including:
|
||||
- Model names and identifiers
|
||||
- Input/output token costs
|
||||
- Context window sizes
|
||||
- Model capabilities
|
||||
|
||||
Last updated: 2025-08-10
|
||||
31356
backend/resources/model-pricing/model_prices_and_context_window.json
Normal file
31356
backend/resources/model-pricing/model_prices_and_context_window.json
Normal file
File diff suppressed because it is too large
Load Diff
55
deploy/.env.example
Normal file
55
deploy/.env.example
Normal file
@@ -0,0 +1,55 @@
|
||||
# =============================================================================
|
||||
# Sub2API Docker Environment Configuration
|
||||
# =============================================================================
|
||||
# Copy this file to .env and modify as needed:
|
||||
# cp .env.example .env
|
||||
# nano .env
|
||||
#
|
||||
# Then start with: docker-compose up -d
|
||||
# =============================================================================
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Server Configuration
|
||||
# -----------------------------------------------------------------------------
|
||||
# Bind address for host port mapping
|
||||
BIND_HOST=0.0.0.0
|
||||
|
||||
# Server port (exposed on host)
|
||||
SERVER_PORT=8080
|
||||
|
||||
# Server mode: release or debug
|
||||
SERVER_MODE=release
|
||||
|
||||
# Timezone
|
||||
TZ=Asia/Shanghai
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# PostgreSQL Configuration (REQUIRED)
|
||||
# -----------------------------------------------------------------------------
|
||||
POSTGRES_USER=sub2api
|
||||
POSTGRES_PASSWORD=change_this_secure_password
|
||||
POSTGRES_DB=sub2api
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Redis Configuration
|
||||
# -----------------------------------------------------------------------------
|
||||
# Leave empty for no password (default for local development)
|
||||
REDIS_PASSWORD=
|
||||
REDIS_DB=0
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Admin Account
|
||||
# -----------------------------------------------------------------------------
|
||||
# Email for the admin account
|
||||
ADMIN_EMAIL=admin@sub2api.local
|
||||
|
||||
# Password for admin account
|
||||
# Leave empty to auto-generate (will be shown in logs on first run)
|
||||
ADMIN_PASSWORD=
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# JWT Configuration
|
||||
# -----------------------------------------------------------------------------
|
||||
# Leave empty to auto-generate (recommended)
|
||||
JWT_SECRET=
|
||||
JWT_EXPIRE_HOUR=24
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user