First commit

This commit is contained in:
shaw
2025-12-18 13:50:39 +08:00
parent 569f4882e5
commit 642842c29e
218 changed files with 86902 additions and 0 deletions

74
.dockerignore Normal file
View File

@@ -0,0 +1,74 @@
# =============================================================================
# Docker Ignore File for Sub2API
# =============================================================================
# Git
.git
.gitignore
.gitattributes
# Documentation
*.md
!deploy/DOCKER.md
docs/
# IDE
.idea/
.vscode/
*.swp
*.swo
# OS files
.DS_Store
Thumbs.db
# Build artifacts
dist/
build/
# Node modules (will be installed in container)
frontend/node_modules/
node_modules/
# Go build cache (will be built in container)
backend/vendor/
# Test files
*_test.go
**/*.test.js
coverage/
.nyc_output/
# Environment files
.env
.env.*
!.env.example
# Local config
config.yaml
config.local.yaml
# Logs
*.log
logs/
# Temporary files
tmp/
temp/
*.tmp
# Deploy files (not needed in image)
deploy/install.sh
deploy/sub2api.service
deploy/sub2api-sudoers
# GoReleaser
.goreleaser.yaml
# GitHub
.github/
# Claude files
.claude/
issues/
CLAUDE.md

178
.github/workflows/release.yml vendored Normal file
View File

@@ -0,0 +1,178 @@
name: Release
on:
push:
tags:
- 'v*'
permissions:
contents: write
packages: write
jobs:
# Update VERSION file with tag version
update-version:
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Update VERSION file
run: |
VERSION=${GITHUB_REF#refs/tags/v}
echo "$VERSION" > backend/cmd/server/VERSION
echo "Updated VERSION file to: $VERSION"
- name: Upload VERSION artifact
uses: actions/upload-artifact@v4
with:
name: version-file
path: backend/cmd/server/VERSION
retention-days: 1
build-frontend:
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Setup Node.js
uses: actions/setup-node@v4
with:
node-version: '20'
cache: 'npm'
cache-dependency-path: frontend/package-lock.json
- name: Install dependencies
run: npm ci
working-directory: frontend
- name: Build frontend
run: npm run build
working-directory: frontend
- name: Upload frontend artifact
uses: actions/upload-artifact@v4
with:
name: frontend-dist
path: backend/internal/web/dist/
retention-days: 1
release:
needs: [update-version, build-frontend]
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Download VERSION artifact
uses: actions/download-artifact@v4
with:
name: version-file
path: backend/cmd/server/
- name: Download frontend artifact
uses: actions/download-artifact@v4
with:
name: frontend-dist
path: backend/internal/web/dist/
- name: Setup Go
uses: actions/setup-go@v5
with:
go-version: '1.24'
cache-dependency-path: backend/go.sum
- name: Run GoReleaser
uses: goreleaser/goreleaser-action@v6
with:
version: '~> v2'
args: release --clean
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
# ===========================================================================
# Docker Build and Push
# ===========================================================================
docker:
needs: [update-version, build-frontend]
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Download VERSION artifact
uses: actions/download-artifact@v4
with:
name: version-file
path: backend/cmd/server/
- name: Download frontend artifact
uses: actions/download-artifact@v4
with:
name: frontend-dist
path: backend/internal/web/dist/
# Extract version from tag
- name: Extract version
id: version
run: |
VERSION=${GITHUB_REF#refs/tags/v}
echo "version=$VERSION" >> $GITHUB_OUTPUT
echo "Version: $VERSION"
# Set up Docker Buildx for multi-platform builds
- name: Set up QEMU
uses: docker/setup-qemu-action@v3
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
# Login to DockerHub
- name: Login to DockerHub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }}
# Extract metadata for Docker
- name: Extract Docker metadata
id: meta
uses: docker/metadata-action@v5
with:
images: |
weishaw/sub2api
tags: |
type=semver,pattern={{version}}
type=semver,pattern={{major}}.{{minor}}
type=semver,pattern={{major}}
type=raw,value=latest,enable={{is_default_branch}}
# Build and push Docker image
- name: Build and push Docker image
uses: docker/build-push-action@v5
with:
context: .
file: ./Dockerfile
platforms: linux/amd64,linux/arm64
push: true
tags: ${{ steps.meta.outputs.tags }}
labels: ${{ steps.meta.outputs.labels }}
build-args: |
VERSION=${{ steps.version.outputs.version }}
COMMIT=${{ github.sha }}
DATE=${{ github.event.head_commit.timestamp }}
cache-from: type=gha
cache-to: type=gha,mode=max
# Update DockerHub description (optional)
- name: Update DockerHub description
uses: peter-evans/dockerhub-description@v4
with:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }}
repository: weishaw/sub2api
short-description: "Sub2API - AI API Gateway Platform"
readme-filepath: ./deploy/DOCKER.md

93
.gitignore vendored Normal file
View File

@@ -0,0 +1,93 @@
docs/claude-relay-service/
# ===================
# Go 后端
# ===================
# 二进制文件
*.exe
*.exe~
*.dll
*.so
*.dylib
backend/bin/
backend/server
backend/sub2api
# 测试覆盖率
*.out
coverage.html
# 依赖(使用 go mod
vendor/
# ===================
# Node.js / Vue 前端
# ===================
node_modules/
frontend/node_modules/
frontend/dist/
*.local
# 日志
npm-debug.log*
yarn-debug.log*
yarn-error.log*
pnpm-debug.log*
# ===================
# 环境配置
# ===================
.env
.env.local
.env.*.local
*.env
!.env.example
# ===================
# IDE / 编辑器
# ===================
.idea/
.vscode/
*.swp
*.swo
*~
.project
.settings/
.classpath
# ===================
# 操作系统
# ===================
.DS_Store
Thumbs.db
Desktop.ini
# ===================
# 临时文件
# ===================
tmp/
temp/
*.tmp
*.temp
*.log
*.bak
# ===================
# 构建产物
# ===================
dist/
build/
release/
# 后端嵌入的前端构建产物
backend/internal/web/dist/
# 后端运行时缓存数据
backend/data/
# ===================
# 其他
# ===================
tests
CLAUDE.md
.claude

85
.goreleaser.yaml Normal file
View File

@@ -0,0 +1,85 @@
version: 2
project_name: sub2api
before:
hooks:
- go mod tidy -C backend
builds:
- id: sub2api
dir: backend
main: ./cmd/server
binary: sub2api
env:
- CGO_ENABLED=0
goos:
- linux
- windows
- darwin
goarch:
- amd64
- arm64
ignore:
- goos: windows
goarch: arm64
ldflags:
- -s -w
- -X main.Commit={{.Commit}}
- -X main.Date={{.Date}}
- -X main.BuildType=release
archives:
- id: default
format: tar.gz
name_template: >-
{{ .ProjectName }}_{{ .Version }}_{{ .Os }}_{{ .Arch }}
format_overrides:
- goos: windows
format: zip
files:
- LICENSE*
- README*
- deploy/*
checksum:
name_template: 'checksums.txt'
algorithm: sha256
changelog:
# 禁用自动 changelog完全使用 tag 消息
disable: true
release:
github:
owner: Wei-Shaw
name: sub2api
draft: false
prerelease: auto
name_template: "v{{.Version}}"
# 完全使用 tag 消息作为 release 内容
header: |
## Sub2API {{.Version}}
> AI API Gateway Platform - 将 AI 订阅配额分发和管理
{{ .TagBody }}
footer: |
---
## 📥 Installation
**One-line install (Linux):**
```bash
curl -sSL https://raw.githubusercontent.com/Wei-Shaw/sub2api/main/deploy/install.sh | sudo bash
```
**Manual download:**
Download the appropriate archive for your platform from the assets below.
## 📚 Documentation
- [GitHub Repository](https://github.com/Wei-Shaw/sub2api)
- [Installation Guide](https://github.com/Wei-Shaw/sub2api/blob/main/deploy/README.md)

96
Dockerfile Normal file
View File

@@ -0,0 +1,96 @@
# =============================================================================
# Sub2API Multi-Stage Dockerfile
# =============================================================================
# Stage 1: Build frontend
# Stage 2: Build Go backend with embedded frontend
# Stage 3: Final minimal image
# =============================================================================
# -----------------------------------------------------------------------------
# Stage 1: Frontend Builder
# -----------------------------------------------------------------------------
FROM node:20-alpine AS frontend-builder
WORKDIR /app/frontend
# Install dependencies first (better caching)
COPY frontend/package*.json ./
RUN npm ci
# Copy frontend source and build
COPY frontend/ ./
RUN npm run build
# -----------------------------------------------------------------------------
# Stage 2: Backend Builder
# -----------------------------------------------------------------------------
FROM golang:1.24-alpine AS backend-builder
# Build arguments for version info (set by CI)
ARG VERSION=docker
ARG COMMIT=docker
ARG DATE
# Install build dependencies
RUN apk add --no-cache git ca-certificates tzdata
WORKDIR /app/backend
# Copy go mod files first (better caching)
COPY backend/go.mod backend/go.sum ./
RUN go mod download
# Copy frontend dist from previous stage
COPY --from=frontend-builder /app/frontend/../backend/internal/web/dist ./internal/web/dist
# Copy backend source
COPY backend/ ./
# Build the binary (BuildType=release for CI builds)
RUN CGO_ENABLED=0 GOOS=linux go build \
-ldflags="-s -w -X main.Commit=${COMMIT} -X main.Date=${DATE:-$(date -u +%Y-%m-%dT%H:%M:%SZ)} -X main.BuildType=release" \
-o /app/sub2api \
./cmd/server
# -----------------------------------------------------------------------------
# Stage 3: Final Runtime Image
# -----------------------------------------------------------------------------
FROM alpine:3.19
# Labels
LABEL maintainer="Wei-Shaw <github.com/Wei-Shaw>"
LABEL description="Sub2API - AI API Gateway Platform"
LABEL org.opencontainers.image.source="https://github.com/Wei-Shaw/sub2api"
# Install runtime dependencies
RUN apk add --no-cache \
ca-certificates \
tzdata \
curl \
&& rm -rf /var/cache/apk/*
# Create non-root user
RUN addgroup -g 1000 sub2api && \
adduser -u 1000 -G sub2api -s /bin/sh -D sub2api
# Set working directory
WORKDIR /app
# Copy binary from builder
COPY --from=backend-builder /app/sub2api /app/sub2api
# Create data directory
RUN mkdir -p /app/data && chown -R sub2api:sub2api /app
# Switch to non-root user
USER sub2api
# Expose port (can be overridden by SERVER_PORT env var)
EXPOSE 8080
# Health check
HEALTHCHECK --interval=30s --timeout=10s --start-period=10s --retries=3 \
CMD curl -f http://localhost:${SERVER_PORT:-8080}/health || exit 1
# Run the application
ENTRYPOINT ["/app/sub2api"]

318
README.md Normal file
View File

@@ -0,0 +1,318 @@
# Sub2API
<div align="center">
[![Go](https://img.shields.io/badge/Go-1.21+-00ADD8.svg)](https://golang.org/)
[![Vue](https://img.shields.io/badge/Vue-3.4+-4FC08D.svg)](https://vuejs.org/)
[![PostgreSQL](https://img.shields.io/badge/PostgreSQL-15+-336791.svg)](https://www.postgresql.org/)
[![Redis](https://img.shields.io/badge/Redis-7+-DC382D.svg)](https://redis.io/)
[![Docker](https://img.shields.io/badge/Docker-Ready-2496ED.svg)](https://www.docker.com/)
**AI API Gateway Platform for Subscription Quota Distribution**
English | [中文](README_CN.md)
</div>
---
## Overview
Sub2API is an AI API gateway platform designed to distribute and manage API quotas from AI product subscriptions (like Claude Code $200/month). Users can access upstream AI services through platform-generated API Keys, while the platform handles authentication, billing, load balancing, and request forwarding.
## Features
- **Multi-Account Management** - Support multiple upstream account types (OAuth, API Key)
- **API Key Distribution** - Generate and manage API Keys for users
- **Precise Billing** - Token-level usage tracking and cost calculation
- **Smart Scheduling** - Intelligent account selection with sticky sessions
- **Concurrency Control** - Per-user and per-account concurrency limits
- **Rate Limiting** - Configurable request and token rate limits
- **Admin Dashboard** - Web interface for monitoring and management
## Tech Stack
| Component | Technology |
|-----------|------------|
| Backend | Go 1.21+, Gin, GORM |
| Frontend | Vue 3.4+, Vite 5+, TailwindCSS |
| Database | PostgreSQL 15+ |
| Cache/Queue | Redis 7+ |
---
## Deployment
### Method 1: Script Installation (Recommended)
One-click installation script that downloads pre-built binaries from GitHub Releases.
#### Prerequisites
- Linux server (amd64 or arm64)
- PostgreSQL 15+ (installed and running)
- Redis 7+ (installed and running)
- Root privileges
#### Installation Steps
```bash
# Download and run the installation script
curl -sSL https://raw.githubusercontent.com/Wei-Shaw/sub2api/main/deploy/install.sh | sudo bash
```
The script will:
1. Detect your system architecture
2. Download the latest release
3. Install binary to `/opt/sub2api`
4. Create systemd service
5. Configure system user and permissions
#### Post-Installation
```bash
# 1. Start the service
sudo systemctl start sub2api
# 2. Enable auto-start on boot
sudo systemctl enable sub2api
# 3. Open Setup Wizard in browser
# http://YOUR_SERVER_IP:8080
```
The Setup Wizard will guide you through:
- Database configuration
- Redis configuration
- Admin account creation
#### Upgrade
You can upgrade directly from the **Admin Dashboard** by clicking the **Check for Updates** button in the top-left corner.
The web interface will:
- Check for new versions automatically
- Download and apply updates with one click
- Support rollback if needed
#### Useful Commands
```bash
# Check status
sudo systemctl status sub2api
# View logs
sudo journalctl -u sub2api -f
# Restart service
sudo systemctl restart sub2api
# Uninstall
curl -sSL https://raw.githubusercontent.com/Wei-Shaw/sub2api/main/deploy/install.sh | sudo bash -s uninstall
```
---
### Method 2: Docker Compose
Deploy with Docker Compose, including PostgreSQL and Redis containers.
#### Prerequisites
- Docker 20.10+
- Docker Compose v2+
#### Installation Steps
```bash
# 1. Clone the repository
git clone https://github.com/Wei-Shaw/sub2api.git
cd sub2api
# 2. Enter the deploy directory
cd deploy
# 3. Copy environment configuration
cp .env.example .env
# 4. Edit configuration (set your passwords)
nano .env
```
**Required configuration in `.env`:**
```bash
# PostgreSQL password (REQUIRED - change this!)
POSTGRES_PASSWORD=your_secure_password_here
# Optional: Admin account
ADMIN_EMAIL=admin@example.com
ADMIN_PASSWORD=your_admin_password
# Optional: Custom port
SERVER_PORT=8080
```
```bash
# 5. Start all services
docker-compose up -d
# 6. Check status
docker-compose ps
# 7. View logs
docker-compose logs -f sub2api
```
#### Access
Open `http://YOUR_SERVER_IP:8080` in your browser.
#### Upgrade
```bash
# Pull latest image and recreate container
docker-compose pull
docker-compose up -d
```
#### Useful Commands
```bash
# Stop all services
docker-compose down
# Restart
docker-compose restart
# View all logs
docker-compose logs -f
```
---
### Method 3: Build from Source
Build and run from source code for development or customization.
#### Prerequisites
- Go 1.21+
- Node.js 18+
- PostgreSQL 15+
- Redis 7+
#### Build Steps
```bash
# 1. Clone the repository
git clone https://github.com/Wei-Shaw/sub2api.git
cd sub2api
# 2. Build backend
cd backend
go build -o sub2api ./cmd/server
# 3. Build frontend
cd ../frontend
npm install
npm run build
# 4. Copy frontend build to backend (for embedding)
cp -r dist ../backend/internal/web/
# 5. Create configuration file
cd ../backend
cp ../deploy/config.example.yaml ./config.yaml
# 6. Edit configuration
nano config.yaml
```
**Key configuration in `config.yaml`:**
```yaml
server:
host: "0.0.0.0"
port: 8080
mode: "release"
database:
host: "localhost"
port: 5432
user: "postgres"
password: "your_password"
dbname: "sub2api"
redis:
host: "localhost"
port: 6379
password: ""
jwt:
secret: "change-this-to-a-secure-random-string"
expire_hour: 24
default:
admin_email: "admin@example.com"
admin_password: "admin123"
```
```bash
# 7. Run the application
./sub2api
```
#### Development Mode
```bash
# Backend (with hot reload)
cd backend
go run ./cmd/server
# Frontend (with hot reload)
cd frontend
npm run dev
```
---
## Project Structure
```
sub2api/
├── backend/ # Go backend service
│ ├── cmd/server/ # Application entry
│ ├── internal/ # Internal modules
│ │ ├── config/ # Configuration
│ │ ├── model/ # Data models
│ │ ├── service/ # Business logic
│ │ ├── handler/ # HTTP handlers
│ │ └── gateway/ # API gateway core
│ └── resources/ # Static resources
├── frontend/ # Vue 3 frontend
│ └── src/
│ ├── api/ # API calls
│ ├── stores/ # State management
│ ├── views/ # Page components
│ └── components/ # Reusable components
└── deploy/ # Deployment files
├── docker-compose.yml # Docker Compose configuration
├── .env.example # Environment variables for Docker Compose
├── config.example.yaml # Full config file for binary deployment
└── install.sh # One-click installation script
```
## License
MIT License
---
<div align="center">
**If you find this project useful, please give it a star!**
</div>

318
README_CN.md Normal file
View File

@@ -0,0 +1,318 @@
# Sub2API
<div align="center">
[![Go](https://img.shields.io/badge/Go-1.21+-00ADD8.svg)](https://golang.org/)
[![Vue](https://img.shields.io/badge/Vue-3.4+-4FC08D.svg)](https://vuejs.org/)
[![PostgreSQL](https://img.shields.io/badge/PostgreSQL-15+-336791.svg)](https://www.postgresql.org/)
[![Redis](https://img.shields.io/badge/Redis-7+-DC382D.svg)](https://redis.io/)
[![Docker](https://img.shields.io/badge/Docker-Ready-2496ED.svg)](https://www.docker.com/)
**AI API 网关平台 - 订阅配额分发管理**
[English](README.md) | 中文
</div>
---
## 项目概述
Sub2API 是一个 AI API 网关平台,用于分发和管理 AI 产品订阅(如 Claude Code $200/月)的 API 配额。用户通过平台生成的 API Key 调用上游 AI 服务,平台负责鉴权、计费、负载均衡和请求转发。
## 核心功能
- **多账号管理** - 支持多种上游账号类型OAuth、API Key
- **API Key 分发** - 为用户生成和管理 API Key
- **精确计费** - Token 级别的用量追踪和成本计算
- **智能调度** - 智能账号选择,支持粘性会话
- **并发控制** - 用户级和账号级并发限制
- **速率限制** - 可配置的请求和 Token 速率限制
- **管理后台** - Web 界面进行监控和管理
## 技术栈
| 组件 | 技术 |
|------|------|
| 后端 | Go 1.21+, Gin, GORM |
| 前端 | Vue 3.4+, Vite 5+, TailwindCSS |
| 数据库 | PostgreSQL 15+ |
| 缓存/队列 | Redis 7+ |
---
## 部署方式
### 方式一:脚本安装(推荐)
一键安装脚本,自动从 GitHub Releases 下载预编译的二进制文件。
#### 前置条件
- Linux 服务器amd64 或 arm64
- PostgreSQL 15+(已安装并运行)
- Redis 7+(已安装并运行)
- Root 权限
#### 安装步骤
```bash
# 下载并运行安装脚本
curl -sSL https://raw.githubusercontent.com/Wei-Shaw/sub2api/main/deploy/install.sh | sudo bash
```
脚本会自动:
1. 检测系统架构
2. 下载最新版本
3. 安装二进制文件到 `/opt/sub2api`
4. 创建 systemd 服务
5. 配置系统用户和权限
#### 安装后配置
```bash
# 1. 启动服务
sudo systemctl start sub2api
# 2. 设置开机自启
sudo systemctl enable sub2api
# 3. 在浏览器中打开设置向导
# http://你的服务器IP:8080
```
设置向导将引导你完成:
- 数据库配置
- Redis 配置
- 管理员账号创建
#### 升级
可以直接在 **管理后台** 左上角点击 **检测更新** 按钮进行在线升级。
网页升级功能支持:
- 自动检测新版本
- 一键下载并应用更新
- 支持回滚
#### 常用命令
```bash
# 查看状态
sudo systemctl status sub2api
# 查看日志
sudo journalctl -u sub2api -f
# 重启服务
sudo systemctl restart sub2api
# 卸载
curl -sSL https://raw.githubusercontent.com/Wei-Shaw/sub2api/main/deploy/install.sh | sudo bash -s uninstall
```
---
### 方式二Docker Compose
使用 Docker Compose 部署,包含 PostgreSQL 和 Redis 容器。
#### 前置条件
- Docker 20.10+
- Docker Compose v2+
#### 安装步骤
```bash
# 1. 克隆仓库
git clone https://github.com/Wei-Shaw/sub2api.git
cd sub2api
# 2. 进入 deploy 目录
cd deploy
# 3. 复制环境配置文件
cp .env.example .env
# 4. 编辑配置(设置密码等)
nano .env
```
**`.env` 必须配置项:**
```bash
# PostgreSQL 密码(必须修改!)
POSTGRES_PASSWORD=your_secure_password_here
# 可选:管理员账号
ADMIN_EMAIL=admin@example.com
ADMIN_PASSWORD=your_admin_password
# 可选:自定义端口
SERVER_PORT=8080
```
```bash
# 5. 启动所有服务
docker-compose up -d
# 6. 查看状态
docker-compose ps
# 7. 查看日志
docker-compose logs -f sub2api
```
#### 访问
在浏览器中打开 `http://你的服务器IP:8080`
#### 升级
```bash
# 拉取最新镜像并重建容器
docker-compose pull
docker-compose up -d
```
#### 常用命令
```bash
# 停止所有服务
docker-compose down
# 重启
docker-compose restart
# 查看所有日志
docker-compose logs -f
```
---
### 方式三:源码编译
从源码编译安装,适合开发或定制需求。
#### 前置条件
- Go 1.21+
- Node.js 18+
- PostgreSQL 15+
- Redis 7+
#### 编译步骤
```bash
# 1. 克隆仓库
git clone https://github.com/Wei-Shaw/sub2api.git
cd sub2api
# 2. 编译后端
cd backend
go build -o sub2api ./cmd/server
# 3. 编译前端
cd ../frontend
npm install
npm run build
# 4. 复制前端构建产物到后端(用于嵌入)
cp -r dist ../backend/internal/web/
# 5. 创建配置文件
cd ../backend
cp ../deploy/config.example.yaml ./config.yaml
# 6. 编辑配置
nano config.yaml
```
**`config.yaml` 关键配置:**
```yaml
server:
host: "0.0.0.0"
port: 8080
mode: "release"
database:
host: "localhost"
port: 5432
user: "postgres"
password: "your_password"
dbname: "sub2api"
redis:
host: "localhost"
port: 6379
password: ""
jwt:
secret: "change-this-to-a-secure-random-string"
expire_hour: 24
default:
admin_email: "admin@example.com"
admin_password: "admin123"
```
```bash
# 7. 运行应用
./sub2api
```
#### 开发模式
```bash
# 后端(支持热重载)
cd backend
go run ./cmd/server
# 前端(支持热重载)
cd frontend
npm run dev
```
---
## 项目结构
```
sub2api/
├── backend/ # Go 后端服务
│ ├── cmd/server/ # 应用入口
│ ├── internal/ # 内部模块
│ │ ├── config/ # 配置管理
│ │ ├── model/ # 数据模型
│ │ ├── service/ # 业务逻辑
│ │ ├── handler/ # HTTP 处理器
│ │ └── gateway/ # API 网关核心
│ └── resources/ # 静态资源
├── frontend/ # Vue 3 前端
│ └── src/
│ ├── api/ # API 调用
│ ├── stores/ # 状态管理
│ ├── views/ # 页面组件
│ └── components/ # 通用组件
└── deploy/ # 部署文件
├── docker-compose.yml # Docker Compose 配置
├── .env.example # Docker Compose 环境变量
├── config.example.yaml # 二进制部署完整配置文件
└── install.sh # 一键安装脚本
```
## 许可证
MIT License
---
<div align="center">
**如果觉得有用,请给个 Star 支持一下!**
</div>

24
backend/Dockerfile Normal file
View File

@@ -0,0 +1,24 @@
FROM golang:1.21-alpine
WORKDIR /app
# 安装必要的工具
RUN apk add --no-cache git
# 复制go.mod和go.sum
COPY go.mod go.sum ./
# 下载依赖
RUN go mod download
# 复制源代码
COPY . .
# 构建应用
RUN go build -o main cmd/server/main.go
# 暴露端口
EXPOSE 8080
# 运行应用
CMD ["./main"]

View File

@@ -0,0 +1 @@
0.1.1

470
backend/cmd/server/main.go Normal file
View File

@@ -0,0 +1,470 @@
package main
import (
"context"
_ "embed"
"flag"
"log"
"net/http"
"os"
"os/signal"
"strings"
"syscall"
"time"
"sub2api/internal/config"
"sub2api/internal/handler"
"sub2api/internal/middleware"
"sub2api/internal/model"
"sub2api/internal/pkg/timezone"
"sub2api/internal/repository"
"sub2api/internal/service"
"sub2api/internal/setup"
"sub2api/internal/web"
"github.com/gin-gonic/gin"
"github.com/redis/go-redis/v9"
"gorm.io/driver/postgres"
"gorm.io/gorm"
"gorm.io/gorm/logger"
)
//go:embed VERSION
var embeddedVersion string
// Build-time variables (can be set by ldflags)
var (
Version = ""
Commit = "unknown"
Date = "unknown"
BuildType = "source" // "source" for manual builds, "release" for CI builds (set by ldflags)
)
func init() {
// Read version from embedded VERSION file
Version = strings.TrimSpace(embeddedVersion)
if Version == "" {
Version = "0.0.0-dev"
}
}
func main() {
// Parse command line flags
setupMode := flag.Bool("setup", false, "Run setup wizard in CLI mode")
showVersion := flag.Bool("version", false, "Show version information")
flag.Parse()
if *showVersion {
log.Printf("Sub2API %s (commit: %s, built: %s)\n", Version, Commit, Date)
return
}
// CLI setup mode
if *setupMode {
if err := setup.RunCLI(); err != nil {
log.Fatalf("Setup failed: %v", err)
}
return
}
// Check if setup is needed
if setup.NeedsSetup() {
// Check if auto-setup is enabled (for Docker deployment)
if setup.AutoSetupEnabled() {
log.Println("Auto setup mode enabled...")
if err := setup.AutoSetupFromEnv(); err != nil {
log.Fatalf("Auto setup failed: %v", err)
}
// Continue to main server after auto-setup
} else {
log.Println("First run detected, starting setup wizard...")
runSetupServer()
return
}
}
// Normal server mode
runMainServer()
}
func runSetupServer() {
r := gin.New()
r.Use(gin.Recovery())
r.Use(middleware.CORS())
// Register setup routes
setup.RegisterRoutes(r)
// Serve embedded frontend if available
if web.HasEmbeddedFrontend() {
r.Use(web.ServeEmbeddedFrontend())
}
addr := ":8080"
log.Printf("Setup wizard available at http://localhost%s", addr)
log.Println("Complete the setup wizard to configure Sub2API")
if err := r.Run(addr); err != nil {
log.Fatalf("Failed to start setup server: %v", err)
}
}
func runMainServer() {
// 加载配置
cfg, err := config.Load()
if err != nil {
log.Fatalf("Failed to load config: %v", err)
}
// 初始化时区(类似 PHP 的 date_default_timezone_set
if err := timezone.Init(cfg.Timezone); err != nil {
log.Fatalf("Failed to initialize timezone: %v", err)
}
// 初始化数据库
db, err := initDB(cfg)
if err != nil {
log.Fatalf("Failed to connect to database: %v", err)
}
// 初始化Redis
rdb := initRedis(cfg)
// 初始化Repository
repos := repository.NewRepositories(db)
// 初始化Service
services := service.NewServices(repos, rdb, cfg)
// 初始化Handler
buildInfo := handler.BuildInfo{
Version: Version,
BuildType: BuildType,
}
handlers := handler.NewHandlers(services, repos, rdb, buildInfo)
// 设置Gin模式
if cfg.Server.Mode == "release" {
gin.SetMode(gin.ReleaseMode)
}
// 创建路由
r := gin.New()
r.Use(gin.Recovery())
r.Use(middleware.Logger())
r.Use(middleware.CORS())
// 注册路由
registerRoutes(r, handlers, services, repos)
// Serve embedded frontend if available
if web.HasEmbeddedFrontend() {
r.Use(web.ServeEmbeddedFrontend())
}
// 启动服务器
srv := &http.Server{
Addr: cfg.Server.Address(),
Handler: r,
// ReadHeaderTimeout: 读取请求头的超时时间,防止慢速请求头攻击
ReadHeaderTimeout: time.Duration(cfg.Server.ReadHeaderTimeout) * time.Second,
// IdleTimeout: 空闲连接超时时间,释放不活跃的连接资源
IdleTimeout: time.Duration(cfg.Server.IdleTimeout) * time.Second,
// 注意:不设置 WriteTimeout因为流式响应可能持续十几分钟
// 不设置 ReadTimeout因为大请求体可能需要较长时间读取
}
// 优雅关闭
go func() {
if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed {
log.Fatalf("Failed to start server: %v", err)
}
}()
log.Printf("Server started on %s", cfg.Server.Address())
// 等待中断信号
quit := make(chan os.Signal, 1)
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
<-quit
log.Println("Shutting down server...")
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := srv.Shutdown(ctx); err != nil {
log.Fatalf("Server forced to shutdown: %v", err)
}
log.Println("Server exited")
}
func initDB(cfg *config.Config) (*gorm.DB, error) {
gormConfig := &gorm.Config{}
if cfg.Server.Mode == "debug" {
gormConfig.Logger = logger.Default.LogMode(logger.Info)
}
// 使用带时区的 DSN 连接数据库
db, err := gorm.Open(postgres.Open(cfg.Database.DSNWithTimezone(cfg.Timezone)), gormConfig)
if err != nil {
return nil, err
}
// 自动迁移(开发环境)
if cfg.Server.Mode == "debug" {
if err := model.AutoMigrate(db); err != nil {
return nil, err
}
}
return db, nil
}
func initRedis(cfg *config.Config) *redis.Client {
return redis.NewClient(&redis.Options{
Addr: cfg.Redis.Address(),
Password: cfg.Redis.Password,
DB: cfg.Redis.DB,
})
}
func registerRoutes(r *gin.Engine, h *handler.Handlers, s *service.Services, repos *repository.Repositories) {
// 健康检查
r.GET("/health", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"status": "ok"})
})
// API v1
v1 := r.Group("/api/v1")
{
// 公开接口
auth := v1.Group("/auth")
{
auth.POST("/register", h.Auth.Register)
auth.POST("/login", h.Auth.Login)
auth.POST("/send-verify-code", h.Auth.SendVerifyCode)
}
// 公开设置(无需认证)
settings := v1.Group("/settings")
{
settings.GET("/public", h.Setting.GetPublicSettings)
}
// 需要认证的接口
authenticated := v1.Group("")
authenticated.Use(middleware.JWTAuth(s.Auth, repos.User))
{
// 当前用户信息
authenticated.GET("/auth/me", h.Auth.GetCurrentUser)
// 用户接口
user := authenticated.Group("/user")
{
user.GET("/profile", h.User.GetProfile)
user.PUT("/password", h.User.ChangePassword)
}
// API Key管理
keys := authenticated.Group("/keys")
{
keys.GET("", h.APIKey.List)
keys.GET("/:id", h.APIKey.GetByID)
keys.POST("", h.APIKey.Create)
keys.PUT("/:id", h.APIKey.Update)
keys.DELETE("/:id", h.APIKey.Delete)
}
// 用户可用分组(非管理员接口)
groups := authenticated.Group("/groups")
{
groups.GET("/available", h.APIKey.GetAvailableGroups)
}
// 使用记录
usage := authenticated.Group("/usage")
{
usage.GET("", h.Usage.List)
usage.GET("/:id", h.Usage.GetByID)
usage.GET("/stats", h.Usage.Stats)
// User dashboard endpoints
usage.GET("/dashboard/stats", h.Usage.DashboardStats)
usage.GET("/dashboard/trend", h.Usage.DashboardTrend)
usage.GET("/dashboard/models", h.Usage.DashboardModels)
usage.POST("/dashboard/api-keys-usage", h.Usage.DashboardApiKeysUsage)
}
// 卡密兑换
redeem := authenticated.Group("/redeem")
{
redeem.POST("", h.Redeem.Redeem)
redeem.GET("/history", h.Redeem.GetHistory)
}
// 用户订阅
subscriptions := authenticated.Group("/subscriptions")
{
subscriptions.GET("", h.Subscription.List)
subscriptions.GET("/active", h.Subscription.GetActive)
subscriptions.GET("/progress", h.Subscription.GetProgress)
subscriptions.GET("/summary", h.Subscription.GetSummary)
}
}
// 管理员接口
admin := v1.Group("/admin")
admin.Use(middleware.JWTAuth(s.Auth, repos.User), middleware.AdminOnly())
{
// 仪表盘
dashboard := admin.Group("/dashboard")
{
dashboard.GET("/stats", h.Admin.Dashboard.GetStats)
dashboard.GET("/realtime", h.Admin.Dashboard.GetRealtimeMetrics)
dashboard.GET("/trend", h.Admin.Dashboard.GetUsageTrend)
dashboard.GET("/models", h.Admin.Dashboard.GetModelStats)
dashboard.GET("/api-keys-trend", h.Admin.Dashboard.GetApiKeyUsageTrend)
dashboard.GET("/users-trend", h.Admin.Dashboard.GetUserUsageTrend)
dashboard.POST("/users-usage", h.Admin.Dashboard.GetBatchUsersUsage)
dashboard.POST("/api-keys-usage", h.Admin.Dashboard.GetBatchApiKeysUsage)
}
// 用户管理
users := admin.Group("/users")
{
users.GET("", h.Admin.User.List)
users.GET("/:id", h.Admin.User.GetByID)
users.POST("", h.Admin.User.Create)
users.PUT("/:id", h.Admin.User.Update)
users.DELETE("/:id", h.Admin.User.Delete)
users.POST("/:id/balance", h.Admin.User.UpdateBalance)
users.GET("/:id/api-keys", h.Admin.User.GetUserAPIKeys)
users.GET("/:id/usage", h.Admin.User.GetUserUsage)
}
// 分组管理
groups := admin.Group("/groups")
{
groups.GET("", h.Admin.Group.List)
groups.GET("/all", h.Admin.Group.GetAll)
groups.GET("/:id", h.Admin.Group.GetByID)
groups.POST("", h.Admin.Group.Create)
groups.PUT("/:id", h.Admin.Group.Update)
groups.DELETE("/:id", h.Admin.Group.Delete)
groups.GET("/:id/stats", h.Admin.Group.GetStats)
groups.GET("/:id/api-keys", h.Admin.Group.GetGroupAPIKeys)
}
// 账号管理
accounts := admin.Group("/accounts")
{
accounts.GET("", h.Admin.Account.List)
accounts.GET("/:id", h.Admin.Account.GetByID)
accounts.POST("", h.Admin.Account.Create)
accounts.PUT("/:id", h.Admin.Account.Update)
accounts.DELETE("/:id", h.Admin.Account.Delete)
accounts.POST("/:id/test", h.Admin.Account.Test)
accounts.POST("/:id/refresh", h.Admin.Account.Refresh)
accounts.GET("/:id/stats", h.Admin.Account.GetStats)
accounts.POST("/:id/clear-error", h.Admin.Account.ClearError)
accounts.GET("/:id/usage", h.Admin.Account.GetUsage)
accounts.GET("/:id/today-stats", h.Admin.Account.GetTodayStats)
accounts.POST("/:id/clear-rate-limit", h.Admin.Account.ClearRateLimit)
accounts.POST("/:id/schedulable", h.Admin.Account.SetSchedulable)
accounts.POST("/batch", h.Admin.Account.BatchCreate)
// OAuth routes
accounts.POST("/generate-auth-url", h.Admin.OAuth.GenerateAuthURL)
accounts.POST("/generate-setup-token-url", h.Admin.OAuth.GenerateSetupTokenURL)
accounts.POST("/exchange-code", h.Admin.OAuth.ExchangeCode)
accounts.POST("/exchange-setup-token-code", h.Admin.OAuth.ExchangeSetupTokenCode)
accounts.POST("/cookie-auth", h.Admin.OAuth.CookieAuth)
accounts.POST("/setup-token-cookie-auth", h.Admin.OAuth.SetupTokenCookieAuth)
}
// 代理管理
proxies := admin.Group("/proxies")
{
proxies.GET("", h.Admin.Proxy.List)
proxies.GET("/all", h.Admin.Proxy.GetAll)
proxies.GET("/:id", h.Admin.Proxy.GetByID)
proxies.POST("", h.Admin.Proxy.Create)
proxies.PUT("/:id", h.Admin.Proxy.Update)
proxies.DELETE("/:id", h.Admin.Proxy.Delete)
proxies.POST("/:id/test", h.Admin.Proxy.Test)
proxies.GET("/:id/stats", h.Admin.Proxy.GetStats)
proxies.GET("/:id/accounts", h.Admin.Proxy.GetProxyAccounts)
proxies.POST("/batch", h.Admin.Proxy.BatchCreate)
}
// 卡密管理
codes := admin.Group("/redeem-codes")
{
codes.GET("", h.Admin.Redeem.List)
codes.GET("/stats", h.Admin.Redeem.GetStats)
codes.GET("/export", h.Admin.Redeem.Export)
codes.GET("/:id", h.Admin.Redeem.GetByID)
codes.POST("/generate", h.Admin.Redeem.Generate)
codes.DELETE("/:id", h.Admin.Redeem.Delete)
codes.POST("/batch-delete", h.Admin.Redeem.BatchDelete)
codes.POST("/:id/expire", h.Admin.Redeem.Expire)
}
// 系统设置
adminSettings := admin.Group("/settings")
{
adminSettings.GET("", h.Admin.Setting.GetSettings)
adminSettings.PUT("", h.Admin.Setting.UpdateSettings)
adminSettings.POST("/test-smtp", h.Admin.Setting.TestSmtpConnection)
adminSettings.POST("/send-test-email", h.Admin.Setting.SendTestEmail)
}
// 系统管理
system := admin.Group("/system")
{
system.GET("/version", h.Admin.System.GetVersion)
system.GET("/check-updates", h.Admin.System.CheckUpdates)
system.POST("/update", h.Admin.System.PerformUpdate)
system.POST("/rollback", h.Admin.System.Rollback)
system.POST("/restart", h.Admin.System.RestartService)
}
// 订阅管理
subscriptions := admin.Group("/subscriptions")
{
subscriptions.GET("", h.Admin.Subscription.List)
subscriptions.GET("/:id", h.Admin.Subscription.GetByID)
subscriptions.GET("/:id/progress", h.Admin.Subscription.GetProgress)
subscriptions.POST("/assign", h.Admin.Subscription.Assign)
subscriptions.POST("/bulk-assign", h.Admin.Subscription.BulkAssign)
subscriptions.POST("/:id/extend", h.Admin.Subscription.Extend)
subscriptions.DELETE("/:id", h.Admin.Subscription.Revoke)
}
// 分组下的订阅列表
admin.GET("/groups/:id/subscriptions", h.Admin.Subscription.ListByGroup)
// 用户下的订阅列表
admin.GET("/users/:id/subscriptions", h.Admin.Subscription.ListByUser)
// 使用记录管理
usage := admin.Group("/usage")
{
usage.GET("", h.Admin.Usage.List)
usage.GET("/stats", h.Admin.Usage.Stats)
usage.GET("/search-users", h.Admin.Usage.SearchUsers)
usage.GET("/search-api-keys", h.Admin.Usage.SearchApiKeys)
}
}
}
// API网关Claude API兼容
gateway := r.Group("/v1")
gateway.Use(middleware.ApiKeyAuthWithSubscription(s.ApiKey, s.Subscription))
{
gateway.POST("/messages", h.Gateway.Messages)
gateway.GET("/models", h.Gateway.Models)
gateway.GET("/usage", h.Gateway.Usage)
}
}

38
backend/config.yaml Normal file
View File

@@ -0,0 +1,38 @@
server:
host: "0.0.0.0"
port: 8080
mode: "debug" # debug/release
database:
host: "127.0.0.1"
port: 5432
user: "postgres"
password: "XZeRr7nkjHWhm8fw"
dbname: "sub2api"
sslmode: "disable"
redis:
host: "127.0.0.1"
port: 6379
password: ""
db: 0
jwt:
secret: "your-secret-key-change-in-production"
expire_hour: 24
default:
admin_email: "admin@sub2api.com"
admin_password: "admin123"
user_concurrency: 5
user_balance: 0
api_key_prefix: "sk-"
rate_multiplier: 1.0
# Timezone configuration (similar to PHP's date_default_timezone_set)
# This affects ALL time operations:
# - Database timestamps
# - Usage statistics "today" boundary
# - Subscription expiry times
# Common values: Asia/Shanghai, America/New_York, Europe/London, UTC
timezone: "Asia/Shanghai"

74
backend/go.mod Normal file
View File

@@ -0,0 +1,74 @@
module sub2api
go 1.24.0
toolchain go1.24.11
require (
github.com/gin-gonic/gin v1.9.1
github.com/golang-jwt/jwt/v5 v5.2.0
github.com/google/uuid v1.6.0
github.com/imroc/req/v3 v3.56.0
github.com/lib/pq v1.10.9
github.com/redis/go-redis/v9 v9.3.0
github.com/spf13/viper v1.18.2
golang.org/x/crypto v0.44.0
golang.org/x/term v0.37.0
gopkg.in/yaml.v3 v3.0.1
gorm.io/driver/postgres v1.5.4
gorm.io/gorm v1.25.5
)
require (
github.com/andybalholm/brotli v1.2.0 // indirect
github.com/bytedance/sonic v1.9.1 // indirect
github.com/cespare/xxhash/v2 v2.2.0 // indirect
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
github.com/fsnotify/fsnotify v1.7.0 // indirect
github.com/gabriel-vasile/mimetype v1.4.2 // indirect
github.com/gin-contrib/sse v0.1.0 // indirect
github.com/go-playground/locales v0.14.1 // indirect
github.com/go-playground/universal-translator v0.18.1 // indirect
github.com/go-playground/validator/v10 v10.14.0 // indirect
github.com/goccy/go-json v0.10.2 // indirect
github.com/google/go-querystring v1.1.0 // indirect
github.com/hashicorp/hcl v1.0.0 // indirect
github.com/icholy/digest v1.1.0 // indirect
github.com/jackc/pgpassfile v1.0.0 // indirect
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect
github.com/jackc/pgx/v5 v5.4.3 // indirect
github.com/jinzhu/inflection v1.0.0 // indirect
github.com/jinzhu/now v1.1.5 // indirect
github.com/json-iterator/go v1.1.12 // indirect
github.com/klauspost/compress v1.18.1 // indirect
github.com/klauspost/cpuid/v2 v2.2.4 // indirect
github.com/leodido/go-urn v1.2.4 // indirect
github.com/magiconair/properties v1.8.7 // indirect
github.com/mattn/go-isatty v0.0.19 // indirect
github.com/mitchellh/mapstructure v1.5.0 // indirect
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
github.com/modern-go/reflect2 v1.0.2 // indirect
github.com/pelletier/go-toml/v2 v2.1.0 // indirect
github.com/quic-go/qpack v0.5.1 // indirect
github.com/quic-go/quic-go v0.56.0 // indirect
github.com/refraction-networking/utls v1.8.1 // indirect
github.com/sagikazarmark/locafero v0.4.0 // indirect
github.com/sagikazarmark/slog-shim v0.1.0 // indirect
github.com/sourcegraph/conc v0.3.0 // indirect
github.com/spf13/afero v1.11.0 // indirect
github.com/spf13/cast v1.6.0 // indirect
github.com/spf13/pflag v1.0.5 // indirect
github.com/subosito/gotenv v1.6.0 // indirect
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
github.com/ugorji/go/codec v1.2.11 // indirect
go.uber.org/atomic v1.9.0 // indirect
go.uber.org/multierr v1.9.0 // indirect
golang.org/x/arch v0.3.0 // indirect
golang.org/x/exp v0.0.0-20230905200255-921286631fa9 // indirect
golang.org/x/net v0.47.0 // indirect
golang.org/x/sys v0.38.0 // indirect
golang.org/x/text v0.31.0 // indirect
google.golang.org/protobuf v1.31.0 // indirect
gopkg.in/ini.v1 v1.67.0 // indirect
)

187
backend/go.sum Normal file
View File

@@ -0,0 +1,187 @@
github.com/andybalholm/brotli v1.2.0 h1:ukwgCxwYrmACq68yiUqwIWnGY0cTPox/M94sVwToPjQ=
github.com/andybalholm/brotli v1.2.0/go.mod h1:rzTDkvFWvIrjDXZHkuS16NPggd91W3kUSvPlQ1pLaKY=
github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs=
github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c=
github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA=
github.com/bsm/gomega v1.27.10/go.mod h1:JyEr/xRbxbtgWNi8tIEVPUYZ5Dzef52k01W3YH0H+O0=
github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM=
github.com/bytedance/sonic v1.9.1 h1:6iJ6NqdoxCDr6mbY8h18oSO+cShGSMRGCEo7F2h0x8s=
github.com/bytedance/sonic v1.9.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U=
github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44=
github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/chenzhuoyu/base64x v0.0.0-20211019084208-fb5309c8db06/go.mod h1:DH46F32mSOjUmXrMHnKwZdA8wcEefY7UVqBKYGjpdQY=
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 h1:qSGYFH7+jGhDF8vLC+iwCD4WpbV1EBDSzWkJODFLams=
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311/go.mod h1:b583jCggY9gE99b6G5LEC39OIiVsWj+R97kbl5odCEk=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM=
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78=
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8=
github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0=
github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nosvA=
github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyTJwOpGvigM=
github.com/gabriel-vasile/mimetype v1.4.2 h1:w5qFW6JKBz9Y393Y4q372O9A7cUSequkh1Q7OhCmWKU=
github.com/gabriel-vasile/mimetype v1.4.2/go.mod h1:zApsH/mKG4w07erKIaJPFiX0Tsq9BFQgN3qGY5GnNgA=
github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE=
github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI=
github.com/gin-gonic/gin v1.9.1 h1:4idEAncQnU5cB7BeOkPtxjfCSye0AAm1R0RVIqJ+Jmg=
github.com/gin-gonic/gin v1.9.1/go.mod h1:hPrL7YrpYKXt5YId3A/Tnip5kqbEAP+KLuI3SUcPTeU=
github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s=
github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA=
github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY=
github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY=
github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY=
github.com/go-playground/validator/v10 v10.14.0 h1:vgvQWe3XCz3gIeFDm/HnTIbj6UGmg/+t63MyGU2n5js=
github.com/go-playground/validator/v10 v10.14.0/go.mod h1:9iXMNT7sEkjXb0I+enO7QXmzG6QCsPWY4zveKFVRSyU=
github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU=
github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
github.com/golang-jwt/jwt/v5 v5.2.0 h1:d/ix8ftRUorsN+5eMIlF4T6J8CAt9rch3My2winC1Jw=
github.com/golang-jwt/jwt/v5 v5.2.0/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk=
github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
github.com/google/go-querystring v1.1.0 h1:AnCroh3fv4ZBgVIf1Iwtovgjaw/GiKJo8M8yD/fhyJ8=
github.com/google/go-querystring v1.1.0/go.mod h1:Kcdr2DB4koayq7X8pmAG4sNG59So17icRSOU623lUBU=
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/hashicorp/hcl v1.0.0 h1:0Anlzjpi4vEasTeNFn2mLJgTSwt0+6sfsiTG8qcWGx4=
github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ=
github.com/icholy/digest v1.1.0 h1:HfGg9Irj7i+IX1o1QAmPfIBNu/Q5A5Tu3n/MED9k9H4=
github.com/icholy/digest v1.1.0/go.mod h1:QNrsSGQ5v7v9cReDI0+eyjsXGUoRSUZQHeQ5C4XLa0Y=
github.com/imroc/req/v3 v3.56.0 h1:t6YdqqerYBXhZ9+VjqsQs5wlKxdUNEvsgBhxWc1AEEo=
github.com/imroc/req/v3 v3.56.0/go.mod h1:cUZSooE8hhzFNOrAbdxuemXDQxFXLQTnu3066jr7ZGk=
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a h1:bbPeKD0xmW/Y25WS6cokEszi5g+S0QxI/d45PkRi7Nk=
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
github.com/jackc/pgx/v5 v5.4.3 h1:cxFyXhxlvAifxnkKKdlxv8XqUf59tDlYjnV5YYfsJJY=
github.com/jackc/pgx/v5 v5.4.3/go.mod h1:Ig06C2Vu0t5qXC60W8sqIthScaEnFvojjj9dSljmHRA=
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ=
github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
github.com/klauspost/compress v1.18.1 h1:bcSGx7UbpBqMChDtsF28Lw6v/G94LPrrbMbdC3JH2co=
github.com/klauspost/compress v1.18.1/go.mod h1:ZQFFVG+MdnR0P+l6wpXgIL4NTtwiKIdBnrBd8Nrxr+0=
github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg=
github.com/klauspost/cpuid/v2 v2.2.4 h1:acbojRNwl3o09bUq+yDCtZFc1aiwaAAxtcn8YkZXnvk=
github.com/klauspost/cpuid/v2 v2.2.4/go.mod h1:RVVoqg1df56z8g3pUjL/3lE5UfnlrJX8tyFgg4nqhuY=
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/leodido/go-urn v1.2.4 h1:XlAE/cm/ms7TE/VMVoduSpNBoyc2dOxHs5MZSwAN63Q=
github.com/leodido/go-urn v1.2.4/go.mod h1:7ZrI8mTSeBSHl/UaRyKQW1qZeMgak41ANeCNaVckg+4=
github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw=
github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
github.com/magiconair/properties v1.8.7 h1:IeQXZAiQcpL9mgcAe1Nu6cX9LLw6ExEHKjN0VQdvPDY=
github.com/magiconair/properties v1.8.7/go.mod h1:Dhd985XPs7jluiymwWYZ0G4Z61jb3vdS329zhj2hYo0=
github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA=
github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY=
github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo=
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg=
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M=
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
github.com/pelletier/go-toml/v2 v2.1.0 h1:FnwAJ4oYMvbT/34k9zzHuZNrhlz48GB3/s6at6/MHO4=
github.com/pelletier/go-toml/v2 v2.1.0/go.mod h1:tJU2Z3ZkXwnxa4DPO899bsyIoywizdUvyaeZurnPPDc=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U=
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/quic-go/qpack v0.5.1 h1:giqksBPnT/HDtZ6VhtFKgoLOWmlyo9Ei6u9PqzIMbhI=
github.com/quic-go/qpack v0.5.1/go.mod h1:+PC4XFrEskIVkcLzpEkbLqq1uCoxPhQuvK5rH1ZgaEg=
github.com/quic-go/quic-go v0.56.0 h1:q/TW+OLismmXAehgFLczhCDTYB3bFmua4D9lsNBWxvY=
github.com/quic-go/quic-go v0.56.0/go.mod h1:9gx5KsFQtw2oZ6GZTyh+7YEvOxWCL9WZAepnHxgAo6c=
github.com/redis/go-redis/v9 v9.3.0 h1:RiVDjmig62jIWp7Kk4XVLs0hzV6pI3PyTnnL0cnn0u0=
github.com/redis/go-redis/v9 v9.3.0/go.mod h1:hdY0cQFCN4fnSYT6TkisLufl/4W5UIXyv0b/CLO2V2M=
github.com/refraction-networking/utls v1.8.1 h1:yNY1kapmQU8JeM1sSw2H2asfTIwWxIkrMJI0pRUOCAo=
github.com/refraction-networking/utls v1.8.1/go.mod h1:jkSOEkLqn+S/jtpEHPOsVv/4V4EVnelwbMQl4vCWXAM=
github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ=
github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog=
github.com/sagikazarmark/locafero v0.4.0 h1:HApY1R9zGo4DBgr7dqsTH/JJxLTTsOt7u6keLGt6kNQ=
github.com/sagikazarmark/locafero v0.4.0/go.mod h1:Pe1W6UlPYUk/+wc/6KFhbORCfqzgYEpgQ3O5fPuL3H4=
github.com/sagikazarmark/slog-shim v0.1.0 h1:diDBnUNK9N/354PgrxMywXnAwEr1QZcOr6gto+ugjYE=
github.com/sagikazarmark/slog-shim v0.1.0/go.mod h1:SrcSrq8aKtyuqEI1uvTDTK1arOWRIczQRv+GVI1AkeQ=
github.com/sourcegraph/conc v0.3.0 h1:OQTbbt6P72L20UqAkXXuLOj79LfEanQ+YQFNpLA9ySo=
github.com/sourcegraph/conc v0.3.0/go.mod h1:Sdozi7LEKbFPqYX2/J+iBAM6HpqSLTASQIKqDmF7Mt0=
github.com/spf13/afero v1.11.0 h1:WJQKhtpdm3v2IzqG8VMqrr6Rf3UYpEF239Jy9wNepM8=
github.com/spf13/afero v1.11.0/go.mod h1:GH9Y3pIexgf1MTIWtNGyogA5MwRIDXGUr+hbWNoBjkY=
github.com/spf13/cast v1.6.0 h1:GEiTHELF+vaR5dhz3VqZfFSzZjYbgeKDpBxQVS4GYJ0=
github.com/spf13/cast v1.6.0/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo=
github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=
github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
github.com/spf13/viper v1.18.2 h1:LUXCnvUvSM6FXAsj6nnfc8Q2tp1dIgUfY9Kc8GsSOiQ=
github.com/spf13/viper v1.18.2/go.mod h1:EKmWIqdnk5lOcmR72yw6hS+8OPYcwD0jteitLMVB+yk=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8=
github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU=
github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI=
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
github.com/ugorji/go/codec v1.2.11 h1:BMaWp1Bb6fHwEtbplGBGJ498wD+LKlNSl25MjdZY4dU=
github.com/ugorji/go/codec v1.2.11/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg=
github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZqKjWU=
github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E=
go.uber.org/atomic v1.9.0 h1:ECmE8Bn/WFTYwEW/bpKD3M8VtR/zQVbavAoalC1PYyE=
go.uber.org/atomic v1.9.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc=
go.uber.org/mock v0.6.0 h1:hyF9dfmbgIX5EfOdasqLsWD6xqpNZlXblLB/Dbnwv3Y=
go.uber.org/mock v0.6.0/go.mod h1:KiVJ4BqZJaMj4svdfmHM0AUx4NJYO8ZNpPnZn1Z+BBU=
go.uber.org/multierr v1.9.0 h1:7fIwc/ZtS0q++VgcfqFDxSBZVv/Xo49/SYnDFupUwlI=
go.uber.org/multierr v1.9.0/go.mod h1:X2jQV1h+kxSjClGpnseKVIxpmcjrj7MNnI0bnlfKTVQ=
golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
golang.org/x/arch v0.3.0 h1:02VY4/ZcO/gBOH6PUaoiptASxtXU10jazRCP865E97k=
golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
golang.org/x/crypto v0.44.0 h1:A97SsFvM3AIwEEmTBiaxPPTYpDC47w720rdiiUvgoAU=
golang.org/x/crypto v0.44.0/go.mod h1:013i+Nw79BMiQiMsOPcVCB5ZIJbYkerPrGnOa00tvmc=
golang.org/x/exp v0.0.0-20230905200255-921286631fa9 h1:GoHiUyI/Tp2nVkLI2mCxVkOjsbSXD66ic0XW0js0R9g=
golang.org/x/exp v0.0.0-20230905200255-921286631fa9/go.mod h1:S2oDrQGGwySpoQPVqRShND87VCbxmc6bL1Yd2oYrm6k=
golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY=
golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU=
golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc=
golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.org/x/term v0.37.0 h1:8EGAD0qCmHYZg6J17DvsMy9/wJ7/D/4pV/wfnld5lTU=
golang.org/x/term v0.37.0/go.mod h1:5pB4lxRNYYVZuTLmy8oR2BH8dflOR+IbTYFD8fi3254=
golang.org/x/text v0.31.0 h1:aC8ghyu4JhP8VojJ2lEHBnochRno1sgL6nEi9WGFGMM=
golang.org/x/text v0.31.0/go.mod h1:tKRAlv61yKIjGGHX/4tP1LTbc13YSec1pxVEWXzfoeM=
golang.org/x/time v0.12.0 h1:ScB/8o8olJvc+CQPWrK3fPZNfh7qgwCrY0zJmoEQLSE=
golang.org/x/time v0.12.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
google.golang.org/protobuf v1.31.0 h1:g0LDEJHgrBl9N9r17Ru3sqWhkIx2NB67okBHPwC7hs8=
google.golang.org/protobuf v1.31.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
gopkg.in/ini.v1 v1.67.0 h1:Dgnx+6+nfE+IfzjUEISNeydPJh9AXNNsWbGP9KzCsOA=
gopkg.in/ini.v1 v1.67.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gorm.io/driver/postgres v1.5.4 h1:Iyrp9Meh3GmbSuyIAGyjkN+n9K+GHX9b9MqsTL4EJCo=
gorm.io/driver/postgres v1.5.4/go.mod h1:Bgo89+h0CRcdA33Y6frlaHHVuTdOf87pmyzwW9C/BH0=
gorm.io/gorm v1.25.5 h1:zR9lOiiYf09VNh5Q1gphfyia1JpiClIWG9hQaxB/mls=
gorm.io/gorm v1.25.5/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8=
gotest.tools/v3 v3.5.1 h1:EENdUnS3pdur5nybKYIh2Vfgc8IUNBjxDPSjtiJcOzU=
gotest.tools/v3 v3.5.1/go.mod h1:isy3WKz7GK6uNw/sbHzfKBLvlvXwUyV06n6brMxxopU=
rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4=

View File

@@ -0,0 +1,205 @@
package config
import (
"fmt"
"strings"
"github.com/spf13/viper"
)
type Config struct {
Server ServerConfig `mapstructure:"server"`
Database DatabaseConfig `mapstructure:"database"`
Redis RedisConfig `mapstructure:"redis"`
JWT JWTConfig `mapstructure:"jwt"`
Default DefaultConfig `mapstructure:"default"`
RateLimit RateLimitConfig `mapstructure:"rate_limit"`
Pricing PricingConfig `mapstructure:"pricing"`
Gateway GatewayConfig `mapstructure:"gateway"`
Timezone string `mapstructure:"timezone"` // e.g. "Asia/Shanghai", "UTC"
}
type PricingConfig struct {
// 价格数据远程URL默认使用LiteLLM镜像
RemoteURL string `mapstructure:"remote_url"`
// 哈希校验文件URL
HashURL string `mapstructure:"hash_url"`
// 本地数据目录
DataDir string `mapstructure:"data_dir"`
// 回退文件路径
FallbackFile string `mapstructure:"fallback_file"`
// 更新间隔(小时)
UpdateIntervalHours int `mapstructure:"update_interval_hours"`
// 哈希校验间隔(分钟)
HashCheckIntervalMinutes int `mapstructure:"hash_check_interval_minutes"`
}
type ServerConfig struct {
Host string `mapstructure:"host"`
Port int `mapstructure:"port"`
Mode string `mapstructure:"mode"` // debug/release
ReadHeaderTimeout int `mapstructure:"read_header_timeout"` // 读取请求头超时(秒)
IdleTimeout int `mapstructure:"idle_timeout"` // 空闲连接超时(秒)
}
// GatewayConfig API网关相关配置
type GatewayConfig struct {
// 等待上游响应头的超时时间0表示无超时
// 注意:这不影响流式数据传输,只控制等待响应头的时间
ResponseHeaderTimeout int `mapstructure:"response_header_timeout"`
}
func (s *ServerConfig) Address() string {
return fmt.Sprintf("%s:%d", s.Host, s.Port)
}
type DatabaseConfig struct {
Host string `mapstructure:"host"`
Port int `mapstructure:"port"`
User string `mapstructure:"user"`
Password string `mapstructure:"password"`
DBName string `mapstructure:"dbname"`
SSLMode string `mapstructure:"sslmode"`
}
func (d *DatabaseConfig) DSN() string {
return fmt.Sprintf(
"host=%s port=%d user=%s password=%s dbname=%s sslmode=%s",
d.Host, d.Port, d.User, d.Password, d.DBName, d.SSLMode,
)
}
// DSNWithTimezone returns DSN with timezone setting
func (d *DatabaseConfig) DSNWithTimezone(tz string) string {
if tz == "" {
tz = "Asia/Shanghai"
}
return fmt.Sprintf(
"host=%s port=%d user=%s password=%s dbname=%s sslmode=%s TimeZone=%s",
d.Host, d.Port, d.User, d.Password, d.DBName, d.SSLMode, tz,
)
}
type RedisConfig struct {
Host string `mapstructure:"host"`
Port int `mapstructure:"port"`
Password string `mapstructure:"password"`
DB int `mapstructure:"db"`
}
func (r *RedisConfig) Address() string {
return fmt.Sprintf("%s:%d", r.Host, r.Port)
}
type JWTConfig struct {
Secret string `mapstructure:"secret"`
ExpireHour int `mapstructure:"expire_hour"`
}
type DefaultConfig struct {
AdminEmail string `mapstructure:"admin_email"`
AdminPassword string `mapstructure:"admin_password"`
UserConcurrency int `mapstructure:"user_concurrency"`
UserBalance float64 `mapstructure:"user_balance"`
ApiKeyPrefix string `mapstructure:"api_key_prefix"`
RateMultiplier float64 `mapstructure:"rate_multiplier"`
}
type RateLimitConfig struct {
OverloadCooldownMinutes int `mapstructure:"overload_cooldown_minutes"` // 529过载冷却时间(分钟)
}
func Load() (*Config, error) {
viper.SetConfigName("config")
viper.SetConfigType("yaml")
viper.AddConfigPath(".")
viper.AddConfigPath("./config")
viper.AddConfigPath("/etc/sub2api")
// 环境变量支持
viper.AutomaticEnv()
viper.SetEnvKeyReplacer(strings.NewReplacer(".", "_"))
// 默认值
setDefaults()
if err := viper.ReadInConfig(); err != nil {
if _, ok := err.(viper.ConfigFileNotFoundError); !ok {
return nil, fmt.Errorf("read config error: %w", err)
}
// 配置文件不存在时使用默认值
}
var cfg Config
if err := viper.Unmarshal(&cfg); err != nil {
return nil, fmt.Errorf("unmarshal config error: %w", err)
}
if err := cfg.Validate(); err != nil {
return nil, fmt.Errorf("validate config error: %w", err)
}
return &cfg, nil
}
func setDefaults() {
// Server
viper.SetDefault("server.host", "0.0.0.0")
viper.SetDefault("server.port", 8080)
viper.SetDefault("server.mode", "debug")
viper.SetDefault("server.read_header_timeout", 30) // 30秒读取请求头
viper.SetDefault("server.idle_timeout", 120) // 120秒空闲超时
// Database
viper.SetDefault("database.host", "localhost")
viper.SetDefault("database.port", 5432)
viper.SetDefault("database.user", "postgres")
viper.SetDefault("database.password", "postgres")
viper.SetDefault("database.dbname", "sub2api")
viper.SetDefault("database.sslmode", "disable")
// Redis
viper.SetDefault("redis.host", "localhost")
viper.SetDefault("redis.port", 6379)
viper.SetDefault("redis.password", "")
viper.SetDefault("redis.db", 0)
// JWT
viper.SetDefault("jwt.secret", "change-me-in-production")
viper.SetDefault("jwt.expire_hour", 24)
// Default
viper.SetDefault("default.admin_email", "admin@sub2api.com")
viper.SetDefault("default.admin_password", "admin123")
viper.SetDefault("default.user_concurrency", 5)
viper.SetDefault("default.user_balance", 0)
viper.SetDefault("default.api_key_prefix", "sk-")
viper.SetDefault("default.rate_multiplier", 1.0)
// RateLimit
viper.SetDefault("rate_limit.overload_cooldown_minutes", 10)
// Pricing - 从 price-mirror 分支同步,该分支维护了 sha256 哈希文件用于增量更新检查
viper.SetDefault("pricing.remote_url", "https://raw.githubusercontent.com/Wei-Shaw/claude-relay-service/price-mirror/model_prices_and_context_window.json")
viper.SetDefault("pricing.hash_url", "https://raw.githubusercontent.com/Wei-Shaw/claude-relay-service/price-mirror/model_prices_and_context_window.sha256")
viper.SetDefault("pricing.data_dir", "./data")
viper.SetDefault("pricing.fallback_file", "./resources/model-pricing/model_prices_and_context_window.json")
viper.SetDefault("pricing.update_interval_hours", 24)
viper.SetDefault("pricing.hash_check_interval_minutes", 10)
// Timezone (default to Asia/Shanghai for Chinese users)
viper.SetDefault("timezone", "Asia/Shanghai")
// Gateway
viper.SetDefault("gateway.response_header_timeout", 300) // 300秒(5分钟)等待上游响应头LLM高负载时可能排队较久
}
func (c *Config) Validate() error {
if c.JWT.Secret == "" {
return fmt.Errorf("jwt.secret is required")
}
if c.JWT.Secret == "change-me-in-production" && c.Server.Mode == "release" {
return fmt.Errorf("jwt.secret must be changed in production")
}
return nil
}

View File

@@ -0,0 +1,537 @@
package admin
import (
"strconv"
"sub2api/internal/pkg/response"
"sub2api/internal/service"
"github.com/gin-gonic/gin"
)
// OAuthHandler handles OAuth-related operations for accounts
type OAuthHandler struct {
oauthService *service.OAuthService
adminService service.AdminService
}
// NewOAuthHandler creates a new OAuth handler
func NewOAuthHandler(oauthService *service.OAuthService, adminService service.AdminService) *OAuthHandler {
return &OAuthHandler{
oauthService: oauthService,
adminService: adminService,
}
}
// AccountHandler handles admin account management
type AccountHandler struct {
adminService service.AdminService
oauthService *service.OAuthService
rateLimitService *service.RateLimitService
accountUsageService *service.AccountUsageService
accountTestService *service.AccountTestService
}
// NewAccountHandler creates a new admin account handler
func NewAccountHandler(adminService service.AdminService, oauthService *service.OAuthService, rateLimitService *service.RateLimitService, accountUsageService *service.AccountUsageService, accountTestService *service.AccountTestService) *AccountHandler {
return &AccountHandler{
adminService: adminService,
oauthService: oauthService,
rateLimitService: rateLimitService,
accountUsageService: accountUsageService,
accountTestService: accountTestService,
}
}
// CreateAccountRequest represents create account request
type CreateAccountRequest struct {
Name string `json:"name" binding:"required"`
Platform string `json:"platform" binding:"required"`
Type string `json:"type" binding:"required,oneof=oauth setup-token apikey"`
Credentials map[string]interface{} `json:"credentials" binding:"required"`
Extra map[string]interface{} `json:"extra"`
ProxyID *int64 `json:"proxy_id"`
Concurrency int `json:"concurrency"`
Priority int `json:"priority"`
GroupIDs []int64 `json:"group_ids"`
}
// UpdateAccountRequest represents update account request
// 使用指针类型来区分"未提供"和"设置为0"
type UpdateAccountRequest struct {
Name string `json:"name"`
Type string `json:"type" binding:"omitempty,oneof=oauth setup-token apikey"`
Credentials map[string]interface{} `json:"credentials"`
Extra map[string]interface{} `json:"extra"`
ProxyID *int64 `json:"proxy_id"`
Concurrency *int `json:"concurrency"`
Priority *int `json:"priority"`
Status string `json:"status" binding:"omitempty,oneof=active inactive"`
GroupIDs *[]int64 `json:"group_ids"`
}
// List handles listing all accounts with pagination
// GET /api/v1/admin/accounts
func (h *AccountHandler) List(c *gin.Context) {
page, pageSize := response.ParsePagination(c)
platform := c.Query("platform")
accountType := c.Query("type")
status := c.Query("status")
search := c.Query("search")
accounts, total, err := h.adminService.ListAccounts(c.Request.Context(), page, pageSize, platform, accountType, status, search)
if err != nil {
response.InternalError(c, "Failed to list accounts: "+err.Error())
return
}
response.Paginated(c, accounts, total, page, pageSize)
}
// GetByID handles getting an account by ID
// GET /api/v1/admin/accounts/:id
func (h *AccountHandler) GetByID(c *gin.Context) {
accountID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid account ID")
return
}
account, err := h.adminService.GetAccount(c.Request.Context(), accountID)
if err != nil {
response.NotFound(c, "Account not found")
return
}
response.Success(c, account)
}
// Create handles creating a new account
// POST /api/v1/admin/accounts
func (h *AccountHandler) Create(c *gin.Context) {
var req CreateAccountRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
account, err := h.adminService.CreateAccount(c.Request.Context(), &service.CreateAccountInput{
Name: req.Name,
Platform: req.Platform,
Type: req.Type,
Credentials: req.Credentials,
Extra: req.Extra,
ProxyID: req.ProxyID,
Concurrency: req.Concurrency,
Priority: req.Priority,
GroupIDs: req.GroupIDs,
})
if err != nil {
response.BadRequest(c, "Failed to create account: "+err.Error())
return
}
response.Success(c, account)
}
// Update handles updating an account
// PUT /api/v1/admin/accounts/:id
func (h *AccountHandler) Update(c *gin.Context) {
accountID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid account ID")
return
}
var req UpdateAccountRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
account, err := h.adminService.UpdateAccount(c.Request.Context(), accountID, &service.UpdateAccountInput{
Name: req.Name,
Type: req.Type,
Credentials: req.Credentials,
Extra: req.Extra,
ProxyID: req.ProxyID,
Concurrency: req.Concurrency, // 指针类型nil 表示未提供
Priority: req.Priority, // 指针类型nil 表示未提供
Status: req.Status,
GroupIDs: req.GroupIDs,
})
if err != nil {
response.InternalError(c, "Failed to update account: "+err.Error())
return
}
response.Success(c, account)
}
// Delete handles deleting an account
// DELETE /api/v1/admin/accounts/:id
func (h *AccountHandler) Delete(c *gin.Context) {
accountID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid account ID")
return
}
err = h.adminService.DeleteAccount(c.Request.Context(), accountID)
if err != nil {
response.InternalError(c, "Failed to delete account: "+err.Error())
return
}
response.Success(c, gin.H{"message": "Account deleted successfully"})
}
// Test handles testing account connectivity with SSE streaming
// POST /api/v1/admin/accounts/:id/test
func (h *AccountHandler) Test(c *gin.Context) {
accountID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid account ID")
return
}
// Use AccountTestService to test the account with SSE streaming
if err := h.accountTestService.TestAccountConnection(c, accountID); err != nil {
// Error already sent via SSE, just log
return
}
}
// Refresh handles refreshing account credentials
// POST /api/v1/admin/accounts/:id/refresh
func (h *AccountHandler) Refresh(c *gin.Context) {
accountID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid account ID")
return
}
// Get account
account, err := h.adminService.GetAccount(c.Request.Context(), accountID)
if err != nil {
response.NotFound(c, "Account not found")
return
}
// Only refresh OAuth-based accounts (oauth and setup-token)
if !account.IsOAuth() {
response.BadRequest(c, "Cannot refresh non-OAuth account credentials")
return
}
// Use OAuth service to refresh token
tokenInfo, err := h.oauthService.RefreshAccountToken(c.Request.Context(), account)
if err != nil {
response.InternalError(c, "Failed to refresh credentials: "+err.Error())
return
}
// Update account credentials
newCredentials := map[string]interface{}{
"access_token": tokenInfo.AccessToken,
"token_type": tokenInfo.TokenType,
"expires_in": tokenInfo.ExpiresIn,
"expires_at": tokenInfo.ExpiresAt,
"refresh_token": tokenInfo.RefreshToken,
"scope": tokenInfo.Scope,
}
updatedAccount, err := h.adminService.UpdateAccount(c.Request.Context(), accountID, &service.UpdateAccountInput{
Credentials: newCredentials,
})
if err != nil {
response.InternalError(c, "Failed to update account credentials: "+err.Error())
return
}
response.Success(c, updatedAccount)
}
// GetStats handles getting account statistics
// GET /api/v1/admin/accounts/:id/stats
func (h *AccountHandler) GetStats(c *gin.Context) {
accountID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid account ID")
return
}
// Return mock data for now
_ = accountID
response.Success(c, gin.H{
"total_requests": 0,
"successful_requests": 0,
"failed_requests": 0,
"total_tokens": 0,
"average_response_time": 0,
})
}
// ClearError handles clearing account error
// POST /api/v1/admin/accounts/:id/clear-error
func (h *AccountHandler) ClearError(c *gin.Context) {
accountID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid account ID")
return
}
account, err := h.adminService.ClearAccountError(c.Request.Context(), accountID)
if err != nil {
response.InternalError(c, "Failed to clear error: "+err.Error())
return
}
response.Success(c, account)
}
// BatchCreate handles batch creating accounts
// POST /api/v1/admin/accounts/batch
func (h *AccountHandler) BatchCreate(c *gin.Context) {
var req struct {
Accounts []CreateAccountRequest `json:"accounts" binding:"required,min=1"`
}
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
// Return mock data for now
response.Success(c, gin.H{
"success": len(req.Accounts),
"failed": 0,
"results": []gin.H{},
})
}
// ========== OAuth Handlers ==========
// GenerateAuthURLRequest represents the request for generating auth URL
type GenerateAuthURLRequest struct {
ProxyID *int64 `json:"proxy_id"`
}
// GenerateAuthURL generates OAuth authorization URL with full scope
// POST /api/v1/admin/accounts/generate-auth-url
func (h *OAuthHandler) GenerateAuthURL(c *gin.Context) {
var req GenerateAuthURLRequest
if err := c.ShouldBindJSON(&req); err != nil {
// Allow empty body
req = GenerateAuthURLRequest{}
}
result, err := h.oauthService.GenerateAuthURL(c.Request.Context(), req.ProxyID)
if err != nil {
response.InternalError(c, "Failed to generate auth URL: "+err.Error())
return
}
response.Success(c, result)
}
// GenerateSetupTokenURL generates OAuth authorization URL for setup token (inference only)
// POST /api/v1/admin/accounts/generate-setup-token-url
func (h *OAuthHandler) GenerateSetupTokenURL(c *gin.Context) {
var req GenerateAuthURLRequest
if err := c.ShouldBindJSON(&req); err != nil {
// Allow empty body
req = GenerateAuthURLRequest{}
}
result, err := h.oauthService.GenerateSetupTokenURL(c.Request.Context(), req.ProxyID)
if err != nil {
response.InternalError(c, "Failed to generate setup token URL: "+err.Error())
return
}
response.Success(c, result)
}
// ExchangeCodeRequest represents the request for exchanging auth code
type ExchangeCodeRequest struct {
SessionID string `json:"session_id" binding:"required"`
Code string `json:"code" binding:"required"`
ProxyID *int64 `json:"proxy_id"`
}
// ExchangeCode exchanges authorization code for tokens
// POST /api/v1/admin/accounts/exchange-code
func (h *OAuthHandler) ExchangeCode(c *gin.Context) {
var req ExchangeCodeRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
tokenInfo, err := h.oauthService.ExchangeCode(c.Request.Context(), &service.ExchangeCodeInput{
SessionID: req.SessionID,
Code: req.Code,
ProxyID: req.ProxyID,
})
if err != nil {
response.BadRequest(c, "Failed to exchange code: "+err.Error())
return
}
response.Success(c, tokenInfo)
}
// ExchangeSetupTokenCode exchanges authorization code for setup token
// POST /api/v1/admin/accounts/exchange-setup-token-code
func (h *OAuthHandler) ExchangeSetupTokenCode(c *gin.Context) {
var req ExchangeCodeRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
tokenInfo, err := h.oauthService.ExchangeCode(c.Request.Context(), &service.ExchangeCodeInput{
SessionID: req.SessionID,
Code: req.Code,
ProxyID: req.ProxyID,
})
if err != nil {
response.BadRequest(c, "Failed to exchange code: "+err.Error())
return
}
response.Success(c, tokenInfo)
}
// CookieAuthRequest represents the request for cookie-based authentication
type CookieAuthRequest struct {
SessionKey string `json:"code" binding:"required"` // Using 'code' field as sessionKey (frontend sends it this way)
ProxyID *int64 `json:"proxy_id"`
}
// CookieAuth performs OAuth using sessionKey (cookie-based auto-auth)
// POST /api/v1/admin/accounts/cookie-auth
func (h *OAuthHandler) CookieAuth(c *gin.Context) {
var req CookieAuthRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
tokenInfo, err := h.oauthService.CookieAuth(c.Request.Context(), &service.CookieAuthInput{
SessionKey: req.SessionKey,
ProxyID: req.ProxyID,
Scope: "full",
})
if err != nil {
response.BadRequest(c, "Cookie auth failed: "+err.Error())
return
}
response.Success(c, tokenInfo)
}
// SetupTokenCookieAuth performs OAuth using sessionKey for setup token (inference only)
// POST /api/v1/admin/accounts/setup-token-cookie-auth
func (h *OAuthHandler) SetupTokenCookieAuth(c *gin.Context) {
var req CookieAuthRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
tokenInfo, err := h.oauthService.CookieAuth(c.Request.Context(), &service.CookieAuthInput{
SessionKey: req.SessionKey,
ProxyID: req.ProxyID,
Scope: "inference",
})
if err != nil {
response.BadRequest(c, "Cookie auth failed: "+err.Error())
return
}
response.Success(c, tokenInfo)
}
// GetUsage handles getting account usage information
// GET /api/v1/admin/accounts/:id/usage
func (h *AccountHandler) GetUsage(c *gin.Context) {
accountID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid account ID")
return
}
usage, err := h.accountUsageService.GetUsage(c.Request.Context(), accountID)
if err != nil {
response.InternalError(c, "Failed to get usage: "+err.Error())
return
}
response.Success(c, usage)
}
// ClearRateLimit handles clearing account rate limit status
// POST /api/v1/admin/accounts/:id/clear-rate-limit
func (h *AccountHandler) ClearRateLimit(c *gin.Context) {
accountID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid account ID")
return
}
err = h.rateLimitService.ClearRateLimit(c.Request.Context(), accountID)
if err != nil {
response.InternalError(c, "Failed to clear rate limit: "+err.Error())
return
}
response.Success(c, gin.H{"message": "Rate limit cleared successfully"})
}
// GetTodayStats handles getting account today statistics
// GET /api/v1/admin/accounts/:id/today-stats
func (h *AccountHandler) GetTodayStats(c *gin.Context) {
accountID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid account ID")
return
}
stats, err := h.accountUsageService.GetTodayStats(c.Request.Context(), accountID)
if err != nil {
response.InternalError(c, "Failed to get today stats: "+err.Error())
return
}
response.Success(c, stats)
}
// SetSchedulableRequest represents the request body for setting schedulable status
type SetSchedulableRequest struct {
Schedulable bool `json:"schedulable"`
}
// SetSchedulable handles toggling account schedulable status
// POST /api/v1/admin/accounts/:id/schedulable
func (h *AccountHandler) SetSchedulable(c *gin.Context) {
accountID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid account ID")
return
}
var req SetSchedulableRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
account, err := h.adminService.SetAccountSchedulable(c.Request.Context(), accountID, req.Schedulable)
if err != nil {
response.InternalError(c, "Failed to update schedulable status: "+err.Error())
return
}
response.Success(c, account)
}

View File

@@ -0,0 +1,274 @@
package admin
import (
"strconv"
"sub2api/internal/pkg/response"
"sub2api/internal/pkg/timezone"
"sub2api/internal/repository"
"sub2api/internal/service"
"time"
"github.com/gin-gonic/gin"
)
// DashboardHandler handles admin dashboard statistics
type DashboardHandler struct {
adminService service.AdminService
usageRepo *repository.UsageLogRepository
startTime time.Time // Server start time for uptime calculation
}
// NewDashboardHandler creates a new admin dashboard handler
func NewDashboardHandler(adminService service.AdminService, usageRepo *repository.UsageLogRepository) *DashboardHandler {
return &DashboardHandler{
adminService: adminService,
usageRepo: usageRepo,
startTime: time.Now(),
}
}
// parseTimeRange parses start_date, end_date query parameters
func parseTimeRange(c *gin.Context) (time.Time, time.Time) {
now := timezone.Now()
startDate := c.Query("start_date")
endDate := c.Query("end_date")
var startTime, endTime time.Time
if startDate != "" {
if t, err := timezone.ParseInLocation("2006-01-02", startDate); err == nil {
startTime = t
} else {
startTime = timezone.StartOfDay(now.AddDate(0, 0, -7))
}
} else {
startTime = timezone.StartOfDay(now.AddDate(0, 0, -7))
}
if endDate != "" {
if t, err := timezone.ParseInLocation("2006-01-02", endDate); err == nil {
endTime = t.Add(24 * time.Hour) // Include the end date
} else {
endTime = timezone.StartOfDay(now.AddDate(0, 0, 1))
}
} else {
endTime = timezone.StartOfDay(now.AddDate(0, 0, 1))
}
return startTime, endTime
}
// GetStats handles getting dashboard statistics
// GET /api/v1/admin/dashboard/stats
func (h *DashboardHandler) GetStats(c *gin.Context) {
stats, err := h.usageRepo.GetDashboardStats(c.Request.Context())
if err != nil {
response.Error(c, 500, "Failed to get dashboard statistics")
return
}
// Calculate uptime in seconds
uptime := int64(time.Since(h.startTime).Seconds())
response.Success(c, gin.H{
// 用户统计
"total_users": stats.TotalUsers,
"today_new_users": stats.TodayNewUsers,
"active_users": stats.ActiveUsers,
// API Key 统计
"total_api_keys": stats.TotalApiKeys,
"active_api_keys": stats.ActiveApiKeys,
// 账户统计
"total_accounts": stats.TotalAccounts,
"normal_accounts": stats.NormalAccounts,
"error_accounts": stats.ErrorAccounts,
"ratelimit_accounts": stats.RateLimitAccounts,
"overload_accounts": stats.OverloadAccounts,
// 累计 Token 使用统计
"total_requests": stats.TotalRequests,
"total_input_tokens": stats.TotalInputTokens,
"total_output_tokens": stats.TotalOutputTokens,
"total_cache_creation_tokens": stats.TotalCacheCreationTokens,
"total_cache_read_tokens": stats.TotalCacheReadTokens,
"total_tokens": stats.TotalTokens,
"total_cost": stats.TotalCost, // 标准计费
"total_actual_cost": stats.TotalActualCost, // 实际扣除
// 今日 Token 使用统计
"today_requests": stats.TodayRequests,
"today_input_tokens": stats.TodayInputTokens,
"today_output_tokens": stats.TodayOutputTokens,
"today_cache_creation_tokens": stats.TodayCacheCreationTokens,
"today_cache_read_tokens": stats.TodayCacheReadTokens,
"today_tokens": stats.TodayTokens,
"today_cost": stats.TodayCost, // 今日标准计费
"today_actual_cost": stats.TodayActualCost, // 今日实际扣除
// 系统运行统计
"average_duration_ms": stats.AverageDurationMs,
"uptime": uptime,
})
}
// GetRealtimeMetrics handles getting real-time system metrics
// GET /api/v1/admin/dashboard/realtime
func (h *DashboardHandler) GetRealtimeMetrics(c *gin.Context) {
// Return mock data for now
response.Success(c, gin.H{
"active_requests": 0,
"requests_per_minute": 0,
"average_response_time": 0,
"error_rate": 0.0,
})
}
// GetUsageTrend handles getting usage trend data
// GET /api/v1/admin/dashboard/trend
// Query params: start_date, end_date (YYYY-MM-DD), granularity (day/hour)
func (h *DashboardHandler) GetUsageTrend(c *gin.Context) {
startTime, endTime := parseTimeRange(c)
granularity := c.DefaultQuery("granularity", "day")
trend, err := h.usageRepo.GetUsageTrend(c.Request.Context(), startTime, endTime, granularity)
if err != nil {
response.Error(c, 500, "Failed to get usage trend")
return
}
response.Success(c, gin.H{
"trend": trend,
"start_date": startTime.Format("2006-01-02"),
"end_date": endTime.Add(-24 * time.Hour).Format("2006-01-02"),
"granularity": granularity,
})
}
// GetModelStats handles getting model usage statistics
// GET /api/v1/admin/dashboard/models
// Query params: start_date, end_date (YYYY-MM-DD)
func (h *DashboardHandler) GetModelStats(c *gin.Context) {
startTime, endTime := parseTimeRange(c)
stats, err := h.usageRepo.GetModelStats(c.Request.Context(), startTime, endTime)
if err != nil {
response.Error(c, 500, "Failed to get model statistics")
return
}
response.Success(c, gin.H{
"models": stats,
"start_date": startTime.Format("2006-01-02"),
"end_date": endTime.Add(-24 * time.Hour).Format("2006-01-02"),
})
}
// GetApiKeyUsageTrend handles getting API key usage trend data
// GET /api/v1/admin/dashboard/api-keys-trend
// Query params: start_date, end_date (YYYY-MM-DD), granularity (day/hour), limit (default 5)
func (h *DashboardHandler) GetApiKeyUsageTrend(c *gin.Context) {
startTime, endTime := parseTimeRange(c)
granularity := c.DefaultQuery("granularity", "day")
limitStr := c.DefaultQuery("limit", "5")
limit, err := strconv.Atoi(limitStr)
if err != nil || limit <= 0 {
limit = 5
}
trend, err := h.usageRepo.GetApiKeyUsageTrend(c.Request.Context(), startTime, endTime, granularity, limit)
if err != nil {
response.Error(c, 500, "Failed to get API key usage trend")
return
}
response.Success(c, gin.H{
"trend": trend,
"start_date": startTime.Format("2006-01-02"),
"end_date": endTime.Add(-24 * time.Hour).Format("2006-01-02"),
"granularity": granularity,
})
}
// GetUserUsageTrend handles getting user usage trend data
// GET /api/v1/admin/dashboard/users-trend
// Query params: start_date, end_date (YYYY-MM-DD), granularity (day/hour), limit (default 12)
func (h *DashboardHandler) GetUserUsageTrend(c *gin.Context) {
startTime, endTime := parseTimeRange(c)
granularity := c.DefaultQuery("granularity", "day")
limitStr := c.DefaultQuery("limit", "12")
limit, err := strconv.Atoi(limitStr)
if err != nil || limit <= 0 {
limit = 12
}
trend, err := h.usageRepo.GetUserUsageTrend(c.Request.Context(), startTime, endTime, granularity, limit)
if err != nil {
response.Error(c, 500, "Failed to get user usage trend")
return
}
response.Success(c, gin.H{
"trend": trend,
"start_date": startTime.Format("2006-01-02"),
"end_date": endTime.Add(-24 * time.Hour).Format("2006-01-02"),
"granularity": granularity,
})
}
// BatchUsersUsageRequest represents the request body for batch user usage stats
type BatchUsersUsageRequest struct {
UserIDs []int64 `json:"user_ids" binding:"required"`
}
// GetBatchUsersUsage handles getting usage stats for multiple users
// POST /api/v1/admin/dashboard/users-usage
func (h *DashboardHandler) GetBatchUsersUsage(c *gin.Context) {
var req BatchUsersUsageRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
if len(req.UserIDs) == 0 {
response.Success(c, gin.H{"stats": map[string]interface{}{}})
return
}
stats, err := h.usageRepo.GetBatchUserUsageStats(c.Request.Context(), req.UserIDs)
if err != nil {
response.Error(c, 500, "Failed to get user usage stats")
return
}
response.Success(c, gin.H{"stats": stats})
}
// BatchApiKeysUsageRequest represents the request body for batch api key usage stats
type BatchApiKeysUsageRequest struct {
ApiKeyIDs []int64 `json:"api_key_ids" binding:"required"`
}
// GetBatchApiKeysUsage handles getting usage stats for multiple API keys
// POST /api/v1/admin/dashboard/api-keys-usage
func (h *DashboardHandler) GetBatchApiKeysUsage(c *gin.Context) {
var req BatchApiKeysUsageRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
if len(req.ApiKeyIDs) == 0 {
response.Success(c, gin.H{"stats": map[string]interface{}{}})
return
}
stats, err := h.usageRepo.GetBatchApiKeyUsageStats(c.Request.Context(), req.ApiKeyIDs)
if err != nil {
response.Error(c, 500, "Failed to get API key usage stats")
return
}
response.Success(c, gin.H{"stats": stats})
}

View File

@@ -0,0 +1,233 @@
package admin
import (
"strconv"
"sub2api/internal/model"
"sub2api/internal/pkg/response"
"sub2api/internal/service"
"github.com/gin-gonic/gin"
)
// GroupHandler handles admin group management
type GroupHandler struct {
adminService service.AdminService
}
// NewGroupHandler creates a new admin group handler
func NewGroupHandler(adminService service.AdminService) *GroupHandler {
return &GroupHandler{
adminService: adminService,
}
}
// CreateGroupRequest represents create group request
type CreateGroupRequest struct {
Name string `json:"name" binding:"required"`
Description string `json:"description"`
Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini"`
RateMultiplier float64 `json:"rate_multiplier"`
IsExclusive bool `json:"is_exclusive"`
SubscriptionType string `json:"subscription_type" binding:"omitempty,oneof=standard subscription"`
DailyLimitUSD *float64 `json:"daily_limit_usd"`
WeeklyLimitUSD *float64 `json:"weekly_limit_usd"`
MonthlyLimitUSD *float64 `json:"monthly_limit_usd"`
}
// UpdateGroupRequest represents update group request
type UpdateGroupRequest struct {
Name string `json:"name"`
Description string `json:"description"`
Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini"`
RateMultiplier *float64 `json:"rate_multiplier"`
IsExclusive *bool `json:"is_exclusive"`
Status string `json:"status" binding:"omitempty,oneof=active inactive"`
SubscriptionType string `json:"subscription_type" binding:"omitempty,oneof=standard subscription"`
DailyLimitUSD *float64 `json:"daily_limit_usd"`
WeeklyLimitUSD *float64 `json:"weekly_limit_usd"`
MonthlyLimitUSD *float64 `json:"monthly_limit_usd"`
}
// List handles listing all groups with pagination
// GET /api/v1/admin/groups
func (h *GroupHandler) List(c *gin.Context) {
page, pageSize := response.ParsePagination(c)
platform := c.Query("platform")
status := c.Query("status")
isExclusiveStr := c.Query("is_exclusive")
var isExclusive *bool
if isExclusiveStr != "" {
val := isExclusiveStr == "true"
isExclusive = &val
}
groups, total, err := h.adminService.ListGroups(c.Request.Context(), page, pageSize, platform, status, isExclusive)
if err != nil {
response.InternalError(c, "Failed to list groups: "+err.Error())
return
}
response.Paginated(c, groups, total, page, pageSize)
}
// GetAll handles getting all active groups without pagination
// GET /api/v1/admin/groups/all
func (h *GroupHandler) GetAll(c *gin.Context) {
platform := c.Query("platform")
var groups []model.Group
var err error
if platform != "" {
groups, err = h.adminService.GetAllGroupsByPlatform(c.Request.Context(), platform)
} else {
groups, err = h.adminService.GetAllGroups(c.Request.Context())
}
if err != nil {
response.InternalError(c, "Failed to get groups: "+err.Error())
return
}
response.Success(c, groups)
}
// GetByID handles getting a group by ID
// GET /api/v1/admin/groups/:id
func (h *GroupHandler) GetByID(c *gin.Context) {
groupID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid group ID")
return
}
group, err := h.adminService.GetGroup(c.Request.Context(), groupID)
if err != nil {
response.NotFound(c, "Group not found")
return
}
response.Success(c, group)
}
// Create handles creating a new group
// POST /api/v1/admin/groups
func (h *GroupHandler) Create(c *gin.Context) {
var req CreateGroupRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
group, err := h.adminService.CreateGroup(c.Request.Context(), &service.CreateGroupInput{
Name: req.Name,
Description: req.Description,
Platform: req.Platform,
RateMultiplier: req.RateMultiplier,
IsExclusive: req.IsExclusive,
SubscriptionType: req.SubscriptionType,
DailyLimitUSD: req.DailyLimitUSD,
WeeklyLimitUSD: req.WeeklyLimitUSD,
MonthlyLimitUSD: req.MonthlyLimitUSD,
})
if err != nil {
response.BadRequest(c, "Failed to create group: "+err.Error())
return
}
response.Success(c, group)
}
// Update handles updating a group
// PUT /api/v1/admin/groups/:id
func (h *GroupHandler) Update(c *gin.Context) {
groupID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid group ID")
return
}
var req UpdateGroupRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
group, err := h.adminService.UpdateGroup(c.Request.Context(), groupID, &service.UpdateGroupInput{
Name: req.Name,
Description: req.Description,
Platform: req.Platform,
RateMultiplier: req.RateMultiplier,
IsExclusive: req.IsExclusive,
Status: req.Status,
SubscriptionType: req.SubscriptionType,
DailyLimitUSD: req.DailyLimitUSD,
WeeklyLimitUSD: req.WeeklyLimitUSD,
MonthlyLimitUSD: req.MonthlyLimitUSD,
})
if err != nil {
response.InternalError(c, "Failed to update group: "+err.Error())
return
}
response.Success(c, group)
}
// Delete handles deleting a group
// DELETE /api/v1/admin/groups/:id
func (h *GroupHandler) Delete(c *gin.Context) {
groupID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid group ID")
return
}
err = h.adminService.DeleteGroup(c.Request.Context(), groupID)
if err != nil {
response.InternalError(c, "Failed to delete group: "+err.Error())
return
}
response.Success(c, gin.H{"message": "Group deleted successfully"})
}
// GetStats handles getting group statistics
// GET /api/v1/admin/groups/:id/stats
func (h *GroupHandler) GetStats(c *gin.Context) {
groupID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid group ID")
return
}
// Return mock data for now
response.Success(c, gin.H{
"total_api_keys": 0,
"active_api_keys": 0,
"total_requests": 0,
"total_cost": 0.0,
})
_ = groupID // TODO: implement actual stats
}
// GetGroupAPIKeys handles getting API keys in a group
// GET /api/v1/admin/groups/:id/api-keys
func (h *GroupHandler) GetGroupAPIKeys(c *gin.Context) {
groupID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid group ID")
return
}
page, pageSize := response.ParsePagination(c)
keys, total, err := h.adminService.GetGroupAPIKeys(c.Request.Context(), groupID, page, pageSize)
if err != nil {
response.InternalError(c, "Failed to get group API keys: "+err.Error())
return
}
response.Paginated(c, keys, total, page, pageSize)
}

View File

@@ -0,0 +1,300 @@
package admin
import (
"strconv"
"sub2api/internal/pkg/response"
"sub2api/internal/service"
"github.com/gin-gonic/gin"
)
// ProxyHandler handles admin proxy management
type ProxyHandler struct {
adminService service.AdminService
}
// NewProxyHandler creates a new admin proxy handler
func NewProxyHandler(adminService service.AdminService) *ProxyHandler {
return &ProxyHandler{
adminService: adminService,
}
}
// CreateProxyRequest represents create proxy request
type CreateProxyRequest struct {
Name string `json:"name" binding:"required"`
Protocol string `json:"protocol" binding:"required,oneof=http https socks5"`
Host string `json:"host" binding:"required"`
Port int `json:"port" binding:"required,min=1,max=65535"`
Username string `json:"username"`
Password string `json:"password"`
}
// UpdateProxyRequest represents update proxy request
type UpdateProxyRequest struct {
Name string `json:"name"`
Protocol string `json:"protocol" binding:"omitempty,oneof=http https socks5"`
Host string `json:"host"`
Port int `json:"port" binding:"omitempty,min=1,max=65535"`
Username string `json:"username"`
Password string `json:"password"`
Status string `json:"status" binding:"omitempty,oneof=active inactive"`
}
// List handles listing all proxies with pagination
// GET /api/v1/admin/proxies
func (h *ProxyHandler) List(c *gin.Context) {
page, pageSize := response.ParsePagination(c)
protocol := c.Query("protocol")
status := c.Query("status")
search := c.Query("search")
proxies, total, err := h.adminService.ListProxies(c.Request.Context(), page, pageSize, protocol, status, search)
if err != nil {
response.InternalError(c, "Failed to list proxies: "+err.Error())
return
}
response.Paginated(c, proxies, total, page, pageSize)
}
// GetAll handles getting all active proxies without pagination
// GET /api/v1/admin/proxies/all
// Optional query param: with_count=true to include account count per proxy
func (h *ProxyHandler) GetAll(c *gin.Context) {
withCount := c.Query("with_count") == "true"
if withCount {
proxies, err := h.adminService.GetAllProxiesWithAccountCount(c.Request.Context())
if err != nil {
response.InternalError(c, "Failed to get proxies: "+err.Error())
return
}
response.Success(c, proxies)
return
}
proxies, err := h.adminService.GetAllProxies(c.Request.Context())
if err != nil {
response.InternalError(c, "Failed to get proxies: "+err.Error())
return
}
response.Success(c, proxies)
}
// GetByID handles getting a proxy by ID
// GET /api/v1/admin/proxies/:id
func (h *ProxyHandler) GetByID(c *gin.Context) {
proxyID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid proxy ID")
return
}
proxy, err := h.adminService.GetProxy(c.Request.Context(), proxyID)
if err != nil {
response.NotFound(c, "Proxy not found")
return
}
response.Success(c, proxy)
}
// Create handles creating a new proxy
// POST /api/v1/admin/proxies
func (h *ProxyHandler) Create(c *gin.Context) {
var req CreateProxyRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
proxy, err := h.adminService.CreateProxy(c.Request.Context(), &service.CreateProxyInput{
Name: req.Name,
Protocol: req.Protocol,
Host: req.Host,
Port: req.Port,
Username: req.Username,
Password: req.Password,
})
if err != nil {
response.BadRequest(c, "Failed to create proxy: "+err.Error())
return
}
response.Success(c, proxy)
}
// Update handles updating a proxy
// PUT /api/v1/admin/proxies/:id
func (h *ProxyHandler) Update(c *gin.Context) {
proxyID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid proxy ID")
return
}
var req UpdateProxyRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
proxy, err := h.adminService.UpdateProxy(c.Request.Context(), proxyID, &service.UpdateProxyInput{
Name: req.Name,
Protocol: req.Protocol,
Host: req.Host,
Port: req.Port,
Username: req.Username,
Password: req.Password,
Status: req.Status,
})
if err != nil {
response.InternalError(c, "Failed to update proxy: "+err.Error())
return
}
response.Success(c, proxy)
}
// Delete handles deleting a proxy
// DELETE /api/v1/admin/proxies/:id
func (h *ProxyHandler) Delete(c *gin.Context) {
proxyID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid proxy ID")
return
}
err = h.adminService.DeleteProxy(c.Request.Context(), proxyID)
if err != nil {
response.InternalError(c, "Failed to delete proxy: "+err.Error())
return
}
response.Success(c, gin.H{"message": "Proxy deleted successfully"})
}
// Test handles testing proxy connectivity
// POST /api/v1/admin/proxies/:id/test
func (h *ProxyHandler) Test(c *gin.Context) {
proxyID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid proxy ID")
return
}
result, err := h.adminService.TestProxy(c.Request.Context(), proxyID)
if err != nil {
response.InternalError(c, "Failed to test proxy: "+err.Error())
return
}
response.Success(c, result)
}
// GetStats handles getting proxy statistics
// GET /api/v1/admin/proxies/:id/stats
func (h *ProxyHandler) GetStats(c *gin.Context) {
proxyID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid proxy ID")
return
}
// Return mock data for now
_ = proxyID
response.Success(c, gin.H{
"total_accounts": 0,
"active_accounts": 0,
"total_requests": 0,
"success_rate": 100.0,
"average_latency": 0,
})
}
// GetProxyAccounts handles getting accounts using a proxy
// GET /api/v1/admin/proxies/:id/accounts
func (h *ProxyHandler) GetProxyAccounts(c *gin.Context) {
proxyID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid proxy ID")
return
}
page, pageSize := response.ParsePagination(c)
accounts, total, err := h.adminService.GetProxyAccounts(c.Request.Context(), proxyID, page, pageSize)
if err != nil {
response.InternalError(c, "Failed to get proxy accounts: "+err.Error())
return
}
response.Paginated(c, accounts, total, page, pageSize)
}
// BatchCreateProxyItem represents a single proxy in batch create request
type BatchCreateProxyItem struct {
Protocol string `json:"protocol" binding:"required,oneof=http https socks5"`
Host string `json:"host" binding:"required"`
Port int `json:"port" binding:"required,min=1,max=65535"`
Username string `json:"username"`
Password string `json:"password"`
}
// BatchCreateRequest represents batch create proxies request
type BatchCreateRequest struct {
Proxies []BatchCreateProxyItem `json:"proxies" binding:"required,min=1"`
}
// BatchCreate handles batch creating proxies
// POST /api/v1/admin/proxies/batch
func (h *ProxyHandler) BatchCreate(c *gin.Context) {
var req BatchCreateRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
created := 0
skipped := 0
for _, item := range req.Proxies {
// Check for duplicates (same host, port, username, password)
exists, err := h.adminService.CheckProxyExists(c.Request.Context(), item.Host, item.Port, item.Username, item.Password)
if err != nil {
response.InternalError(c, "Failed to check proxy existence: "+err.Error())
return
}
if exists {
skipped++
continue
}
// Create proxy with default name
_, err = h.adminService.CreateProxy(c.Request.Context(), &service.CreateProxyInput{
Name: "default",
Protocol: item.Protocol,
Host: item.Host,
Port: item.Port,
Username: item.Username,
Password: item.Password,
})
if err != nil {
// If creation fails due to duplicate, count as skipped
skipped++
continue
}
created++
}
response.Success(c, gin.H{
"created": created,
"skipped": skipped,
})
}

View File

@@ -0,0 +1,219 @@
package admin
import (
"bytes"
"encoding/csv"
"fmt"
"strconv"
"sub2api/internal/pkg/response"
"sub2api/internal/service"
"github.com/gin-gonic/gin"
)
// RedeemHandler handles admin redeem code management
type RedeemHandler struct {
adminService service.AdminService
}
// NewRedeemHandler creates a new admin redeem handler
func NewRedeemHandler(adminService service.AdminService) *RedeemHandler {
return &RedeemHandler{
adminService: adminService,
}
}
// GenerateRedeemCodesRequest represents generate redeem codes request
type GenerateRedeemCodesRequest struct {
Count int `json:"count" binding:"required,min=1,max=100"`
Type string `json:"type" binding:"required,oneof=balance concurrency subscription"`
Value float64 `json:"value" binding:"min=0"`
GroupID *int64 `json:"group_id"` // 订阅类型必填
ValidityDays int `json:"validity_days"` // 订阅类型使用默认30天
}
// List handles listing all redeem codes with pagination
// GET /api/v1/admin/redeem-codes
func (h *RedeemHandler) List(c *gin.Context) {
page, pageSize := response.ParsePagination(c)
codeType := c.Query("type")
status := c.Query("status")
search := c.Query("search")
codes, total, err := h.adminService.ListRedeemCodes(c.Request.Context(), page, pageSize, codeType, status, search)
if err != nil {
response.InternalError(c, "Failed to list redeem codes: "+err.Error())
return
}
response.Paginated(c, codes, total, page, pageSize)
}
// GetByID handles getting a redeem code by ID
// GET /api/v1/admin/redeem-codes/:id
func (h *RedeemHandler) GetByID(c *gin.Context) {
codeID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid redeem code ID")
return
}
code, err := h.adminService.GetRedeemCode(c.Request.Context(), codeID)
if err != nil {
response.NotFound(c, "Redeem code not found")
return
}
response.Success(c, code)
}
// Generate handles generating new redeem codes
// POST /api/v1/admin/redeem-codes/generate
func (h *RedeemHandler) Generate(c *gin.Context) {
var req GenerateRedeemCodesRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
codes, err := h.adminService.GenerateRedeemCodes(c.Request.Context(), &service.GenerateRedeemCodesInput{
Count: req.Count,
Type: req.Type,
Value: req.Value,
GroupID: req.GroupID,
ValidityDays: req.ValidityDays,
})
if err != nil {
response.InternalError(c, "Failed to generate redeem codes: "+err.Error())
return
}
response.Success(c, codes)
}
// Delete handles deleting a redeem code
// DELETE /api/v1/admin/redeem-codes/:id
func (h *RedeemHandler) Delete(c *gin.Context) {
codeID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid redeem code ID")
return
}
err = h.adminService.DeleteRedeemCode(c.Request.Context(), codeID)
if err != nil {
response.InternalError(c, "Failed to delete redeem code: "+err.Error())
return
}
response.Success(c, gin.H{"message": "Redeem code deleted successfully"})
}
// BatchDelete handles batch deleting redeem codes
// POST /api/v1/admin/redeem-codes/batch-delete
func (h *RedeemHandler) BatchDelete(c *gin.Context) {
var req struct {
IDs []int64 `json:"ids" binding:"required,min=1"`
}
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
deleted, err := h.adminService.BatchDeleteRedeemCodes(c.Request.Context(), req.IDs)
if err != nil {
response.InternalError(c, "Failed to batch delete redeem codes: "+err.Error())
return
}
response.Success(c, gin.H{
"deleted": deleted,
"message": "Redeem codes deleted successfully",
})
}
// Expire handles expiring a redeem code
// POST /api/v1/admin/redeem-codes/:id/expire
func (h *RedeemHandler) Expire(c *gin.Context) {
codeID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid redeem code ID")
return
}
code, err := h.adminService.ExpireRedeemCode(c.Request.Context(), codeID)
if err != nil {
response.InternalError(c, "Failed to expire redeem code: "+err.Error())
return
}
response.Success(c, code)
}
// GetStats handles getting redeem code statistics
// GET /api/v1/admin/redeem-codes/stats
func (h *RedeemHandler) GetStats(c *gin.Context) {
// Return mock data for now
response.Success(c, gin.H{
"total_codes": 0,
"active_codes": 0,
"used_codes": 0,
"expired_codes": 0,
"total_value_distributed": 0.0,
"by_type": gin.H{
"balance": 0,
"concurrency": 0,
"trial": 0,
},
})
}
// Export handles exporting redeem codes to CSV
// GET /api/v1/admin/redeem-codes/export
func (h *RedeemHandler) Export(c *gin.Context) {
codeType := c.Query("type")
status := c.Query("status")
// Get all codes without pagination (use large page size)
codes, _, err := h.adminService.ListRedeemCodes(c.Request.Context(), 1, 10000, codeType, status, "")
if err != nil {
response.InternalError(c, "Failed to export redeem codes: "+err.Error())
return
}
// Create CSV buffer
var buf bytes.Buffer
writer := csv.NewWriter(&buf)
// Write header
writer.Write([]string{"id", "code", "type", "value", "status", "used_by", "used_at", "created_at"})
// Write data rows
for _, code := range codes {
usedBy := ""
if code.UsedBy != nil {
usedBy = fmt.Sprintf("%d", *code.UsedBy)
}
usedAt := ""
if code.UsedAt != nil {
usedAt = code.UsedAt.Format("2006-01-02 15:04:05")
}
writer.Write([]string{
fmt.Sprintf("%d", code.ID),
code.Code,
code.Type,
fmt.Sprintf("%.2f", code.Value),
code.Status,
usedBy,
usedAt,
code.CreatedAt.Format("2006-01-02 15:04:05"),
})
}
writer.Flush()
c.Header("Content-Type", "text/csv")
c.Header("Content-Disposition", "attachment; filename=redeem_codes.csv")
c.Data(200, "text/csv", buf.Bytes())
}

View File

@@ -0,0 +1,258 @@
package admin
import (
"sub2api/internal/model"
"sub2api/internal/pkg/response"
"sub2api/internal/service"
"github.com/gin-gonic/gin"
)
// SettingHandler 系统设置处理器
type SettingHandler struct {
settingService *service.SettingService
emailService *service.EmailService
}
// NewSettingHandler 创建系统设置处理器
func NewSettingHandler(settingService *service.SettingService, emailService *service.EmailService) *SettingHandler {
return &SettingHandler{
settingService: settingService,
emailService: emailService,
}
}
// GetSettings 获取所有系统设置
// GET /api/v1/admin/settings
func (h *SettingHandler) GetSettings(c *gin.Context) {
settings, err := h.settingService.GetAllSettings(c.Request.Context())
if err != nil {
response.InternalError(c, "Failed to get settings: "+err.Error())
return
}
response.Success(c, settings)
}
// UpdateSettingsRequest 更新设置请求
type UpdateSettingsRequest struct {
// 注册设置
RegistrationEnabled bool `json:"registration_enabled"`
EmailVerifyEnabled bool `json:"email_verify_enabled"`
// 邮件服务设置
SmtpHost string `json:"smtp_host"`
SmtpPort int `json:"smtp_port"`
SmtpUsername string `json:"smtp_username"`
SmtpPassword string `json:"smtp_password"`
SmtpFrom string `json:"smtp_from_email"`
SmtpFromName string `json:"smtp_from_name"`
SmtpUseTLS bool `json:"smtp_use_tls"`
// Cloudflare Turnstile 设置
TurnstileEnabled bool `json:"turnstile_enabled"`
TurnstileSiteKey string `json:"turnstile_site_key"`
TurnstileSecretKey string `json:"turnstile_secret_key"`
// OEM设置
SiteName string `json:"site_name"`
SiteLogo string `json:"site_logo"`
SiteSubtitle string `json:"site_subtitle"`
ApiBaseUrl string `json:"api_base_url"`
ContactInfo string `json:"contact_info"`
// 默认配置
DefaultConcurrency int `json:"default_concurrency"`
DefaultBalance float64 `json:"default_balance"`
}
// UpdateSettings 更新系统设置
// PUT /api/v1/admin/settings
func (h *SettingHandler) UpdateSettings(c *gin.Context) {
var req UpdateSettingsRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
// 验证参数
if req.DefaultConcurrency < 1 {
req.DefaultConcurrency = 1
}
if req.DefaultBalance < 0 {
req.DefaultBalance = 0
}
if req.SmtpPort <= 0 {
req.SmtpPort = 587
}
settings := &model.SystemSettings{
RegistrationEnabled: req.RegistrationEnabled,
EmailVerifyEnabled: req.EmailVerifyEnabled,
SmtpHost: req.SmtpHost,
SmtpPort: req.SmtpPort,
SmtpUsername: req.SmtpUsername,
SmtpPassword: req.SmtpPassword,
SmtpFrom: req.SmtpFrom,
SmtpFromName: req.SmtpFromName,
SmtpUseTLS: req.SmtpUseTLS,
TurnstileEnabled: req.TurnstileEnabled,
TurnstileSiteKey: req.TurnstileSiteKey,
TurnstileSecretKey: req.TurnstileSecretKey,
SiteName: req.SiteName,
SiteLogo: req.SiteLogo,
SiteSubtitle: req.SiteSubtitle,
ApiBaseUrl: req.ApiBaseUrl,
ContactInfo: req.ContactInfo,
DefaultConcurrency: req.DefaultConcurrency,
DefaultBalance: req.DefaultBalance,
}
if err := h.settingService.UpdateSettings(c.Request.Context(), settings); err != nil {
response.InternalError(c, "Failed to update settings: "+err.Error())
return
}
// 重新获取设置返回
updatedSettings, err := h.settingService.GetAllSettings(c.Request.Context())
if err != nil {
response.InternalError(c, "Failed to get updated settings: "+err.Error())
return
}
response.Success(c, updatedSettings)
}
// TestSmtpRequest 测试SMTP连接请求
type TestSmtpRequest struct {
SmtpHost string `json:"smtp_host" binding:"required"`
SmtpPort int `json:"smtp_port"`
SmtpUsername string `json:"smtp_username"`
SmtpPassword string `json:"smtp_password"`
SmtpUseTLS bool `json:"smtp_use_tls"`
}
// TestSmtpConnection 测试SMTP连接
// POST /api/v1/admin/settings/test-smtp
func (h *SettingHandler) TestSmtpConnection(c *gin.Context) {
var req TestSmtpRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
if req.SmtpPort <= 0 {
req.SmtpPort = 587
}
// 如果未提供密码,从数据库获取已保存的密码
password := req.SmtpPassword
if password == "" {
savedConfig, err := h.emailService.GetSmtpConfig(c.Request.Context())
if err == nil && savedConfig != nil {
password = savedConfig.Password
}
}
config := &service.SmtpConfig{
Host: req.SmtpHost,
Port: req.SmtpPort,
Username: req.SmtpUsername,
Password: password,
UseTLS: req.SmtpUseTLS,
}
err := h.emailService.TestSmtpConnectionWithConfig(config)
if err != nil {
response.BadRequest(c, "SMTP connection test failed: "+err.Error())
return
}
response.Success(c, gin.H{"message": "SMTP connection successful"})
}
// SendTestEmailRequest 发送测试邮件请求
type SendTestEmailRequest struct {
Email string `json:"email" binding:"required,email"`
SmtpHost string `json:"smtp_host" binding:"required"`
SmtpPort int `json:"smtp_port"`
SmtpUsername string `json:"smtp_username"`
SmtpPassword string `json:"smtp_password"`
SmtpFrom string `json:"smtp_from_email"`
SmtpFromName string `json:"smtp_from_name"`
SmtpUseTLS bool `json:"smtp_use_tls"`
}
// SendTestEmail 发送测试邮件
// POST /api/v1/admin/settings/send-test-email
func (h *SettingHandler) SendTestEmail(c *gin.Context) {
var req SendTestEmailRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
if req.SmtpPort <= 0 {
req.SmtpPort = 587
}
// 如果未提供密码,从数据库获取已保存的密码
password := req.SmtpPassword
if password == "" {
savedConfig, err := h.emailService.GetSmtpConfig(c.Request.Context())
if err == nil && savedConfig != nil {
password = savedConfig.Password
}
}
config := &service.SmtpConfig{
Host: req.SmtpHost,
Port: req.SmtpPort,
Username: req.SmtpUsername,
Password: password,
From: req.SmtpFrom,
FromName: req.SmtpFromName,
UseTLS: req.SmtpUseTLS,
}
siteName := h.settingService.GetSiteName(c.Request.Context())
subject := "[" + siteName + "] Test Email"
body := `
<!DOCTYPE html>
<html>
<head>
<meta charset="UTF-8">
<style>
body { font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif; background-color: #f5f5f5; margin: 0; padding: 20px; }
.container { max-width: 600px; margin: 0 auto; background-color: #ffffff; border-radius: 8px; overflow: hidden; box-shadow: 0 2px 8px rgba(0,0,0,0.1); }
.header { background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); color: white; padding: 30px; text-align: center; }
.content { padding: 40px 30px; text-align: center; }
.success { color: #10b981; font-size: 48px; margin-bottom: 20px; }
.footer { background-color: #f8f9fa; padding: 20px; text-align: center; color: #999; font-size: 12px; }
</style>
</head>
<body>
<div class="container">
<div class="header">
<h1>` + siteName + `</h1>
</div>
<div class="content">
<div class="success">✓</div>
<h2>Email Configuration Successful!</h2>
<p>This is a test email to verify your SMTP settings are working correctly.</p>
</div>
<div class="footer">
<p>This is an automated test message.</p>
</div>
</div>
</body>
</html>
`
if err := h.emailService.SendEmailWithConfig(config, req.Email, subject, body); err != nil {
response.BadRequest(c, "Failed to send test email: "+err.Error())
return
}
response.Success(c, gin.H{"message": "Test email sent successfully"})
}

View File

@@ -0,0 +1,266 @@
package admin
import (
"strconv"
"sub2api/internal/model"
"sub2api/internal/pkg/response"
"sub2api/internal/repository"
"sub2api/internal/service"
"github.com/gin-gonic/gin"
)
// toResponsePagination converts repository.PaginationResult to response.PaginationResult
func toResponsePagination(p *repository.PaginationResult) *response.PaginationResult {
if p == nil {
return nil
}
return &response.PaginationResult{
Total: p.Total,
Page: p.Page,
PageSize: p.PageSize,
Pages: p.Pages,
}
}
// SubscriptionHandler handles admin subscription management
type SubscriptionHandler struct {
subscriptionService *service.SubscriptionService
}
// NewSubscriptionHandler creates a new admin subscription handler
func NewSubscriptionHandler(subscriptionService *service.SubscriptionService) *SubscriptionHandler {
return &SubscriptionHandler{
subscriptionService: subscriptionService,
}
}
// AssignSubscriptionRequest represents assign subscription request
type AssignSubscriptionRequest struct {
UserID int64 `json:"user_id" binding:"required"`
GroupID int64 `json:"group_id" binding:"required"`
ValidityDays int `json:"validity_days"`
Notes string `json:"notes"`
}
// BulkAssignSubscriptionRequest represents bulk assign subscription request
type BulkAssignSubscriptionRequest struct {
UserIDs []int64 `json:"user_ids" binding:"required,min=1"`
GroupID int64 `json:"group_id" binding:"required"`
ValidityDays int `json:"validity_days"`
Notes string `json:"notes"`
}
// ExtendSubscriptionRequest represents extend subscription request
type ExtendSubscriptionRequest struct {
Days int `json:"days" binding:"required,min=1"`
}
// List handles listing all subscriptions with pagination and filters
// GET /api/v1/admin/subscriptions
func (h *SubscriptionHandler) List(c *gin.Context) {
page, pageSize := response.ParsePagination(c)
// Parse optional filters
var userID, groupID *int64
if userIDStr := c.Query("user_id"); userIDStr != "" {
if id, err := strconv.ParseInt(userIDStr, 10, 64); err == nil {
userID = &id
}
}
if groupIDStr := c.Query("group_id"); groupIDStr != "" {
if id, err := strconv.ParseInt(groupIDStr, 10, 64); err == nil {
groupID = &id
}
}
status := c.Query("status")
subscriptions, pagination, err := h.subscriptionService.List(c.Request.Context(), page, pageSize, userID, groupID, status)
if err != nil {
response.InternalError(c, "Failed to list subscriptions: "+err.Error())
return
}
response.PaginatedWithResult(c, subscriptions, toResponsePagination(pagination))
}
// GetByID handles getting a subscription by ID
// GET /api/v1/admin/subscriptions/:id
func (h *SubscriptionHandler) GetByID(c *gin.Context) {
subscriptionID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid subscription ID")
return
}
subscription, err := h.subscriptionService.GetByID(c.Request.Context(), subscriptionID)
if err != nil {
response.NotFound(c, "Subscription not found")
return
}
response.Success(c, subscription)
}
// GetProgress handles getting subscription usage progress
// GET /api/v1/admin/subscriptions/:id/progress
func (h *SubscriptionHandler) GetProgress(c *gin.Context) {
subscriptionID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid subscription ID")
return
}
progress, err := h.subscriptionService.GetSubscriptionProgress(c.Request.Context(), subscriptionID)
if err != nil {
response.NotFound(c, "Subscription not found")
return
}
response.Success(c, progress)
}
// Assign handles assigning a subscription to a user
// POST /api/v1/admin/subscriptions/assign
func (h *SubscriptionHandler) Assign(c *gin.Context) {
var req AssignSubscriptionRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
// Get admin user ID from context
adminID := getAdminIDFromContext(c)
subscription, err := h.subscriptionService.AssignSubscription(c.Request.Context(), &service.AssignSubscriptionInput{
UserID: req.UserID,
GroupID: req.GroupID,
ValidityDays: req.ValidityDays,
AssignedBy: adminID,
Notes: req.Notes,
})
if err != nil {
response.BadRequest(c, "Failed to assign subscription: "+err.Error())
return
}
response.Success(c, subscription)
}
// BulkAssign handles bulk assigning subscriptions to multiple users
// POST /api/v1/admin/subscriptions/bulk-assign
func (h *SubscriptionHandler) BulkAssign(c *gin.Context) {
var req BulkAssignSubscriptionRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
// Get admin user ID from context
adminID := getAdminIDFromContext(c)
result, err := h.subscriptionService.BulkAssignSubscription(c.Request.Context(), &service.BulkAssignSubscriptionInput{
UserIDs: req.UserIDs,
GroupID: req.GroupID,
ValidityDays: req.ValidityDays,
AssignedBy: adminID,
Notes: req.Notes,
})
if err != nil {
response.InternalError(c, "Failed to bulk assign subscriptions: "+err.Error())
return
}
response.Success(c, result)
}
// Extend handles extending a subscription
// POST /api/v1/admin/subscriptions/:id/extend
func (h *SubscriptionHandler) Extend(c *gin.Context) {
subscriptionID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid subscription ID")
return
}
var req ExtendSubscriptionRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
subscription, err := h.subscriptionService.ExtendSubscription(c.Request.Context(), subscriptionID, req.Days)
if err != nil {
response.InternalError(c, "Failed to extend subscription: "+err.Error())
return
}
response.Success(c, subscription)
}
// Revoke handles revoking a subscription
// DELETE /api/v1/admin/subscriptions/:id
func (h *SubscriptionHandler) Revoke(c *gin.Context) {
subscriptionID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid subscription ID")
return
}
err = h.subscriptionService.RevokeSubscription(c.Request.Context(), subscriptionID)
if err != nil {
response.InternalError(c, "Failed to revoke subscription: "+err.Error())
return
}
response.Success(c, gin.H{"message": "Subscription revoked successfully"})
}
// ListByGroup handles listing subscriptions for a specific group
// GET /api/v1/admin/groups/:id/subscriptions
func (h *SubscriptionHandler) ListByGroup(c *gin.Context) {
groupID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid group ID")
return
}
page, pageSize := response.ParsePagination(c)
subscriptions, pagination, err := h.subscriptionService.ListGroupSubscriptions(c.Request.Context(), groupID, page, pageSize)
if err != nil {
response.InternalError(c, "Failed to list group subscriptions: "+err.Error())
return
}
response.PaginatedWithResult(c, subscriptions, toResponsePagination(pagination))
}
// ListByUser handles listing subscriptions for a specific user
// GET /api/v1/admin/users/:id/subscriptions
func (h *SubscriptionHandler) ListByUser(c *gin.Context) {
userID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid user ID")
return
}
subscriptions, err := h.subscriptionService.ListUserSubscriptions(c.Request.Context(), userID)
if err != nil {
response.InternalError(c, "Failed to list user subscriptions: "+err.Error())
return
}
response.Success(c, subscriptions)
}
// Helper function to get admin ID from context
func getAdminIDFromContext(c *gin.Context) int64 {
if user, exists := c.Get("user"); exists {
if u, ok := user.(*model.User); ok && u != nil {
return u.ID
}
}
return 0
}

View File

@@ -0,0 +1,82 @@
package admin
import (
"net/http"
"sub2api/internal/pkg/response"
"sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/redis/go-redis/v9"
)
// SystemHandler handles system-related operations
type SystemHandler struct {
updateSvc *service.UpdateService
}
// NewSystemHandler creates a new SystemHandler
func NewSystemHandler(rdb *redis.Client, version, buildType string) *SystemHandler {
return &SystemHandler{
updateSvc: service.NewUpdateService(rdb, version, buildType),
}
}
// GetVersion returns the current version
// GET /api/v1/admin/system/version
func (h *SystemHandler) GetVersion(c *gin.Context) {
info, _ := h.updateSvc.CheckUpdate(c.Request.Context(), false)
response.Success(c, gin.H{
"version": info.CurrentVersion,
})
}
// CheckUpdates checks for available updates
// GET /api/v1/admin/system/check-updates
func (h *SystemHandler) CheckUpdates(c *gin.Context) {
force := c.Query("force") == "true"
info, err := h.updateSvc.CheckUpdate(c.Request.Context(), force)
if err != nil {
response.Error(c, http.StatusInternalServerError, err.Error())
return
}
response.Success(c, info)
}
// PerformUpdate downloads and applies the update
// POST /api/v1/admin/system/update
func (h *SystemHandler) PerformUpdate(c *gin.Context) {
if err := h.updateSvc.PerformUpdate(c.Request.Context()); err != nil {
response.Error(c, http.StatusInternalServerError, err.Error())
return
}
response.Success(c, gin.H{
"message": "Update completed. Please restart the service.",
"need_restart": true,
})
}
// Rollback restores the previous version
// POST /api/v1/admin/system/rollback
func (h *SystemHandler) Rollback(c *gin.Context) {
if err := h.updateSvc.Rollback(); err != nil {
response.Error(c, http.StatusInternalServerError, err.Error())
return
}
response.Success(c, gin.H{
"message": "Rollback completed. Please restart the service.",
"need_restart": true,
})
}
// RestartService restarts the systemd service
// POST /api/v1/admin/system/restart
func (h *SystemHandler) RestartService(c *gin.Context) {
if err := h.updateSvc.RestartService(); err != nil {
response.Error(c, http.StatusInternalServerError, err.Error())
return
}
response.Success(c, gin.H{
"message": "Service restart initiated",
})
}

View File

@@ -0,0 +1,262 @@
package admin
import (
"strconv"
"time"
"sub2api/internal/pkg/response"
"sub2api/internal/pkg/timezone"
"sub2api/internal/repository"
"sub2api/internal/service"
"github.com/gin-gonic/gin"
)
// UsageHandler handles admin usage-related requests
type UsageHandler struct {
usageRepo *repository.UsageLogRepository
apiKeyRepo *repository.ApiKeyRepository
usageService *service.UsageService
adminService service.AdminService
}
// NewUsageHandler creates a new admin usage handler
func NewUsageHandler(
usageRepo *repository.UsageLogRepository,
apiKeyRepo *repository.ApiKeyRepository,
usageService *service.UsageService,
adminService service.AdminService,
) *UsageHandler {
return &UsageHandler{
usageRepo: usageRepo,
apiKeyRepo: apiKeyRepo,
usageService: usageService,
adminService: adminService,
}
}
// List handles listing all usage records with filters
// GET /api/v1/admin/usage
func (h *UsageHandler) List(c *gin.Context) {
page, pageSize := response.ParsePagination(c)
// Parse filters
var userID, apiKeyID int64
if userIDStr := c.Query("user_id"); userIDStr != "" {
id, err := strconv.ParseInt(userIDStr, 10, 64)
if err != nil {
response.BadRequest(c, "Invalid user_id")
return
}
userID = id
}
if apiKeyIDStr := c.Query("api_key_id"); apiKeyIDStr != "" {
id, err := strconv.ParseInt(apiKeyIDStr, 10, 64)
if err != nil {
response.BadRequest(c, "Invalid api_key_id")
return
}
apiKeyID = id
}
// Parse date range
var startTime, endTime *time.Time
if startDateStr := c.Query("start_date"); startDateStr != "" {
t, err := timezone.ParseInLocation("2006-01-02", startDateStr)
if err != nil {
response.BadRequest(c, "Invalid start_date format, use YYYY-MM-DD")
return
}
startTime = &t
}
if endDateStr := c.Query("end_date"); endDateStr != "" {
t, err := timezone.ParseInLocation("2006-01-02", endDateStr)
if err != nil {
response.BadRequest(c, "Invalid end_date format, use YYYY-MM-DD")
return
}
// Set end time to end of day
t = t.Add(24*time.Hour - time.Nanosecond)
endTime = &t
}
params := repository.PaginationParams{Page: page, PageSize: pageSize}
filters := repository.UsageLogFilters{
UserID: userID,
ApiKeyID: apiKeyID,
StartTime: startTime,
EndTime: endTime,
}
records, result, err := h.usageRepo.ListWithFilters(c.Request.Context(), params, filters)
if err != nil {
response.InternalError(c, "Failed to list usage records: "+err.Error())
return
}
response.Paginated(c, records, result.Total, page, pageSize)
}
// Stats handles getting usage statistics with filters
// GET /api/v1/admin/usage/stats
func (h *UsageHandler) Stats(c *gin.Context) {
// Parse filters
var userID, apiKeyID int64
if userIDStr := c.Query("user_id"); userIDStr != "" {
id, err := strconv.ParseInt(userIDStr, 10, 64)
if err != nil {
response.BadRequest(c, "Invalid user_id")
return
}
userID = id
}
if apiKeyIDStr := c.Query("api_key_id"); apiKeyIDStr != "" {
id, err := strconv.ParseInt(apiKeyIDStr, 10, 64)
if err != nil {
response.BadRequest(c, "Invalid api_key_id")
return
}
apiKeyID = id
}
// Parse date range
now := timezone.Now()
var startTime, endTime time.Time
startDateStr := c.Query("start_date")
endDateStr := c.Query("end_date")
if startDateStr != "" && endDateStr != "" {
var err error
startTime, err = timezone.ParseInLocation("2006-01-02", startDateStr)
if err != nil {
response.BadRequest(c, "Invalid start_date format, use YYYY-MM-DD")
return
}
endTime, err = timezone.ParseInLocation("2006-01-02", endDateStr)
if err != nil {
response.BadRequest(c, "Invalid end_date format, use YYYY-MM-DD")
return
}
endTime = endTime.Add(24*time.Hour - time.Nanosecond)
} else {
period := c.DefaultQuery("period", "today")
switch period {
case "today":
startTime = timezone.StartOfDay(now)
case "week":
startTime = now.AddDate(0, 0, -7)
case "month":
startTime = now.AddDate(0, -1, 0)
default:
startTime = timezone.StartOfDay(now)
}
endTime = now
}
if apiKeyID > 0 {
stats, err := h.usageService.GetStatsByApiKey(c.Request.Context(), apiKeyID, startTime, endTime)
if err != nil {
response.InternalError(c, "Failed to get usage statistics: "+err.Error())
return
}
response.Success(c, stats)
return
}
if userID > 0 {
stats, err := h.usageService.GetStatsByUser(c.Request.Context(), userID, startTime, endTime)
if err != nil {
response.InternalError(c, "Failed to get usage statistics: "+err.Error())
return
}
response.Success(c, stats)
return
}
// Get global stats
stats, err := h.usageRepo.GetGlobalStats(c.Request.Context(), startTime, endTime)
if err != nil {
response.InternalError(c, "Failed to get usage statistics: "+err.Error())
return
}
response.Success(c, stats)
}
// SearchUsers handles searching users by email keyword
// GET /api/v1/admin/usage/search-users
func (h *UsageHandler) SearchUsers(c *gin.Context) {
keyword := c.Query("q")
if keyword == "" {
response.Success(c, []interface{}{})
return
}
// Limit to 30 results
users, _, err := h.adminService.ListUsers(c.Request.Context(), 1, 30, "", "", keyword)
if err != nil {
response.InternalError(c, "Failed to search users: "+err.Error())
return
}
// Return simplified user list (only id and email)
type SimpleUser struct {
ID int64 `json:"id"`
Email string `json:"email"`
}
result := make([]SimpleUser, len(users))
for i, u := range users {
result[i] = SimpleUser{
ID: u.ID,
Email: u.Email,
}
}
response.Success(c, result)
}
// SearchApiKeys handles searching API keys by user
// GET /api/v1/admin/usage/search-api-keys
func (h *UsageHandler) SearchApiKeys(c *gin.Context) {
userIDStr := c.Query("user_id")
keyword := c.Query("q")
var userID int64
if userIDStr != "" {
id, err := strconv.ParseInt(userIDStr, 10, 64)
if err != nil {
response.BadRequest(c, "Invalid user_id")
return
}
userID = id
}
keys, err := h.apiKeyRepo.SearchApiKeys(c.Request.Context(), userID, keyword, 30)
if err != nil {
response.InternalError(c, "Failed to search API keys: "+err.Error())
return
}
// Return simplified API key list (only id and name)
type SimpleApiKey struct {
ID int64 `json:"id"`
Name string `json:"name"`
UserID int64 `json:"user_id"`
}
result := make([]SimpleApiKey, len(keys))
for i, k := range keys {
result[i] = SimpleApiKey{
ID: k.ID,
Name: k.Name,
UserID: k.UserID,
}
}
response.Success(c, result)
}

View File

@@ -0,0 +1,221 @@
package admin
import (
"strconv"
"sub2api/internal/pkg/response"
"sub2api/internal/service"
"github.com/gin-gonic/gin"
)
// UserHandler handles admin user management
type UserHandler struct {
adminService service.AdminService
}
// NewUserHandler creates a new admin user handler
func NewUserHandler(adminService service.AdminService) *UserHandler {
return &UserHandler{
adminService: adminService,
}
}
// CreateUserRequest represents admin create user request
type CreateUserRequest struct {
Email string `json:"email" binding:"required,email"`
Password string `json:"password" binding:"required,min=6"`
Balance float64 `json:"balance"`
Concurrency int `json:"concurrency"`
AllowedGroups []int64 `json:"allowed_groups"`
}
// UpdateUserRequest represents admin update user request
// 使用指针类型来区分"未提供"和"设置为0"
type UpdateUserRequest struct {
Email string `json:"email" binding:"omitempty,email"`
Password string `json:"password" binding:"omitempty,min=6"`
Balance *float64 `json:"balance"`
Concurrency *int `json:"concurrency"`
Status string `json:"status" binding:"omitempty,oneof=active disabled"`
AllowedGroups *[]int64 `json:"allowed_groups"`
}
// UpdateBalanceRequest represents balance update request
type UpdateBalanceRequest struct {
Balance float64 `json:"balance" binding:"required"`
Operation string `json:"operation" binding:"required,oneof=set add subtract"`
}
// List handles listing all users with pagination
// GET /api/v1/admin/users
func (h *UserHandler) List(c *gin.Context) {
page, pageSize := response.ParsePagination(c)
status := c.Query("status")
role := c.Query("role")
search := c.Query("search")
users, total, err := h.adminService.ListUsers(c.Request.Context(), page, pageSize, status, role, search)
if err != nil {
response.InternalError(c, "Failed to list users: "+err.Error())
return
}
response.Paginated(c, users, total, page, pageSize)
}
// GetByID handles getting a user by ID
// GET /api/v1/admin/users/:id
func (h *UserHandler) GetByID(c *gin.Context) {
userID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid user ID")
return
}
user, err := h.adminService.GetUser(c.Request.Context(), userID)
if err != nil {
response.NotFound(c, "User not found")
return
}
response.Success(c, user)
}
// Create handles creating a new user
// POST /api/v1/admin/users
func (h *UserHandler) Create(c *gin.Context) {
var req CreateUserRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
user, err := h.adminService.CreateUser(c.Request.Context(), &service.CreateUserInput{
Email: req.Email,
Password: req.Password,
Balance: req.Balance,
Concurrency: req.Concurrency,
AllowedGroups: req.AllowedGroups,
})
if err != nil {
response.BadRequest(c, "Failed to create user: "+err.Error())
return
}
response.Success(c, user)
}
// Update handles updating a user
// PUT /api/v1/admin/users/:id
func (h *UserHandler) Update(c *gin.Context) {
userID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid user ID")
return
}
var req UpdateUserRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
// 使用指针类型直接传递nil 表示未提供该字段
user, err := h.adminService.UpdateUser(c.Request.Context(), userID, &service.UpdateUserInput{
Email: req.Email,
Password: req.Password,
Balance: req.Balance,
Concurrency: req.Concurrency,
Status: req.Status,
AllowedGroups: req.AllowedGroups,
})
if err != nil {
response.InternalError(c, "Failed to update user: "+err.Error())
return
}
response.Success(c, user)
}
// Delete handles deleting a user
// DELETE /api/v1/admin/users/:id
func (h *UserHandler) Delete(c *gin.Context) {
userID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid user ID")
return
}
err = h.adminService.DeleteUser(c.Request.Context(), userID)
if err != nil {
response.InternalError(c, "Failed to delete user: "+err.Error())
return
}
response.Success(c, gin.H{"message": "User deleted successfully"})
}
// UpdateBalance handles updating user balance
// POST /api/v1/admin/users/:id/balance
func (h *UserHandler) UpdateBalance(c *gin.Context) {
userID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid user ID")
return
}
var req UpdateBalanceRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
user, err := h.adminService.UpdateUserBalance(c.Request.Context(), userID, req.Balance, req.Operation)
if err != nil {
response.InternalError(c, "Failed to update balance: "+err.Error())
return
}
response.Success(c, user)
}
// GetUserAPIKeys handles getting user's API keys
// GET /api/v1/admin/users/:id/api-keys
func (h *UserHandler) GetUserAPIKeys(c *gin.Context) {
userID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid user ID")
return
}
page, pageSize := response.ParsePagination(c)
keys, total, err := h.adminService.GetUserAPIKeys(c.Request.Context(), userID, page, pageSize)
if err != nil {
response.InternalError(c, "Failed to get user API keys: "+err.Error())
return
}
response.Paginated(c, keys, total, page, pageSize)
}
// GetUserUsage handles getting user's usage statistics
// GET /api/v1/admin/users/:id/usage
func (h *UserHandler) GetUserUsage(c *gin.Context) {
userID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid user ID")
return
}
period := c.DefaultQuery("period", "month")
stats, err := h.adminService.GetUserUsageStats(c.Request.Context(), userID, period)
if err != nil {
response.InternalError(c, "Failed to get user usage: "+err.Error())
return
}
response.Success(c, stats)
}

View File

@@ -0,0 +1,235 @@
package handler
import (
"strconv"
"sub2api/internal/model"
"sub2api/internal/pkg/response"
"sub2api/internal/repository"
"sub2api/internal/service"
"github.com/gin-gonic/gin"
)
// APIKeyHandler handles API key-related requests
type APIKeyHandler struct {
apiKeyService *service.ApiKeyService
}
// NewAPIKeyHandler creates a new APIKeyHandler
func NewAPIKeyHandler(apiKeyService *service.ApiKeyService) *APIKeyHandler {
return &APIKeyHandler{
apiKeyService: apiKeyService,
}
}
// CreateAPIKeyRequest represents the create API key request payload
type CreateAPIKeyRequest struct {
Name string `json:"name" binding:"required"`
GroupID *int64 `json:"group_id"` // nullable
CustomKey *string `json:"custom_key"` // 可选的自定义key
}
// UpdateAPIKeyRequest represents the update API key request payload
type UpdateAPIKeyRequest struct {
Name string `json:"name"`
GroupID *int64 `json:"group_id"`
Status string `json:"status" binding:"omitempty,oneof=active inactive"`
}
// List handles listing user's API keys with pagination
// GET /api/v1/api-keys
func (h *APIKeyHandler) List(c *gin.Context) {
userValue, exists := c.Get("user")
if !exists {
response.Unauthorized(c, "User not authenticated")
return
}
user, ok := userValue.(*model.User)
if !ok {
response.InternalError(c, "Invalid user context")
return
}
page, pageSize := response.ParsePagination(c)
params := repository.PaginationParams{Page: page, PageSize: pageSize}
keys, result, err := h.apiKeyService.List(c.Request.Context(), user.ID, params)
if err != nil {
response.InternalError(c, "Failed to list API keys: "+err.Error())
return
}
response.Paginated(c, keys, result.Total, page, pageSize)
}
// GetByID handles getting a single API key
// GET /api/v1/api-keys/:id
func (h *APIKeyHandler) GetByID(c *gin.Context) {
userValue, exists := c.Get("user")
if !exists {
response.Unauthorized(c, "User not authenticated")
return
}
user, ok := userValue.(*model.User)
if !ok {
response.InternalError(c, "Invalid user context")
return
}
keyID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid key ID")
return
}
key, err := h.apiKeyService.GetByID(c.Request.Context(), keyID)
if err != nil {
response.NotFound(c, "API key not found")
return
}
// 验证所有权
if key.UserID != user.ID {
response.Forbidden(c, "Not authorized to access this key")
return
}
response.Success(c, key)
}
// Create handles creating a new API key
// POST /api/v1/api-keys
func (h *APIKeyHandler) Create(c *gin.Context) {
userValue, exists := c.Get("user")
if !exists {
response.Unauthorized(c, "User not authenticated")
return
}
user, ok := userValue.(*model.User)
if !ok {
response.InternalError(c, "Invalid user context")
return
}
var req CreateAPIKeyRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
svcReq := service.CreateApiKeyRequest{
Name: req.Name,
GroupID: req.GroupID,
CustomKey: req.CustomKey,
}
key, err := h.apiKeyService.Create(c.Request.Context(), user.ID, svcReq)
if err != nil {
response.InternalError(c, "Failed to create API key: "+err.Error())
return
}
response.Success(c, key)
}
// Update handles updating an API key
// PUT /api/v1/api-keys/:id
func (h *APIKeyHandler) Update(c *gin.Context) {
userValue, exists := c.Get("user")
if !exists {
response.Unauthorized(c, "User not authenticated")
return
}
user, ok := userValue.(*model.User)
if !ok {
response.InternalError(c, "Invalid user context")
return
}
keyID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid key ID")
return
}
var req UpdateAPIKeyRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
svcReq := service.UpdateApiKeyRequest{}
if req.Name != "" {
svcReq.Name = &req.Name
}
svcReq.GroupID = req.GroupID
if req.Status != "" {
svcReq.Status = &req.Status
}
key, err := h.apiKeyService.Update(c.Request.Context(), keyID, user.ID, svcReq)
if err != nil {
response.InternalError(c, "Failed to update API key: "+err.Error())
return
}
response.Success(c, key)
}
// Delete handles deleting an API key
// DELETE /api/v1/api-keys/:id
func (h *APIKeyHandler) Delete(c *gin.Context) {
userValue, exists := c.Get("user")
if !exists {
response.Unauthorized(c, "User not authenticated")
return
}
user, ok := userValue.(*model.User)
if !ok {
response.InternalError(c, "Invalid user context")
return
}
keyID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid key ID")
return
}
err = h.apiKeyService.Delete(c.Request.Context(), keyID, user.ID)
if err != nil {
response.InternalError(c, "Failed to delete API key: "+err.Error())
return
}
response.Success(c, gin.H{"message": "API key deleted successfully"})
}
// GetAvailableGroups 获取用户可以绑定的分组列表
// GET /api/v1/groups/available
func (h *APIKeyHandler) GetAvailableGroups(c *gin.Context) {
userValue, exists := c.Get("user")
if !exists {
response.Unauthorized(c, "User not authenticated")
return
}
user, ok := userValue.(*model.User)
if !ok {
response.InternalError(c, "Invalid user context")
return
}
groups, err := h.apiKeyService.GetAvailableGroups(c.Request.Context(), user.ID)
if err != nil {
response.InternalError(c, "Failed to get available groups: "+err.Error())
return
}
response.Success(c, groups)
}

View File

@@ -0,0 +1,158 @@
package handler
import (
"sub2api/internal/model"
"sub2api/internal/pkg/response"
"sub2api/internal/service"
"github.com/gin-gonic/gin"
)
// AuthHandler handles authentication-related requests
type AuthHandler struct {
authService *service.AuthService
}
// NewAuthHandler creates a new AuthHandler
func NewAuthHandler(authService *service.AuthService) *AuthHandler {
return &AuthHandler{
authService: authService,
}
}
// RegisterRequest represents the registration request payload
type RegisterRequest struct {
Email string `json:"email" binding:"required,email"`
Password string `json:"password" binding:"required,min=6"`
VerifyCode string `json:"verify_code"`
TurnstileToken string `json:"turnstile_token"`
}
// SendVerifyCodeRequest 发送验证码请求
type SendVerifyCodeRequest struct {
Email string `json:"email" binding:"required,email"`
TurnstileToken string `json:"turnstile_token"`
}
// SendVerifyCodeResponse 发送验证码响应
type SendVerifyCodeResponse struct {
Message string `json:"message"`
Countdown int `json:"countdown"` // 倒计时秒数
}
// LoginRequest represents the login request payload
type LoginRequest struct {
Email string `json:"email" binding:"required,email"`
Password string `json:"password" binding:"required"`
TurnstileToken string `json:"turnstile_token"`
}
// AuthResponse 认证响应格式(匹配前端期望)
type AuthResponse struct {
AccessToken string `json:"access_token"`
TokenType string `json:"token_type"`
User *model.User `json:"user"`
}
// Register handles user registration
// POST /api/v1/auth/register
func (h *AuthHandler) Register(c *gin.Context) {
var req RegisterRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
// Turnstile 验证(当提供了邮箱验证码时跳过,因为发送验证码时已验证过)
if req.VerifyCode == "" {
if err := h.authService.VerifyTurnstile(c.Request.Context(), req.TurnstileToken, c.ClientIP()); err != nil {
response.BadRequest(c, "Turnstile verification failed: "+err.Error())
return
}
}
token, user, err := h.authService.RegisterWithVerification(c.Request.Context(), req.Email, req.Password, req.VerifyCode)
if err != nil {
response.BadRequest(c, "Registration failed: "+err.Error())
return
}
response.Success(c, AuthResponse{
AccessToken: token,
TokenType: "Bearer",
User: user,
})
}
// SendVerifyCode 发送邮箱验证码
// POST /api/v1/auth/send-verify-code
func (h *AuthHandler) SendVerifyCode(c *gin.Context) {
var req SendVerifyCodeRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
// Turnstile 验证
if err := h.authService.VerifyTurnstile(c.Request.Context(), req.TurnstileToken, c.ClientIP()); err != nil {
response.BadRequest(c, "Turnstile verification failed: "+err.Error())
return
}
result, err := h.authService.SendVerifyCodeAsync(c.Request.Context(), req.Email)
if err != nil {
response.BadRequest(c, "Failed to send verification code: "+err.Error())
return
}
response.Success(c, SendVerifyCodeResponse{
Message: "Verification code sent successfully",
Countdown: result.Countdown,
})
}
// Login handles user login
// POST /api/v1/auth/login
func (h *AuthHandler) Login(c *gin.Context) {
var req LoginRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
// Turnstile 验证
if err := h.authService.VerifyTurnstile(c.Request.Context(), req.TurnstileToken, c.ClientIP()); err != nil {
response.BadRequest(c, "Turnstile verification failed: "+err.Error())
return
}
token, user, err := h.authService.Login(c.Request.Context(), req.Email, req.Password)
if err != nil {
response.Unauthorized(c, "Login failed: "+err.Error())
return
}
response.Success(c, AuthResponse{
AccessToken: token,
TokenType: "Bearer",
User: user,
})
}
// GetCurrentUser handles getting current authenticated user
// GET /api/v1/auth/me
func (h *AuthHandler) GetCurrentUser(c *gin.Context) {
userValue, exists := c.Get("user")
if !exists {
response.Unauthorized(c, "User not authenticated")
return
}
user, ok := userValue.(*model.User)
if !ok {
response.InternalError(c, "Invalid user context")
return
}
response.Success(c, user)
}

View File

@@ -0,0 +1,445 @@
package handler
import (
"context"
"encoding/json"
"fmt"
"io"
"log"
"net/http"
"time"
"sub2api/internal/middleware"
"sub2api/internal/model"
"sub2api/internal/service"
"github.com/gin-gonic/gin"
)
const (
// Maximum wait time for concurrency slot
maxConcurrencyWait = 60 * time.Second
// Ping interval during wait
pingInterval = 5 * time.Second
)
// GatewayHandler handles API gateway requests
type GatewayHandler struct {
gatewayService *service.GatewayService
userService *service.UserService
concurrencyService *service.ConcurrencyService
billingCacheService *service.BillingCacheService
}
// NewGatewayHandler creates a new GatewayHandler
func NewGatewayHandler(gatewayService *service.GatewayService, userService *service.UserService, concurrencyService *service.ConcurrencyService, billingCacheService *service.BillingCacheService) *GatewayHandler {
return &GatewayHandler{
gatewayService: gatewayService,
userService: userService,
concurrencyService: concurrencyService,
billingCacheService: billingCacheService,
}
}
// Messages handles Claude API compatible messages endpoint
// POST /v1/messages
func (h *GatewayHandler) Messages(c *gin.Context) {
// 从context获取apiKey和userApiKeyAuth中间件已设置
apiKey, ok := middleware.GetApiKeyFromContext(c)
if !ok {
h.errorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key")
return
}
user, ok := middleware.GetUserFromContext(c)
if !ok {
h.errorResponse(c, http.StatusInternalServerError, "api_error", "User context not found")
return
}
// 读取请求体
body, err := io.ReadAll(c.Request.Body)
if err != nil {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to read request body")
return
}
if len(body) == 0 {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Request body is empty")
return
}
// 解析请求获取模型名和stream
var req struct {
Model string `json:"model"`
Stream bool `json:"stream"`
}
if err := json.Unmarshal(body, &req); err != nil {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body")
return
}
// Track if we've started streaming (for error handling)
streamStarted := false
// 获取订阅信息可能为nil- 提前获取用于后续检查
subscription, _ := middleware.GetSubscriptionFromContext(c)
// 0. 检查wait队列是否已满
maxWait := service.CalculateMaxWait(user.Concurrency)
canWait, err := h.concurrencyService.IncrementWaitCount(c.Request.Context(), user.ID, maxWait)
if err != nil {
log.Printf("Increment wait count failed: %v", err)
// On error, allow request to proceed
} else if !canWait {
h.errorResponse(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later")
return
}
// 确保在函数退出时减少wait计数
defer h.concurrencyService.DecrementWaitCount(c.Request.Context(), user.ID)
// 1. 首先获取用户并发槽位
userReleaseFunc, err := h.acquireUserSlotWithWait(c, user, req.Stream, &streamStarted)
if err != nil {
log.Printf("User concurrency acquire failed: %v", err)
h.handleConcurrencyError(c, err, "user", streamStarted)
return
}
if userReleaseFunc != nil {
defer userReleaseFunc()
}
// 2. 【新增】Wait后二次检查余额/订阅
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), user, apiKey, apiKey.Group, subscription); err != nil {
log.Printf("Billing eligibility check failed after wait: %v", err)
h.handleStreamingAwareError(c, http.StatusForbidden, "billing_error", err.Error(), streamStarted)
return
}
// 计算粘性会话hash
sessionHash := h.gatewayService.GenerateSessionHash(body)
// 选择支持该模型的账号
account, err := h.gatewayService.SelectAccountForModel(c.Request.Context(), apiKey.GroupID, sessionHash, req.Model)
if err != nil {
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
return
}
// 3. 获取账号并发槽位
accountReleaseFunc, err := h.acquireAccountSlotWithWait(c, account, req.Stream, &streamStarted)
if err != nil {
log.Printf("Account concurrency acquire failed: %v", err)
h.handleConcurrencyError(c, err, "account", streamStarted)
return
}
if accountReleaseFunc != nil {
defer accountReleaseFunc()
}
// 转发请求
result, err := h.gatewayService.Forward(c.Request.Context(), c, account, body)
if err != nil {
// 错误响应已在Forward中处理这里只记录日志
log.Printf("Forward request failed: %v", err)
return
}
// 异步记录使用量subscription已在函数开头获取
go func() {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
Result: result,
ApiKey: apiKey,
User: user,
Account: account,
Subscription: subscription,
}); err != nil {
log.Printf("Record usage failed: %v", err)
}
}()
}
// acquireUserSlotWithWait acquires a user concurrency slot, waiting if necessary
// For streaming requests, sends ping events during the wait
// streamStarted is updated if streaming response has begun
func (h *GatewayHandler) acquireUserSlotWithWait(c *gin.Context, user *model.User, isStream bool, streamStarted *bool) (func(), error) {
ctx := c.Request.Context()
// Try to acquire immediately
result, err := h.concurrencyService.AcquireUserSlot(ctx, user.ID, user.Concurrency)
if err != nil {
return nil, err
}
if result.Acquired {
return result.ReleaseFunc, nil
}
// Need to wait - handle streaming ping if needed
return h.waitForSlotWithPing(c, "user", user.ID, user.Concurrency, isStream, streamStarted)
}
// acquireAccountSlotWithWait acquires an account concurrency slot, waiting if necessary
// For streaming requests, sends ping events during the wait
// streamStarted is updated if streaming response has begun
func (h *GatewayHandler) acquireAccountSlotWithWait(c *gin.Context, account *model.Account, isStream bool, streamStarted *bool) (func(), error) {
ctx := c.Request.Context()
// Try to acquire immediately
result, err := h.concurrencyService.AcquireAccountSlot(ctx, account.ID, account.Concurrency)
if err != nil {
return nil, err
}
if result.Acquired {
return result.ReleaseFunc, nil
}
// Need to wait - handle streaming ping if needed
return h.waitForSlotWithPing(c, "account", account.ID, account.Concurrency, isStream, streamStarted)
}
// concurrencyError represents a concurrency limit error with context
type concurrencyError struct {
SlotType string
IsTimeout bool
}
func (e *concurrencyError) Error() string {
if e.IsTimeout {
return fmt.Sprintf("timeout waiting for %s concurrency slot", e.SlotType)
}
return fmt.Sprintf("%s concurrency limit reached", e.SlotType)
}
// waitForSlotWithPing waits for a concurrency slot, sending ping events for streaming requests
// Note: For streaming requests, we send ping to keep the connection alive.
// streamStarted pointer is updated when streaming begins (for proper error handling by caller)
func (h *GatewayHandler) waitForSlotWithPing(c *gin.Context, slotType string, id int64, maxConcurrency int, isStream bool, streamStarted *bool) (func(), error) {
ctx, cancel := context.WithTimeout(c.Request.Context(), maxConcurrencyWait)
defer cancel()
// For streaming requests, set up SSE headers for ping
var flusher http.Flusher
if isStream {
var ok bool
flusher, ok = c.Writer.(http.Flusher)
if !ok {
return nil, fmt.Errorf("streaming not supported")
}
}
pingTicker := time.NewTicker(pingInterval)
defer pingTicker.Stop()
pollTicker := time.NewTicker(100 * time.Millisecond)
defer pollTicker.Stop()
for {
select {
case <-ctx.Done():
return nil, &concurrencyError{
SlotType: slotType,
IsTimeout: true,
}
case <-pingTicker.C:
// Send ping for streaming requests to keep connection alive
if isStream && flusher != nil {
// Set headers on first ping (lazy initialization)
if !*streamStarted {
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
c.Header("X-Accel-Buffering", "no")
*streamStarted = true
}
fmt.Fprintf(c.Writer, "data: {\"type\": \"ping\"}\n\n")
flusher.Flush()
}
case <-pollTicker.C:
// Try to acquire slot
var result *service.AcquireResult
var err error
if slotType == "user" {
result, err = h.concurrencyService.AcquireUserSlot(ctx, id, maxConcurrency)
} else {
result, err = h.concurrencyService.AcquireAccountSlot(ctx, id, maxConcurrency)
}
if err != nil {
return nil, err
}
if result.Acquired {
return result.ReleaseFunc, nil
}
}
}
}
// Models handles listing available models
// GET /v1/models
func (h *GatewayHandler) Models(c *gin.Context) {
models := []gin.H{
{
"id": "claude-opus-4-5-20251101",
"type": "model",
"display_name": "Claude Opus 4.5",
"created_at": "2025-11-01T00:00:00Z",
},
{
"id": "claude-sonnet-4-5-20250929",
"type": "model",
"display_name": "Claude Sonnet 4.5",
"created_at": "2025-09-29T00:00:00Z",
},
{
"id": "claude-haiku-4-5-20251001",
"type": "model",
"display_name": "Claude Haiku 4.5",
"created_at": "2025-10-01T00:00:00Z",
},
}
c.JSON(http.StatusOK, gin.H{
"data": models,
"object": "list",
})
}
// Usage handles getting account balance for CC Switch integration
// GET /v1/usage
func (h *GatewayHandler) Usage(c *gin.Context) {
apiKey, ok := middleware.GetApiKeyFromContext(c)
if !ok {
h.errorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key")
return
}
user, ok := middleware.GetUserFromContext(c)
if !ok {
h.errorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key")
return
}
// 订阅模式:返回订阅限额信息
if apiKey.Group != nil && apiKey.Group.IsSubscriptionType() {
subscription, ok := middleware.GetSubscriptionFromContext(c)
if !ok {
h.errorResponse(c, http.StatusForbidden, "subscription_error", "No active subscription")
return
}
remaining := h.calculateSubscriptionRemaining(apiKey.Group, subscription)
c.JSON(http.StatusOK, gin.H{
"isValid": true,
"planName": apiKey.Group.Name,
"remaining": remaining,
"unit": "USD",
})
return
}
// 余额模式:返回钱包余额
latestUser, err := h.userService.GetByID(c.Request.Context(), user.ID)
if err != nil {
h.errorResponse(c, http.StatusInternalServerError, "api_error", "Failed to get user info")
return
}
c.JSON(http.StatusOK, gin.H{
"isValid": true,
"planName": "钱包余额",
"remaining": latestUser.Balance,
"unit": "USD",
})
}
// calculateSubscriptionRemaining 计算订阅剩余可用额度
// 逻辑:
// 1. 如果日/周/月任一限额达到100%返回0
// 2. 否则返回所有已配置周期中剩余额度的最小值
func (h *GatewayHandler) calculateSubscriptionRemaining(group *model.Group, sub *model.UserSubscription) float64 {
var remainingValues []float64
// 检查日限额
if group.HasDailyLimit() {
remaining := *group.DailyLimitUSD - sub.DailyUsageUSD
if remaining <= 0 {
return 0
}
remainingValues = append(remainingValues, remaining)
}
// 检查周限额
if group.HasWeeklyLimit() {
remaining := *group.WeeklyLimitUSD - sub.WeeklyUsageUSD
if remaining <= 0 {
return 0
}
remainingValues = append(remainingValues, remaining)
}
// 检查月限额
if group.HasMonthlyLimit() {
remaining := *group.MonthlyLimitUSD - sub.MonthlyUsageUSD
if remaining <= 0 {
return 0
}
remainingValues = append(remainingValues, remaining)
}
// 如果没有配置任何限额,返回-1表示无限制
if len(remainingValues) == 0 {
return -1
}
// 返回最小值
min := remainingValues[0]
for _, v := range remainingValues[1:] {
if v < min {
min = v
}
}
return min
}
// handleConcurrencyError handles concurrency-related errors with proper 429 response
func (h *GatewayHandler) handleConcurrencyError(c *gin.Context, err error, slotType string, streamStarted bool) {
h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error",
fmt.Sprintf("Concurrency limit exceeded for %s, please retry later", slotType), streamStarted)
}
// handleStreamingAwareError handles errors that may occur after streaming has started
func (h *GatewayHandler) handleStreamingAwareError(c *gin.Context, status int, errType, message string, streamStarted bool) {
if streamStarted {
// Stream already started, send error as SSE event then close
flusher, ok := c.Writer.(http.Flusher)
if ok {
// Send error event in SSE format
errorEvent := fmt.Sprintf(`data: {"type": "error", "error": {"type": "%s", "message": "%s"}}`+"\n\n", errType, message)
fmt.Fprint(c.Writer, errorEvent)
flusher.Flush()
}
return
}
// Normal case: return JSON response with proper status code
h.errorResponse(c, status, errType, message)
}
// errorResponse 返回Claude API格式的错误响应
func (h *GatewayHandler) errorResponse(c *gin.Context, status int, errType, message string) {
c.JSON(status, gin.H{
"type": "error",
"error": gin.H{
"type": errType,
"message": message,
},
})
}

View File

@@ -0,0 +1,70 @@
package handler
import (
"sub2api/internal/handler/admin"
"sub2api/internal/repository"
"sub2api/internal/service"
"github.com/redis/go-redis/v9"
)
// AdminHandlers contains all admin-related HTTP handlers
type AdminHandlers struct {
Dashboard *admin.DashboardHandler
User *admin.UserHandler
Group *admin.GroupHandler
Account *admin.AccountHandler
OAuth *admin.OAuthHandler
Proxy *admin.ProxyHandler
Redeem *admin.RedeemHandler
Setting *admin.SettingHandler
System *admin.SystemHandler
Subscription *admin.SubscriptionHandler
Usage *admin.UsageHandler
}
// Handlers contains all HTTP handlers
type Handlers struct {
Auth *AuthHandler
User *UserHandler
APIKey *APIKeyHandler
Usage *UsageHandler
Redeem *RedeemHandler
Subscription *SubscriptionHandler
Admin *AdminHandlers
Gateway *GatewayHandler
Setting *SettingHandler
}
// BuildInfo contains build-time information
type BuildInfo struct {
Version string
BuildType string // "source" for manual builds, "release" for CI builds
}
// NewHandlers creates a new Handlers instance with all handlers initialized
func NewHandlers(services *service.Services, repos *repository.Repositories, rdb *redis.Client, buildInfo BuildInfo) *Handlers {
return &Handlers{
Auth: NewAuthHandler(services.Auth),
User: NewUserHandler(services.User),
APIKey: NewAPIKeyHandler(services.ApiKey),
Usage: NewUsageHandler(services.Usage, repos.UsageLog, services.ApiKey),
Redeem: NewRedeemHandler(services.Redeem),
Subscription: NewSubscriptionHandler(services.Subscription),
Admin: &AdminHandlers{
Dashboard: admin.NewDashboardHandler(services.Admin, repos.UsageLog),
User: admin.NewUserHandler(services.Admin),
Group: admin.NewGroupHandler(services.Admin),
Account: admin.NewAccountHandler(services.Admin, services.OAuth, services.RateLimit, services.AccountUsage, services.AccountTest),
OAuth: admin.NewOAuthHandler(services.OAuth, services.Admin),
Proxy: admin.NewProxyHandler(services.Admin),
Redeem: admin.NewRedeemHandler(services.Admin),
Setting: admin.NewSettingHandler(services.Setting, services.Email),
System: admin.NewSystemHandler(rdb, buildInfo.Version, buildInfo.BuildType),
Subscription: admin.NewSubscriptionHandler(services.Subscription),
Usage: admin.NewUsageHandler(repos.UsageLog, repos.ApiKey, services.Usage, services.Admin),
},
Gateway: NewGatewayHandler(services.Gateway, services.User, services.Concurrency, services.BillingCache),
Setting: NewSettingHandler(services.Setting, buildInfo.Version),
}
}

View File

@@ -0,0 +1,92 @@
package handler
import (
"sub2api/internal/model"
"sub2api/internal/pkg/response"
"sub2api/internal/service"
"github.com/gin-gonic/gin"
)
// RedeemHandler handles redeem code-related requests
type RedeemHandler struct {
redeemService *service.RedeemService
}
// NewRedeemHandler creates a new RedeemHandler
func NewRedeemHandler(redeemService *service.RedeemService) *RedeemHandler {
return &RedeemHandler{
redeemService: redeemService,
}
}
// RedeemRequest represents the redeem code request payload
type RedeemRequest struct {
Code string `json:"code" binding:"required"`
}
// RedeemResponse represents the redeem response
type RedeemResponse struct {
Message string `json:"message"`
Type string `json:"type"`
Value float64 `json:"value"`
NewBalance *float64 `json:"new_balance,omitempty"`
NewConcurrency *int `json:"new_concurrency,omitempty"`
}
// Redeem handles redeeming a code
// POST /api/v1/redeem
func (h *RedeemHandler) Redeem(c *gin.Context) {
userValue, exists := c.Get("user")
if !exists {
response.Unauthorized(c, "User not authenticated")
return
}
user, ok := userValue.(*model.User)
if !ok {
response.InternalError(c, "Invalid user context")
return
}
var req RedeemRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
result, err := h.redeemService.Redeem(c.Request.Context(), user.ID, req.Code)
if err != nil {
response.BadRequest(c, "Failed to redeem code: "+err.Error())
return
}
response.Success(c, result)
}
// GetHistory returns the user's redemption history
// GET /api/v1/redeem/history
func (h *RedeemHandler) GetHistory(c *gin.Context) {
userValue, exists := c.Get("user")
if !exists {
response.Unauthorized(c, "User not authenticated")
return
}
user, ok := userValue.(*model.User)
if !ok {
response.InternalError(c, "Invalid user context")
return
}
// Default limit is 25
limit := 25
codes, err := h.redeemService.GetUserHistory(c.Request.Context(), user.ID, limit)
if err != nil {
response.InternalError(c, "Failed to get history: "+err.Error())
return
}
response.Success(c, codes)
}

View File

@@ -0,0 +1,35 @@
package handler
import (
"sub2api/internal/pkg/response"
"sub2api/internal/service"
"github.com/gin-gonic/gin"
)
// SettingHandler 公开设置处理器(无需认证)
type SettingHandler struct {
settingService *service.SettingService
version string
}
// NewSettingHandler 创建公开设置处理器
func NewSettingHandler(settingService *service.SettingService, version string) *SettingHandler {
return &SettingHandler{
settingService: settingService,
version: version,
}
}
// GetPublicSettings 获取公开设置
// GET /api/v1/settings/public
func (h *SettingHandler) GetPublicSettings(c *gin.Context) {
settings, err := h.settingService.GetPublicSettings(c.Request.Context())
if err != nil {
response.InternalError(c, "Failed to get settings: "+err.Error())
return
}
settings.Version = h.version
response.Success(c, settings)
}

View File

@@ -0,0 +1,203 @@
package handler
import (
"sub2api/internal/model"
"sub2api/internal/pkg/response"
"sub2api/internal/service"
"github.com/gin-gonic/gin"
)
// SubscriptionSummaryItem represents a subscription item in summary
type SubscriptionSummaryItem struct {
ID int64 `json:"id"`
GroupID int64 `json:"group_id"`
GroupName string `json:"group_name"`
Status string `json:"status"`
DailyUsedUSD float64 `json:"daily_used_usd,omitempty"`
DailyLimitUSD float64 `json:"daily_limit_usd,omitempty"`
WeeklyUsedUSD float64 `json:"weekly_used_usd,omitempty"`
WeeklyLimitUSD float64 `json:"weekly_limit_usd,omitempty"`
MonthlyUsedUSD float64 `json:"monthly_used_usd,omitempty"`
MonthlyLimitUSD float64 `json:"monthly_limit_usd,omitempty"`
ExpiresAt *string `json:"expires_at,omitempty"`
}
// SubscriptionProgressInfo represents subscription with progress info
type SubscriptionProgressInfo struct {
Subscription *model.UserSubscription `json:"subscription"`
Progress *service.SubscriptionProgress `json:"progress"`
}
// SubscriptionHandler handles user subscription operations
type SubscriptionHandler struct {
subscriptionService *service.SubscriptionService
}
// NewSubscriptionHandler creates a new user subscription handler
func NewSubscriptionHandler(subscriptionService *service.SubscriptionService) *SubscriptionHandler {
return &SubscriptionHandler{
subscriptionService: subscriptionService,
}
}
// List handles listing current user's subscriptions
// GET /api/v1/subscriptions
func (h *SubscriptionHandler) List(c *gin.Context) {
user, exists := c.Get("user")
if !exists {
response.Unauthorized(c, "User not found in context")
return
}
u, ok := user.(*model.User)
if !ok {
response.InternalError(c, "Invalid user in context")
return
}
subscriptions, err := h.subscriptionService.ListUserSubscriptions(c.Request.Context(), u.ID)
if err != nil {
response.InternalError(c, "Failed to list subscriptions: "+err.Error())
return
}
response.Success(c, subscriptions)
}
// GetActive handles getting current user's active subscriptions
// GET /api/v1/subscriptions/active
func (h *SubscriptionHandler) GetActive(c *gin.Context) {
user, exists := c.Get("user")
if !exists {
response.Unauthorized(c, "User not found in context")
return
}
u, ok := user.(*model.User)
if !ok {
response.InternalError(c, "Invalid user in context")
return
}
subscriptions, err := h.subscriptionService.ListActiveUserSubscriptions(c.Request.Context(), u.ID)
if err != nil {
response.InternalError(c, "Failed to get active subscriptions: "+err.Error())
return
}
response.Success(c, subscriptions)
}
// GetProgress handles getting subscription progress for current user
// GET /api/v1/subscriptions/progress
func (h *SubscriptionHandler) GetProgress(c *gin.Context) {
user, exists := c.Get("user")
if !exists {
response.Unauthorized(c, "User not found in context")
return
}
u, ok := user.(*model.User)
if !ok {
response.InternalError(c, "Invalid user in context")
return
}
// Get all active subscriptions with progress
subscriptions, err := h.subscriptionService.ListActiveUserSubscriptions(c.Request.Context(), u.ID)
if err != nil {
response.InternalError(c, "Failed to get subscriptions: "+err.Error())
return
}
result := make([]SubscriptionProgressInfo, 0, len(subscriptions))
for i := range subscriptions {
sub := &subscriptions[i]
progress, err := h.subscriptionService.GetSubscriptionProgress(c.Request.Context(), sub.ID)
if err != nil {
// Skip subscriptions with errors
continue
}
result = append(result, SubscriptionProgressInfo{
Subscription: sub,
Progress: progress,
})
}
response.Success(c, result)
}
// GetSummary handles getting a summary of current user's subscription status
// GET /api/v1/subscriptions/summary
func (h *SubscriptionHandler) GetSummary(c *gin.Context) {
user, exists := c.Get("user")
if !exists {
response.Unauthorized(c, "User not found in context")
return
}
u, ok := user.(*model.User)
if !ok {
response.InternalError(c, "Invalid user in context")
return
}
// Get all active subscriptions
subscriptions, err := h.subscriptionService.ListActiveUserSubscriptions(c.Request.Context(), u.ID)
if err != nil {
response.InternalError(c, "Failed to get subscriptions: "+err.Error())
return
}
var totalUsed float64
items := make([]SubscriptionSummaryItem, 0, len(subscriptions))
for _, sub := range subscriptions {
item := SubscriptionSummaryItem{
ID: sub.ID,
GroupID: sub.GroupID,
Status: sub.Status,
DailyUsedUSD: sub.DailyUsageUSD,
WeeklyUsedUSD: sub.WeeklyUsageUSD,
MonthlyUsedUSD: sub.MonthlyUsageUSD,
}
// Add group info if preloaded
if sub.Group != nil {
item.GroupName = sub.Group.Name
if sub.Group.DailyLimitUSD != nil {
item.DailyLimitUSD = *sub.Group.DailyLimitUSD
}
if sub.Group.WeeklyLimitUSD != nil {
item.WeeklyLimitUSD = *sub.Group.WeeklyLimitUSD
}
if sub.Group.MonthlyLimitUSD != nil {
item.MonthlyLimitUSD = *sub.Group.MonthlyLimitUSD
}
}
// Format expiration time
if !sub.ExpiresAt.IsZero() {
formatted := sub.ExpiresAt.Format("2006-01-02T15:04:05Z07:00")
item.ExpiresAt = &formatted
}
// Track total usage (use monthly as the most comprehensive)
totalUsed += sub.MonthlyUsageUSD
items = append(items, item)
}
summary := struct {
ActiveCount int `json:"active_count"`
TotalUsedUSD float64 `json:"total_used_usd"`
Subscriptions []SubscriptionSummaryItem `json:"subscriptions"`
}{
ActiveCount: len(subscriptions),
TotalUsedUSD: totalUsed,
Subscriptions: items,
}
response.Success(c, summary)
}

View File

@@ -0,0 +1,396 @@
package handler
import (
"strconv"
"time"
"sub2api/internal/model"
"sub2api/internal/pkg/response"
"sub2api/internal/pkg/timezone"
"sub2api/internal/repository"
"sub2api/internal/service"
"github.com/gin-gonic/gin"
)
// UsageHandler handles usage-related requests
type UsageHandler struct {
usageService *service.UsageService
usageRepo *repository.UsageLogRepository
apiKeyService *service.ApiKeyService
}
// NewUsageHandler creates a new UsageHandler
func NewUsageHandler(usageService *service.UsageService, usageRepo *repository.UsageLogRepository, apiKeyService *service.ApiKeyService) *UsageHandler {
return &UsageHandler{
usageService: usageService,
usageRepo: usageRepo,
apiKeyService: apiKeyService,
}
}
// List handles listing usage records with pagination
// GET /api/v1/usage
func (h *UsageHandler) List(c *gin.Context) {
userValue, exists := c.Get("user")
if !exists {
response.Unauthorized(c, "User not authenticated")
return
}
user, ok := userValue.(*model.User)
if !ok {
response.InternalError(c, "Invalid user context")
return
}
page, pageSize := response.ParsePagination(c)
var apiKeyID int64
if apiKeyIDStr := c.Query("api_key_id"); apiKeyIDStr != "" {
id, err := strconv.ParseInt(apiKeyIDStr, 10, 64)
if err != nil {
response.BadRequest(c, "Invalid api_key_id")
return
}
// [Security Fix] Verify API Key ownership to prevent horizontal privilege escalation
apiKey, err := h.apiKeyService.GetByID(c.Request.Context(), id)
if err != nil {
response.NotFound(c, "API key not found")
return
}
if apiKey.UserID != user.ID {
response.Forbidden(c, "Not authorized to access this API key's usage records")
return
}
apiKeyID = id
}
params := repository.PaginationParams{Page: page, PageSize: pageSize}
var records []model.UsageLog
var result *repository.PaginationResult
var err error
if apiKeyID > 0 {
records, result, err = h.usageService.ListByApiKey(c.Request.Context(), apiKeyID, params)
} else {
records, result, err = h.usageService.ListByUser(c.Request.Context(), user.ID, params)
}
if err != nil {
response.InternalError(c, "Failed to list usage records: "+err.Error())
return
}
response.Paginated(c, records, result.Total, page, pageSize)
}
// GetByID handles getting a single usage record
// GET /api/v1/usage/:id
func (h *UsageHandler) GetByID(c *gin.Context) {
userValue, exists := c.Get("user")
if !exists {
response.Unauthorized(c, "User not authenticated")
return
}
user, ok := userValue.(*model.User)
if !ok {
response.InternalError(c, "Invalid user context")
return
}
usageID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid usage ID")
return
}
record, err := h.usageService.GetByID(c.Request.Context(), usageID)
if err != nil {
response.NotFound(c, "Usage record not found")
return
}
// 验证所有权
if record.UserID != user.ID {
response.Forbidden(c, "Not authorized to access this record")
return
}
response.Success(c, record)
}
// Stats handles getting usage statistics
// GET /api/v1/usage/stats
func (h *UsageHandler) Stats(c *gin.Context) {
userValue, exists := c.Get("user")
if !exists {
response.Unauthorized(c, "User not authenticated")
return
}
user, ok := userValue.(*model.User)
if !ok {
response.InternalError(c, "Invalid user context")
return
}
var apiKeyID int64
if apiKeyIDStr := c.Query("api_key_id"); apiKeyIDStr != "" {
id, err := strconv.ParseInt(apiKeyIDStr, 10, 64)
if err != nil {
response.BadRequest(c, "Invalid api_key_id")
return
}
// [Security Fix] Verify API Key ownership to prevent horizontal privilege escalation
apiKey, err := h.apiKeyService.GetByID(c.Request.Context(), id)
if err != nil {
response.NotFound(c, "API key not found")
return
}
if apiKey.UserID != user.ID {
response.Forbidden(c, "Not authorized to access this API key's statistics")
return
}
apiKeyID = id
}
// 获取时间范围参数
now := timezone.Now()
var startTime, endTime time.Time
// 优先使用 start_date 和 end_date 参数
startDateStr := c.Query("start_date")
endDateStr := c.Query("end_date")
if startDateStr != "" && endDateStr != "" {
// 使用自定义日期范围
var err error
startTime, err = timezone.ParseInLocation("2006-01-02", startDateStr)
if err != nil {
response.BadRequest(c, "Invalid start_date format, use YYYY-MM-DD")
return
}
endTime, err = timezone.ParseInLocation("2006-01-02", endDateStr)
if err != nil {
response.BadRequest(c, "Invalid end_date format, use YYYY-MM-DD")
return
}
// 设置结束时间为当天结束
endTime = endTime.Add(24*time.Hour - time.Nanosecond)
} else {
// 使用 period 参数
period := c.DefaultQuery("period", "today")
switch period {
case "today":
startTime = timezone.StartOfDay(now)
case "week":
startTime = now.AddDate(0, 0, -7)
case "month":
startTime = now.AddDate(0, -1, 0)
default:
startTime = timezone.StartOfDay(now)
}
endTime = now
}
var stats *service.UsageStats
var err error
if apiKeyID > 0 {
stats, err = h.usageService.GetStatsByApiKey(c.Request.Context(), apiKeyID, startTime, endTime)
} else {
stats, err = h.usageService.GetStatsByUser(c.Request.Context(), user.ID, startTime, endTime)
}
if err != nil {
response.InternalError(c, "Failed to get usage statistics: "+err.Error())
return
}
response.Success(c, stats)
}
// parseUserTimeRange parses start_date, end_date query parameters for user dashboard
func parseUserTimeRange(c *gin.Context) (time.Time, time.Time) {
now := timezone.Now()
startDate := c.Query("start_date")
endDate := c.Query("end_date")
var startTime, endTime time.Time
if startDate != "" {
if t, err := timezone.ParseInLocation("2006-01-02", startDate); err == nil {
startTime = t
} else {
startTime = timezone.StartOfDay(now.AddDate(0, 0, -7))
}
} else {
startTime = timezone.StartOfDay(now.AddDate(0, 0, -7))
}
if endDate != "" {
if t, err := timezone.ParseInLocation("2006-01-02", endDate); err == nil {
endTime = t.Add(24 * time.Hour) // Include the end date
} else {
endTime = timezone.StartOfDay(now.AddDate(0, 0, 1))
}
} else {
endTime = timezone.StartOfDay(now.AddDate(0, 0, 1))
}
return startTime, endTime
}
// DashboardStats handles getting user dashboard statistics
// GET /api/v1/usage/dashboard/stats
func (h *UsageHandler) DashboardStats(c *gin.Context) {
userValue, exists := c.Get("user")
if !exists {
response.Unauthorized(c, "User not authenticated")
return
}
user, ok := userValue.(*model.User)
if !ok {
response.InternalError(c, "Invalid user context")
return
}
stats, err := h.usageRepo.GetUserDashboardStats(c.Request.Context(), user.ID)
if err != nil {
response.InternalError(c, "Failed to get dashboard statistics")
return
}
response.Success(c, stats)
}
// DashboardTrend handles getting user usage trend data
// GET /api/v1/usage/dashboard/trend
func (h *UsageHandler) DashboardTrend(c *gin.Context) {
userValue, exists := c.Get("user")
if !exists {
response.Unauthorized(c, "User not authenticated")
return
}
user, ok := userValue.(*model.User)
if !ok {
response.InternalError(c, "Invalid user context")
return
}
startTime, endTime := parseUserTimeRange(c)
granularity := c.DefaultQuery("granularity", "day")
trend, err := h.usageRepo.GetUserUsageTrendByUserID(c.Request.Context(), user.ID, startTime, endTime, granularity)
if err != nil {
response.InternalError(c, "Failed to get usage trend")
return
}
response.Success(c, gin.H{
"trend": trend,
"start_date": startTime.Format("2006-01-02"),
"end_date": endTime.Add(-24 * time.Hour).Format("2006-01-02"),
"granularity": granularity,
})
}
// DashboardModels handles getting user model usage statistics
// GET /api/v1/usage/dashboard/models
func (h *UsageHandler) DashboardModels(c *gin.Context) {
userValue, exists := c.Get("user")
if !exists {
response.Unauthorized(c, "User not authenticated")
return
}
user, ok := userValue.(*model.User)
if !ok {
response.InternalError(c, "Invalid user context")
return
}
startTime, endTime := parseUserTimeRange(c)
stats, err := h.usageRepo.GetUserModelStats(c.Request.Context(), user.ID, startTime, endTime)
if err != nil {
response.InternalError(c, "Failed to get model statistics")
return
}
response.Success(c, gin.H{
"models": stats,
"start_date": startTime.Format("2006-01-02"),
"end_date": endTime.Add(-24 * time.Hour).Format("2006-01-02"),
})
}
// BatchApiKeysUsageRequest represents the request for batch API keys usage
type BatchApiKeysUsageRequest struct {
ApiKeyIDs []int64 `json:"api_key_ids" binding:"required"`
}
// DashboardApiKeysUsage handles getting usage stats for user's own API keys
// POST /api/v1/usage/dashboard/api-keys-usage
func (h *UsageHandler) DashboardApiKeysUsage(c *gin.Context) {
userValue, exists := c.Get("user")
if !exists {
response.Unauthorized(c, "User not authenticated")
return
}
user, ok := userValue.(*model.User)
if !ok {
response.InternalError(c, "Invalid user context")
return
}
var req BatchApiKeysUsageRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
if len(req.ApiKeyIDs) == 0 {
response.Success(c, gin.H{"stats": map[string]interface{}{}})
return
}
// Verify ownership of all requested API keys
userApiKeys, _, err := h.apiKeyService.List(c.Request.Context(), user.ID, repository.PaginationParams{Page: 1, PageSize: 1000})
if err != nil {
response.InternalError(c, "Failed to verify API key ownership")
return
}
userApiKeyIDs := make(map[int64]bool)
for _, key := range userApiKeys {
userApiKeyIDs[key.ID] = true
}
// Filter to only include user's own API keys
validApiKeyIDs := make([]int64, 0)
for _, id := range req.ApiKeyIDs {
if userApiKeyIDs[id] {
validApiKeyIDs = append(validApiKeyIDs, id)
}
}
if len(validApiKeyIDs) == 0 {
response.Success(c, gin.H{"stats": map[string]interface{}{}})
return
}
stats, err := h.usageRepo.GetBatchApiKeyUsageStats(c.Request.Context(), validApiKeyIDs)
if err != nil {
response.InternalError(c, "Failed to get API key usage stats")
return
}
response.Success(c, gin.H{"stats": stats})
}

View File

@@ -0,0 +1,85 @@
package handler
import (
"sub2api/internal/model"
"sub2api/internal/pkg/response"
"sub2api/internal/service"
"github.com/gin-gonic/gin"
)
// UserHandler handles user-related requests
type UserHandler struct {
userService *service.UserService
}
// NewUserHandler creates a new UserHandler
func NewUserHandler(userService *service.UserService) *UserHandler {
return &UserHandler{
userService: userService,
}
}
// ChangePasswordRequest represents the change password request payload
type ChangePasswordRequest struct {
OldPassword string `json:"old_password" binding:"required"`
NewPassword string `json:"new_password" binding:"required,min=6"`
}
// GetProfile handles getting user profile
// GET /api/v1/users/me
func (h *UserHandler) GetProfile(c *gin.Context) {
userValue, exists := c.Get("user")
if !exists {
response.Unauthorized(c, "User not authenticated")
return
}
user, ok := userValue.(*model.User)
if !ok {
response.InternalError(c, "Invalid user context")
return
}
userData, err := h.userService.GetByID(c.Request.Context(), user.ID)
if err != nil {
response.InternalError(c, "Failed to get user profile: "+err.Error())
return
}
response.Success(c, userData)
}
// ChangePassword handles changing user password
// POST /api/v1/users/me/password
func (h *UserHandler) ChangePassword(c *gin.Context) {
userValue, exists := c.Get("user")
if !exists {
response.Unauthorized(c, "User not authenticated")
return
}
user, ok := userValue.(*model.User)
if !ok {
response.InternalError(c, "Invalid user context")
return
}
var req ChangePasswordRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
svcReq := service.ChangePasswordRequest{
CurrentPassword: req.OldPassword,
NewPassword: req.NewPassword,
}
err := h.userService.ChangePassword(c.Request.Context(), user.ID, svcReq)
if err != nil {
response.BadRequest(c, "Failed to change password: "+err.Error())
return
}
response.Success(c, gin.H{"message": "Password changed successfully"})
}

View File

@@ -0,0 +1,28 @@
package middleware
import (
"sub2api/internal/model"
"github.com/gin-gonic/gin"
)
// AdminOnly 管理员权限中间件
// 必须在JWTAuth中间件之后使用
func AdminOnly() gin.HandlerFunc {
return func(c *gin.Context) {
// 从上下文获取用户
user, exists := GetUserFromContext(c)
if !exists {
AbortWithError(c, 401, "UNAUTHORIZED", "User not found in context")
return
}
// 检查是否为管理员
if user.Role != model.RoleAdmin {
AbortWithError(c, 403, "FORBIDDEN", "Admin access required")
return
}
c.Next()
}
}

View File

@@ -0,0 +1,161 @@
package middleware
import (
"context"
"errors"
"log"
"strings"
"sub2api/internal/model"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
)
// ApiKeyAuthService 定义API Key认证服务需要的接口
type ApiKeyAuthService interface {
GetByKey(ctx context.Context, key string) (*model.ApiKey, error)
}
// SubscriptionAuthService 定义订阅认证服务需要的接口
type SubscriptionAuthService interface {
GetActiveSubscription(ctx context.Context, userID, groupID int64) (*model.UserSubscription, error)
ValidateSubscription(ctx context.Context, sub *model.UserSubscription) error
CheckAndActivateWindow(ctx context.Context, sub *model.UserSubscription) error
CheckAndResetWindows(ctx context.Context, sub *model.UserSubscription) error
CheckUsageLimits(ctx context.Context, sub *model.UserSubscription, group *model.Group, additionalCost float64) error
}
// ApiKeyAuth API Key认证中间件
func ApiKeyAuth(apiKeyRepo ApiKeyAuthService) gin.HandlerFunc {
return ApiKeyAuthWithSubscription(apiKeyRepo, nil)
}
// ApiKeyAuthWithSubscription API Key认证中间件支持订阅验证
func ApiKeyAuthWithSubscription(apiKeyRepo ApiKeyAuthService, subscriptionService SubscriptionAuthService) gin.HandlerFunc {
return func(c *gin.Context) {
// 尝试从Authorization header中提取API key (Bearer scheme)
authHeader := c.GetHeader("Authorization")
var apiKeyString string
if authHeader != "" {
// 验证Bearer scheme
parts := strings.SplitN(authHeader, " ", 2)
if len(parts) == 2 && parts[0] == "Bearer" {
apiKeyString = parts[1]
}
}
// 如果Authorization header中没有尝试从x-api-key header中提取
if apiKeyString == "" {
apiKeyString = c.GetHeader("x-api-key")
}
// 如果两个header都没有API key
if apiKeyString == "" {
AbortWithError(c, 401, "API_KEY_REQUIRED", "API key is required in Authorization header (Bearer scheme) or x-api-key header")
return
}
// 从数据库验证API key
apiKey, err := apiKeyRepo.GetByKey(c.Request.Context(), apiKeyString)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
AbortWithError(c, 401, "INVALID_API_KEY", "Invalid API key")
return
}
AbortWithError(c, 500, "INTERNAL_ERROR", "Failed to validate API key")
return
}
// 检查API key是否激活
if !apiKey.IsActive() {
AbortWithError(c, 401, "API_KEY_DISABLED", "API key is disabled")
return
}
// 检查关联的用户
if apiKey.User == nil {
AbortWithError(c, 401, "USER_NOT_FOUND", "User associated with API key not found")
return
}
// 检查用户状态
if !apiKey.User.IsActive() {
AbortWithError(c, 401, "USER_INACTIVE", "User account is not active")
return
}
// 判断计费方式:订阅模式 vs 余额模式
isSubscriptionType := apiKey.Group != nil && apiKey.Group.IsSubscriptionType()
if isSubscriptionType && subscriptionService != nil {
// 订阅模式:验证订阅
subscription, err := subscriptionService.GetActiveSubscription(
c.Request.Context(),
apiKey.User.ID,
apiKey.Group.ID,
)
if err != nil {
AbortWithError(c, 403, "SUBSCRIPTION_NOT_FOUND", "No active subscription found for this group")
return
}
// 验证订阅状态(是否过期、暂停等)
if err := subscriptionService.ValidateSubscription(c.Request.Context(), subscription); err != nil {
AbortWithError(c, 403, "SUBSCRIPTION_INVALID", err.Error())
return
}
// 激活滑动窗口(首次使用时)
if err := subscriptionService.CheckAndActivateWindow(c.Request.Context(), subscription); err != nil {
log.Printf("Failed to activate subscription windows: %v", err)
}
// 检查并重置过期窗口
if err := subscriptionService.CheckAndResetWindows(c.Request.Context(), subscription); err != nil {
log.Printf("Failed to reset subscription windows: %v", err)
}
// 预检查用量限制使用0作为额外费用进行预检查
if err := subscriptionService.CheckUsageLimits(c.Request.Context(), subscription, apiKey.Group, 0); err != nil {
AbortWithError(c, 429, "USAGE_LIMIT_EXCEEDED", err.Error())
return
}
// 将订阅信息存入上下文
c.Set(string(ContextKeySubscription), subscription)
} else {
// 余额模式:检查用户余额
if apiKey.User.Balance <= 0 {
AbortWithError(c, 403, "INSUFFICIENT_BALANCE", "Insufficient account balance")
return
}
}
// 将API key和用户信息存入上下文
c.Set(string(ContextKeyApiKey), apiKey)
c.Set(string(ContextKeyUser), apiKey.User)
c.Next()
}
}
// GetApiKeyFromContext 从上下文中获取API key
func GetApiKeyFromContext(c *gin.Context) (*model.ApiKey, bool) {
value, exists := c.Get(string(ContextKeyApiKey))
if !exists {
return nil, false
}
apiKey, ok := value.(*model.ApiKey)
return apiKey, ok
}
// GetSubscriptionFromContext 从上下文中获取订阅信息
func GetSubscriptionFromContext(c *gin.Context) (*model.UserSubscription, bool) {
value, exists := c.Get(string(ContextKeySubscription))
if !exists {
return nil, false
}
subscription, ok := value.(*model.UserSubscription)
return subscription, ok
}

View File

@@ -0,0 +1,24 @@
package middleware
import (
"github.com/gin-gonic/gin"
)
// CORS 跨域中间件
func CORS() gin.HandlerFunc {
return func(c *gin.Context) {
// 设置允许跨域的响应头
c.Writer.Header().Set("Access-Control-Allow-Origin", "*")
c.Writer.Header().Set("Access-Control-Allow-Credentials", "true")
c.Writer.Header().Set("Access-Control-Allow-Headers", "Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization, accept, origin, Cache-Control, X-Requested-With, X-API-Key")
c.Writer.Header().Set("Access-Control-Allow-Methods", "POST, OPTIONS, GET, PUT, DELETE, PATCH")
// 处理预检请求
if c.Request.Method == "OPTIONS" {
c.AbortWithStatus(204)
return
}
c.Next()
}
}

View File

@@ -0,0 +1,76 @@
package middleware
import (
"context"
"strings"
"sub2api/internal/model"
"sub2api/internal/service"
"github.com/gin-gonic/gin"
)
// JWTAuth JWT认证中间件
func JWTAuth(authService *service.AuthService, userRepo interface {
GetByID(ctx context.Context, id int64) (*model.User, error)
}) gin.HandlerFunc {
return func(c *gin.Context) {
// 从Authorization header中提取token
authHeader := c.GetHeader("Authorization")
if authHeader == "" {
AbortWithError(c, 401, "UNAUTHORIZED", "Authorization header is required")
return
}
// 验证Bearer scheme
parts := strings.SplitN(authHeader, " ", 2)
if len(parts) != 2 || parts[0] != "Bearer" {
AbortWithError(c, 401, "INVALID_AUTH_HEADER", "Authorization header format must be 'Bearer {token}'")
return
}
tokenString := parts[1]
if tokenString == "" {
AbortWithError(c, 401, "EMPTY_TOKEN", "Token cannot be empty")
return
}
// 验证token
claims, err := authService.ValidateToken(tokenString)
if err != nil {
if err == service.ErrTokenExpired {
AbortWithError(c, 401, "TOKEN_EXPIRED", "Token has expired")
return
}
AbortWithError(c, 401, "INVALID_TOKEN", "Invalid token")
return
}
// 从数据库获取最新的用户信息
user, err := userRepo.GetByID(c.Request.Context(), claims.UserID)
if err != nil {
AbortWithError(c, 401, "USER_NOT_FOUND", "User not found")
return
}
// 检查用户状态
if !user.IsActive() {
AbortWithError(c, 401, "USER_INACTIVE", "User account is not active")
return
}
// 将用户信息存入上下文
c.Set(string(ContextKeyUser), user)
c.Next()
}
}
// GetUserFromContext 从上下文中获取用户
func GetUserFromContext(c *gin.Context) (*model.User, bool) {
value, exists := c.Get(string(ContextKeyUser))
if !exists {
return nil, false
}
user, ok := value.(*model.User)
return user, ok
}

View File

@@ -0,0 +1,52 @@
package middleware
import (
"log"
"time"
"github.com/gin-gonic/gin"
)
// Logger 请求日志中间件
func Logger() gin.HandlerFunc {
return func(c *gin.Context) {
// 开始时间
startTime := time.Now()
// 处理请求
c.Next()
// 结束时间
endTime := time.Now()
// 执行时间
latency := endTime.Sub(startTime)
// 请求方法
method := c.Request.Method
// 请求路径
path := c.Request.URL.Path
// 状态码
statusCode := c.Writer.Status()
// 客户端IP
clientIP := c.ClientIP()
// 日志格式: [时间] 状态码 | 延迟 | IP | 方法 路径
log.Printf("[GIN] %v | %3d | %13v | %15s | %-7s %s",
endTime.Format("2006/01/02 - 15:04:05"),
statusCode,
latency,
clientIP,
method,
path,
)
// 如果有错误,额外记录错误信息
if len(c.Errors) > 0 {
log.Printf("[GIN] Errors: %v", c.Errors.String())
}
}
}

View File

@@ -0,0 +1,35 @@
package middleware
import "github.com/gin-gonic/gin"
// ContextKey 定义上下文键类型
type ContextKey string
const (
// ContextKeyUser 用户上下文键
ContextKeyUser ContextKey = "user"
// ContextKeyApiKey API密钥上下文键
ContextKeyApiKey ContextKey = "api_key"
// ContextKeySubscription 订阅上下文键
ContextKeySubscription ContextKey = "subscription"
)
// ErrorResponse 标准错误响应结构
type ErrorResponse struct {
Code string `json:"code"`
Message string `json:"message"`
}
// NewErrorResponse 创建错误响应
func NewErrorResponse(code, message string) ErrorResponse {
return ErrorResponse{
Code: code,
Message: message,
}
}
// AbortWithError 中断请求并返回JSON错误
func AbortWithError(c *gin.Context, statusCode int, code, message string) {
c.JSON(statusCode, NewErrorResponse(code, message))
c.Abort()
}

View File

@@ -0,0 +1,265 @@
package model
import (
"database/sql/driver"
"encoding/json"
"errors"
"time"
"gorm.io/gorm"
)
// JSONB 用于存储JSONB数据
type JSONB map[string]interface{}
func (j JSONB) Value() (driver.Value, error) {
if j == nil {
return nil, nil
}
return json.Marshal(j)
}
func (j *JSONB) Scan(value interface{}) error {
if value == nil {
*j = nil
return nil
}
bytes, ok := value.([]byte)
if !ok {
return errors.New("type assertion to []byte failed")
}
return json.Unmarshal(bytes, j)
}
type Account struct {
ID int64 `gorm:"primaryKey" json:"id"`
Name string `gorm:"size:100;not null" json:"name"`
Platform string `gorm:"size:50;not null" json:"platform"` // anthropic/openai/gemini
Type string `gorm:"size:20;not null" json:"type"` // oauth/apikey
Credentials JSONB `gorm:"type:jsonb;default:'{}'" json:"credentials"` // 凭证(加密存储)
Extra JSONB `gorm:"type:jsonb;default:'{}'" json:"extra"` // 扩展信息
ProxyID *int64 `gorm:"index" json:"proxy_id"`
Concurrency int `gorm:"default:3;not null" json:"concurrency"`
Priority int `gorm:"default:50;not null" json:"priority"` // 1-100越小越高
Status string `gorm:"size:20;default:active;not null" json:"status"` // active/disabled/error
ErrorMessage string `gorm:"type:text" json:"error_message"`
LastUsedAt *time.Time `gorm:"index" json:"last_used_at"`
CreatedAt time.Time `gorm:"not null" json:"created_at"`
UpdatedAt time.Time `gorm:"not null" json:"updated_at"`
DeletedAt gorm.DeletedAt `gorm:"index" json:"-"`
// 调度控制
Schedulable bool `gorm:"default:true;not null" json:"schedulable"`
// 限流状态 (429)
RateLimitedAt *time.Time `gorm:"index" json:"rate_limited_at"`
RateLimitResetAt *time.Time `gorm:"index" json:"rate_limit_reset_at"`
// 过载状态 (529)
OverloadUntil *time.Time `gorm:"index" json:"overload_until"`
// 5小时时间窗口
SessionWindowStart *time.Time `json:"session_window_start"`
SessionWindowEnd *time.Time `json:"session_window_end"`
SessionWindowStatus string `gorm:"size:20" json:"session_window_status"` // allowed/allowed_warning/rejected
// 关联
Proxy *Proxy `gorm:"foreignKey:ProxyID" json:"proxy,omitempty"`
AccountGroups []AccountGroup `gorm:"foreignKey:AccountID" json:"account_groups,omitempty"`
// 虚拟字段 (不存储到数据库)
GroupIDs []int64 `gorm:"-" json:"group_ids,omitempty"`
}
func (Account) TableName() string {
return "accounts"
}
// IsActive 检查是否激活
func (a *Account) IsActive() bool {
return a.Status == "active"
}
// IsSchedulable 检查账号是否可调度
func (a *Account) IsSchedulable() bool {
if !a.IsActive() || !a.Schedulable {
return false
}
now := time.Now()
if a.OverloadUntil != nil && now.Before(*a.OverloadUntil) {
return false
}
if a.RateLimitResetAt != nil && now.Before(*a.RateLimitResetAt) {
return false
}
return true
}
// IsRateLimited 检查是否处于限流状态
func (a *Account) IsRateLimited() bool {
if a.RateLimitResetAt == nil {
return false
}
return time.Now().Before(*a.RateLimitResetAt)
}
// IsOverloaded 检查是否处于过载状态
func (a *Account) IsOverloaded() bool {
if a.OverloadUntil == nil {
return false
}
return time.Now().Before(*a.OverloadUntil)
}
// IsOAuth 检查是否为OAuth类型账号包括oauth和setup-token
func (a *Account) IsOAuth() bool {
return a.Type == AccountTypeOAuth || a.Type == AccountTypeSetupToken
}
// CanGetUsage 检查账号是否可以获取usage信息只有oauth类型可以setup-token没有profile权限
func (a *Account) CanGetUsage() bool {
return a.Type == AccountTypeOAuth
}
// GetCredential 获取凭证字段
func (a *Account) GetCredential(key string) string {
if a.Credentials == nil {
return ""
}
if v, ok := a.Credentials[key]; ok {
if s, ok := v.(string); ok {
return s
}
}
return ""
}
// GetModelMapping 获取模型映射配置
// 返回格式: map[请求模型名]实际模型名
func (a *Account) GetModelMapping() map[string]string {
if a.Credentials == nil {
return nil
}
raw, ok := a.Credentials["model_mapping"]
if !ok || raw == nil {
return nil
}
// 处理map[string]interface{}类型
if m, ok := raw.(map[string]interface{}); ok {
result := make(map[string]string)
for k, v := range m {
if s, ok := v.(string); ok {
result[k] = s
}
}
if len(result) > 0 {
return result
}
}
return nil
}
// IsModelSupported 检查请求的模型是否被该账号支持
// 如果没有设置模型映射,则支持所有模型
func (a *Account) IsModelSupported(requestedModel string) bool {
mapping := a.GetModelMapping()
if mapping == nil || len(mapping) == 0 {
return true // 没有映射配置,支持所有模型
}
_, exists := mapping[requestedModel]
return exists
}
// GetMappedModel 获取映射后的实际模型名
// 如果没有映射,返回原始模型名
func (a *Account) GetMappedModel(requestedModel string) string {
mapping := a.GetModelMapping()
if mapping == nil || len(mapping) == 0 {
return requestedModel
}
if mappedModel, exists := mapping[requestedModel]; exists {
return mappedModel
}
return requestedModel
}
// GetBaseURL 获取API基础URL用于apikey类型账号
func (a *Account) GetBaseURL() string {
if a.Type != AccountTypeApiKey {
return ""
}
baseURL := a.GetCredential("base_url")
if baseURL == "" {
return "https://api.anthropic.com" // 默认URL
}
return baseURL
}
// GetExtraString 从Extra字段获取字符串值
func (a *Account) GetExtraString(key string) string {
if a.Extra == nil {
return ""
}
if v, ok := a.Extra[key]; ok {
if s, ok := v.(string); ok {
return s
}
}
return ""
}
// IsCustomErrorCodesEnabled 检查是否启用自定义错误码功能(仅适用于 apikey 类型)
func (a *Account) IsCustomErrorCodesEnabled() bool {
if a.Type != AccountTypeApiKey || a.Credentials == nil {
return false
}
if v, ok := a.Credentials["custom_error_codes_enabled"]; ok {
if enabled, ok := v.(bool); ok {
return enabled
}
}
return false
}
// GetCustomErrorCodes 获取自定义错误码列表
func (a *Account) GetCustomErrorCodes() []int {
if a.Credentials == nil {
return nil
}
raw, ok := a.Credentials["custom_error_codes"]
if !ok || raw == nil {
return nil
}
// 处理 []interface{} 类型JSON反序列化后的格式
if arr, ok := raw.([]interface{}); ok {
result := make([]int, 0, len(arr))
for _, v := range arr {
// JSON 数字默认解析为 float64
if f, ok := v.(float64); ok {
result = append(result, int(f))
}
}
return result
}
return nil
}
// ShouldHandleErrorCode 检查指定错误码是否应该被处理(停止调度/标记限流等)
// 如果未启用自定义错误码或列表为空,返回 true使用默认策略
// 如果启用且列表非空,只有在列表中的错误码才返回 true
func (a *Account) ShouldHandleErrorCode(statusCode int) bool {
if !a.IsCustomErrorCodesEnabled() {
return true // 未启用,使用默认策略
}
codes := a.GetCustomErrorCodes()
if len(codes) == 0 {
return true // 启用但列表为空fallback到默认策略
}
// 检查是否在自定义列表中
for _, code := range codes {
if code == statusCode {
return true
}
}
return false
}

View File

@@ -0,0 +1,20 @@
package model
import (
"time"
)
type AccountGroup struct {
AccountID int64 `gorm:"primaryKey" json:"account_id"`
GroupID int64 `gorm:"primaryKey" json:"group_id"`
Priority int `gorm:"default:50;not null" json:"priority"` // 分组内优先级
CreatedAt time.Time `gorm:"not null" json:"created_at"`
// 关联
Account *Account `gorm:"foreignKey:AccountID" json:"account,omitempty"`
Group *Group `gorm:"foreignKey:GroupID" json:"group,omitempty"`
}
func (AccountGroup) TableName() string {
return "account_groups"
}

View File

@@ -0,0 +1,32 @@
package model
import (
"time"
"gorm.io/gorm"
)
type ApiKey struct {
ID int64 `gorm:"primaryKey" json:"id"`
UserID int64 `gorm:"index;not null" json:"user_id"`
Key string `gorm:"uniqueIndex;size:128;not null" json:"key"` // sk-xxx
Name string `gorm:"size:100;not null" json:"name"`
GroupID *int64 `gorm:"index" json:"group_id"`
Status string `gorm:"size:20;default:active;not null" json:"status"` // active/disabled
CreatedAt time.Time `gorm:"not null" json:"created_at"`
UpdatedAt time.Time `gorm:"not null" json:"updated_at"`
DeletedAt gorm.DeletedAt `gorm:"index" json:"-"`
// 关联
User *User `gorm:"foreignKey:UserID" json:"user,omitempty"`
Group *Group `gorm:"foreignKey:GroupID" json:"group,omitempty"`
}
func (ApiKey) TableName() string {
return "api_keys"
}
// IsActive 检查是否激活
func (k *ApiKey) IsActive() bool {
return k.Status == "active"
}

View File

@@ -0,0 +1,73 @@
package model
import (
"time"
"gorm.io/gorm"
)
// 订阅类型常量
const (
SubscriptionTypeStandard = "standard" // 标准计费模式(按余额扣费)
SubscriptionTypeSubscription = "subscription" // 订阅模式(按限额控制)
)
type Group struct {
ID int64 `gorm:"primaryKey" json:"id"`
Name string `gorm:"uniqueIndex;size:100;not null" json:"name"`
Description string `gorm:"type:text" json:"description"`
Platform string `gorm:"size:50;default:anthropic;not null" json:"platform"` // anthropic/openai/gemini
RateMultiplier float64 `gorm:"type:decimal(10,4);default:1.0;not null" json:"rate_multiplier"`
IsExclusive bool `gorm:"default:false;not null" json:"is_exclusive"`
Status string `gorm:"size:20;default:active;not null" json:"status"` // active/disabled
// 订阅功能字段
SubscriptionType string `gorm:"size:20;default:standard;not null" json:"subscription_type"` // standard/subscription
DailyLimitUSD *float64 `gorm:"type:decimal(20,8)" json:"daily_limit_usd"`
WeeklyLimitUSD *float64 `gorm:"type:decimal(20,8)" json:"weekly_limit_usd"`
MonthlyLimitUSD *float64 `gorm:"type:decimal(20,8)" json:"monthly_limit_usd"`
CreatedAt time.Time `gorm:"not null" json:"created_at"`
UpdatedAt time.Time `gorm:"not null" json:"updated_at"`
DeletedAt gorm.DeletedAt `gorm:"index" json:"-"`
// 关联
AccountGroups []AccountGroup `gorm:"foreignKey:GroupID" json:"account_groups,omitempty"`
// 虚拟字段 (不存储到数据库)
AccountCount int64 `gorm:"-" json:"account_count,omitempty"`
}
func (Group) TableName() string {
return "groups"
}
// IsActive 检查是否激活
func (g *Group) IsActive() bool {
return g.Status == "active"
}
// IsSubscriptionType 检查是否为订阅类型分组
func (g *Group) IsSubscriptionType() bool {
return g.SubscriptionType == SubscriptionTypeSubscription
}
// IsFreeSubscription 检查是否为免费订阅(不扣余额但有限额)
func (g *Group) IsFreeSubscription() bool {
return g.IsSubscriptionType() && g.RateMultiplier == 0
}
// HasDailyLimit 检查是否有日限额
func (g *Group) HasDailyLimit() bool {
return g.DailyLimitUSD != nil && *g.DailyLimitUSD > 0
}
// HasWeeklyLimit 检查是否有周限额
func (g *Group) HasWeeklyLimit() bool {
return g.WeeklyLimitUSD != nil && *g.WeeklyLimitUSD > 0
}
// HasMonthlyLimit 检查是否有月限额
func (g *Group) HasMonthlyLimit() bool {
return g.MonthlyLimitUSD != nil && *g.MonthlyLimitUSD > 0
}

View File

@@ -0,0 +1,64 @@
package model
import (
"gorm.io/gorm"
)
// AutoMigrate 自动迁移所有模型
func AutoMigrate(db *gorm.DB) error {
return db.AutoMigrate(
&User{},
&ApiKey{},
&Group{},
&Account{},
&AccountGroup{},
&Proxy{},
&RedeemCode{},
&UsageLog{},
&Setting{},
&UserSubscription{},
)
}
// 状态常量
const (
StatusActive = "active"
StatusDisabled = "disabled"
StatusError = "error"
StatusUnused = "unused"
StatusUsed = "used"
StatusExpired = "expired"
)
// 角色常量
const (
RoleAdmin = "admin"
RoleUser = "user"
)
// 平台常量
const (
PlatformAnthropic = "anthropic"
PlatformOpenAI = "openai"
PlatformGemini = "gemini"
)
// 账号类型常量
const (
AccountTypeOAuth = "oauth" // OAuth类型账号full scope: profile + inference
AccountTypeSetupToken = "setup-token" // Setup Token类型账号inference only scope
AccountTypeApiKey = "apikey" // API Key类型账号
)
// 卡密类型常量
const (
RedeemTypeBalance = "balance"
RedeemTypeConcurrency = "concurrency"
RedeemTypeSubscription = "subscription"
)
// 管理员调整类型常量
const (
AdjustmentTypeAdminBalance = "admin_balance" // 管理员调整余额
AdjustmentTypeAdminConcurrency = "admin_concurrency" // 管理员调整并发数
)

View File

@@ -0,0 +1,45 @@
package model
import (
"fmt"
"time"
"gorm.io/gorm"
)
type Proxy struct {
ID int64 `gorm:"primaryKey" json:"id"`
Name string `gorm:"size:100;not null" json:"name"`
Protocol string `gorm:"size:20;not null" json:"protocol"` // http/https/socks5
Host string `gorm:"size:255;not null" json:"host"`
Port int `gorm:"not null" json:"port"`
Username string `gorm:"size:100" json:"username"`
Password string `gorm:"size:100" json:"-"`
Status string `gorm:"size:20;default:active;not null" json:"status"` // active/disabled
CreatedAt time.Time `gorm:"not null" json:"created_at"`
UpdatedAt time.Time `gorm:"not null" json:"updated_at"`
DeletedAt gorm.DeletedAt `gorm:"index" json:"-"`
}
func (Proxy) TableName() string {
return "proxies"
}
// IsActive 检查是否激活
func (p *Proxy) IsActive() bool {
return p.Status == "active"
}
// URL 返回代理URL
func (p *Proxy) URL() string {
if p.Username != "" && p.Password != "" {
return fmt.Sprintf("%s://%s:%s@%s:%d", p.Protocol, p.Username, p.Password, p.Host, p.Port)
}
return fmt.Sprintf("%s://%s:%d", p.Protocol, p.Host, p.Port)
}
// ProxyWithAccountCount extends Proxy with account count information
type ProxyWithAccountCount struct {
Proxy
AccountCount int64 `json:"account_count"`
}

View File

@@ -0,0 +1,47 @@
package model
import (
"crypto/rand"
"encoding/hex"
"time"
)
type RedeemCode struct {
ID int64 `gorm:"primaryKey" json:"id"`
Code string `gorm:"uniqueIndex;size:32;not null" json:"code"`
Type string `gorm:"size:20;default:balance;not null" json:"type"` // balance/concurrency/subscription
Value float64 `gorm:"type:decimal(20,8);not null" json:"value"` // 面值(USD)或并发数或有效天数
Status string `gorm:"size:20;default:unused;not null" json:"status"` // unused/used
UsedBy *int64 `gorm:"index" json:"used_by"`
UsedAt *time.Time `json:"used_at"`
CreatedAt time.Time `gorm:"not null" json:"created_at"`
// 订阅类型专用字段
GroupID *int64 `gorm:"index" json:"group_id"` // 订阅分组ID (仅subscription类型使用)
ValidityDays int `gorm:"default:30" json:"validity_days"` // 订阅有效天数 (仅subscription类型使用)
// 关联
User *User `gorm:"foreignKey:UsedBy" json:"user,omitempty"`
Group *Group `gorm:"foreignKey:GroupID" json:"group,omitempty"`
}
func (RedeemCode) TableName() string {
return "redeem_codes"
}
// IsUsed 检查是否已使用
func (r *RedeemCode) IsUsed() bool {
return r.Status == "used"
}
// CanUse 检查是否可以使用
func (r *RedeemCode) CanUse() bool {
return r.Status == "unused"
}
// GenerateRedeemCode 生成唯一的兑换码
func GenerateRedeemCode() string {
b := make([]byte, 16)
rand.Read(b)
return hex.EncodeToString(b)
}

View File

@@ -0,0 +1,95 @@
package model
import (
"time"
)
// Setting 系统设置模型Key-Value存储
type Setting struct {
ID int64 `gorm:"primaryKey" json:"id"`
Key string `gorm:"uniqueIndex;size:100;not null" json:"key"`
Value string `gorm:"type:text;not null" json:"value"`
UpdatedAt time.Time `gorm:"not null" json:"updated_at"`
}
func (Setting) TableName() string {
return "settings"
}
// 设置Key常量
const (
// 注册设置
SettingKeyRegistrationEnabled = "registration_enabled" // 是否开放注册
SettingKeyEmailVerifyEnabled = "email_verify_enabled" // 是否开启邮件验证
// 邮件服务设置
SettingKeySmtpHost = "smtp_host" // SMTP服务器地址
SettingKeySmtpPort = "smtp_port" // SMTP端口
SettingKeySmtpUsername = "smtp_username" // SMTP用户名
SettingKeySmtpPassword = "smtp_password" // SMTP密码加密存储
SettingKeySmtpFrom = "smtp_from" // 发件人地址
SettingKeySmtpFromName = "smtp_from_name" // 发件人名称
SettingKeySmtpUseTLS = "smtp_use_tls" // 是否使用TLS
// Cloudflare Turnstile 设置
SettingKeyTurnstileEnabled = "turnstile_enabled" // 是否启用 Turnstile 验证
SettingKeyTurnstileSiteKey = "turnstile_site_key" // Turnstile Site Key
SettingKeyTurnstileSecretKey = "turnstile_secret_key" // Turnstile Secret Key
// OEM设置
SettingKeySiteName = "site_name" // 网站名称
SettingKeySiteLogo = "site_logo" // 网站Logo (base64)
SettingKeySiteSubtitle = "site_subtitle" // 网站副标题
SettingKeyApiBaseUrl = "api_base_url" // API端点地址用于客户端配置和导入
SettingKeyContactInfo = "contact_info" // 客服联系方式
// 默认配置
SettingKeyDefaultConcurrency = "default_concurrency" // 新用户默认并发量
SettingKeyDefaultBalance = "default_balance" // 新用户默认余额
)
// SystemSettings 系统设置结构体用于API响应
type SystemSettings struct {
// 注册设置
RegistrationEnabled bool `json:"registration_enabled"`
EmailVerifyEnabled bool `json:"email_verify_enabled"`
// 邮件服务设置
SmtpHost string `json:"smtp_host"`
SmtpPort int `json:"smtp_port"`
SmtpUsername string `json:"smtp_username"`
SmtpPassword string `json:"smtp_password,omitempty"` // 不返回明文密码
SmtpFrom string `json:"smtp_from_email"`
SmtpFromName string `json:"smtp_from_name"`
SmtpUseTLS bool `json:"smtp_use_tls"`
// Cloudflare Turnstile 设置
TurnstileEnabled bool `json:"turnstile_enabled"`
TurnstileSiteKey string `json:"turnstile_site_key"`
TurnstileSecretKey string `json:"turnstile_secret_key,omitempty"` // 不返回明文密钥
// OEM设置
SiteName string `json:"site_name"`
SiteLogo string `json:"site_logo"`
SiteSubtitle string `json:"site_subtitle"`
ApiBaseUrl string `json:"api_base_url"`
ContactInfo string `json:"contact_info"`
// 默认配置
DefaultConcurrency int `json:"default_concurrency"`
DefaultBalance float64 `json:"default_balance"`
}
// PublicSettings 公开设置(无需登录即可获取)
type PublicSettings struct {
RegistrationEnabled bool `json:"registration_enabled"`
EmailVerifyEnabled bool `json:"email_verify_enabled"`
TurnstileEnabled bool `json:"turnstile_enabled"`
TurnstileSiteKey string `json:"turnstile_site_key"`
SiteName string `json:"site_name"`
SiteLogo string `json:"site_logo"`
SiteSubtitle string `json:"site_subtitle"`
ApiBaseUrl string `json:"api_base_url"`
ContactInfo string `json:"contact_info"`
Version string `json:"version"`
}

View File

@@ -0,0 +1,67 @@
package model
import (
"time"
)
// 消费类型常量
const (
BillingTypeBalance int8 = 0 // 钱包余额
BillingTypeSubscription int8 = 1 // 订阅套餐
)
type UsageLog struct {
ID int64 `gorm:"primaryKey" json:"id"`
UserID int64 `gorm:"index;not null" json:"user_id"`
ApiKeyID int64 `gorm:"index;not null" json:"api_key_id"`
AccountID int64 `gorm:"index;not null" json:"account_id"`
RequestID string `gorm:"size:64" json:"request_id"`
Model string `gorm:"size:100;index;not null" json:"model"`
// 订阅关联(可选)
GroupID *int64 `gorm:"index" json:"group_id"`
SubscriptionID *int64 `gorm:"index" json:"subscription_id"`
// Token使用量4类
InputTokens int `gorm:"default:0;not null" json:"input_tokens"`
OutputTokens int `gorm:"default:0;not null" json:"output_tokens"`
CacheCreationTokens int `gorm:"default:0;not null" json:"cache_creation_tokens"`
CacheReadTokens int `gorm:"default:0;not null" json:"cache_read_tokens"`
// 详细的缓存创建分类
CacheCreation5mTokens int `gorm:"default:0;not null" json:"cache_creation_5m_tokens"`
CacheCreation1hTokens int `gorm:"default:0;not null" json:"cache_creation_1h_tokens"`
// 费用USD
InputCost float64 `gorm:"type:decimal(20,10);default:0;not null" json:"input_cost"`
OutputCost float64 `gorm:"type:decimal(20,10);default:0;not null" json:"output_cost"`
CacheCreationCost float64 `gorm:"type:decimal(20,10);default:0;not null" json:"cache_creation_cost"`
CacheReadCost float64 `gorm:"type:decimal(20,10);default:0;not null" json:"cache_read_cost"`
TotalCost float64 `gorm:"type:decimal(20,10);default:0;not null" json:"total_cost"` // 原始总费用
ActualCost float64 `gorm:"type:decimal(20,10);default:0;not null" json:"actual_cost"` // 实际扣除费用
RateMultiplier float64 `gorm:"type:decimal(10,4);default:1;not null" json:"rate_multiplier"` // 计费倍率
// 元数据
BillingType int8 `gorm:"type:smallint;default:0;not null" json:"billing_type"` // 0=余额 1=订阅
Stream bool `gorm:"default:false;not null" json:"stream"`
DurationMs *int `json:"duration_ms"`
FirstTokenMs *int `json:"first_token_ms"` // 首字时间(流式请求)
CreatedAt time.Time `gorm:"index;not null" json:"created_at"`
// 关联
User *User `gorm:"foreignKey:UserID" json:"user,omitempty"`
ApiKey *ApiKey `gorm:"foreignKey:ApiKeyID" json:"api_key,omitempty"`
Account *Account `gorm:"foreignKey:AccountID" json:"account,omitempty"`
Group *Group `gorm:"foreignKey:GroupID" json:"group,omitempty"`
Subscription *UserSubscription `gorm:"foreignKey:SubscriptionID" json:"subscription,omitempty"`
}
func (UsageLog) TableName() string {
return "usage_logs"
}
// TotalTokens 总token数
func (u *UsageLog) TotalTokens() int {
return u.InputTokens + u.OutputTokens + u.CacheCreationTokens + u.CacheReadTokens
}

View File

@@ -0,0 +1,74 @@
package model
import (
"time"
"github.com/lib/pq"
"golang.org/x/crypto/bcrypt"
"gorm.io/gorm"
)
type User struct {
ID int64 `gorm:"primaryKey" json:"id"`
Email string `gorm:"uniqueIndex;size:255;not null" json:"email"`
PasswordHash string `gorm:"size:255;not null" json:"-"`
Role string `gorm:"size:20;default:user;not null" json:"role"` // admin/user
Balance float64 `gorm:"type:decimal(20,8);default:0;not null" json:"balance"`
Concurrency int `gorm:"default:5;not null" json:"concurrency"`
Status string `gorm:"size:20;default:active;not null" json:"status"` // active/disabled
AllowedGroups pq.Int64Array `gorm:"type:bigint[]" json:"allowed_groups"`
CreatedAt time.Time `gorm:"not null" json:"created_at"`
UpdatedAt time.Time `gorm:"not null" json:"updated_at"`
DeletedAt gorm.DeletedAt `gorm:"index" json:"-"`
// 关联
ApiKeys []ApiKey `gorm:"foreignKey:UserID" json:"api_keys,omitempty"`
}
func (User) TableName() string {
return "users"
}
// IsAdmin 检查是否管理员
func (u *User) IsAdmin() bool {
return u.Role == "admin"
}
// IsActive 检查是否激活
func (u *User) IsActive() bool {
return u.Status == "active"
}
// CanBindGroup 检查是否可以绑定指定分组
// 对于标准类型分组:
// - 如果 AllowedGroups 设置了值(非空数组),只能绑定列表中的分组
// - 如果 AllowedGroups 为 nil 或空数组,可以绑定所有非专属分组
func (u *User) CanBindGroup(groupID int64, isExclusive bool) bool {
// 如果设置了 allowed_groups 且不为空,只能绑定指定的分组
if len(u.AllowedGroups) > 0 {
for _, id := range u.AllowedGroups {
if id == groupID {
return true
}
}
return false
}
// 如果没有设置 allowed_groups 或为空数组,可以绑定所有非专属分组
return !isExclusive
}
// SetPassword 设置密码(哈希存储)
func (u *User) SetPassword(password string) error {
hash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
if err != nil {
return err
}
u.PasswordHash = string(hash)
return nil
}
// CheckPassword 验证密码
func (u *User) CheckPassword(password string) bool {
err := bcrypt.CompareHashAndPassword([]byte(u.PasswordHash), []byte(password))
return err == nil
}

View File

@@ -0,0 +1,157 @@
package model
import (
"time"
)
// 订阅状态常量
const (
SubscriptionStatusActive = "active"
SubscriptionStatusExpired = "expired"
SubscriptionStatusSuspended = "suspended"
)
// UserSubscription 用户订阅模型
type UserSubscription struct {
ID int64 `gorm:"primaryKey" json:"id"`
UserID int64 `gorm:"index;not null" json:"user_id"`
GroupID int64 `gorm:"index;not null" json:"group_id"`
// 订阅有效期
StartsAt time.Time `gorm:"not null" json:"starts_at"`
ExpiresAt time.Time `gorm:"not null" json:"expires_at"`
Status string `gorm:"size:20;default:active;not null" json:"status"` // active/expired/suspended
// 滑动窗口起始时间nil = 未激活)
DailyWindowStart *time.Time `json:"daily_window_start"`
WeeklyWindowStart *time.Time `json:"weekly_window_start"`
MonthlyWindowStart *time.Time `json:"monthly_window_start"`
// 当前窗口已用额度USD基于 total_cost 计算)
DailyUsageUSD float64 `gorm:"type:decimal(20,10);default:0;not null" json:"daily_usage_usd"`
WeeklyUsageUSD float64 `gorm:"type:decimal(20,10);default:0;not null" json:"weekly_usage_usd"`
MonthlyUsageUSD float64 `gorm:"type:decimal(20,10);default:0;not null" json:"monthly_usage_usd"`
// 管理员分配信息
AssignedBy *int64 `gorm:"index" json:"assigned_by"`
AssignedAt time.Time `gorm:"not null" json:"assigned_at"`
Notes string `gorm:"type:text" json:"notes"`
CreatedAt time.Time `gorm:"not null" json:"created_at"`
UpdatedAt time.Time `gorm:"not null" json:"updated_at"`
// 关联
User *User `gorm:"foreignKey:UserID" json:"user,omitempty"`
Group *Group `gorm:"foreignKey:GroupID" json:"group,omitempty"`
AssignedByUser *User `gorm:"foreignKey:AssignedBy" json:"assigned_by_user,omitempty"`
}
func (UserSubscription) TableName() string {
return "user_subscriptions"
}
// IsActive 检查订阅是否有效状态为active且未过期
func (s *UserSubscription) IsActive() bool {
return s.Status == SubscriptionStatusActive && time.Now().Before(s.ExpiresAt)
}
// IsExpired 检查订阅是否已过期
func (s *UserSubscription) IsExpired() bool {
return time.Now().After(s.ExpiresAt)
}
// DaysRemaining 返回订阅剩余天数
func (s *UserSubscription) DaysRemaining() int {
if s.IsExpired() {
return 0
}
return int(time.Until(s.ExpiresAt).Hours() / 24)
}
// IsWindowActivated 检查窗口是否已激活
func (s *UserSubscription) IsWindowActivated() bool {
return s.DailyWindowStart != nil || s.WeeklyWindowStart != nil || s.MonthlyWindowStart != nil
}
// NeedsDailyReset 检查日窗口是否需要重置
func (s *UserSubscription) NeedsDailyReset() bool {
if s.DailyWindowStart == nil {
return false
}
return time.Since(*s.DailyWindowStart) >= 24*time.Hour
}
// NeedsWeeklyReset 检查周窗口是否需要重置
func (s *UserSubscription) NeedsWeeklyReset() bool {
if s.WeeklyWindowStart == nil {
return false
}
return time.Since(*s.WeeklyWindowStart) >= 7*24*time.Hour
}
// NeedsMonthlyReset 检查月窗口是否需要重置
func (s *UserSubscription) NeedsMonthlyReset() bool {
if s.MonthlyWindowStart == nil {
return false
}
return time.Since(*s.MonthlyWindowStart) >= 30*24*time.Hour
}
// DailyResetTime 返回日窗口重置时间
func (s *UserSubscription) DailyResetTime() *time.Time {
if s.DailyWindowStart == nil {
return nil
}
t := s.DailyWindowStart.Add(24 * time.Hour)
return &t
}
// WeeklyResetTime 返回周窗口重置时间
func (s *UserSubscription) WeeklyResetTime() *time.Time {
if s.WeeklyWindowStart == nil {
return nil
}
t := s.WeeklyWindowStart.Add(7 * 24 * time.Hour)
return &t
}
// MonthlyResetTime 返回月窗口重置时间
func (s *UserSubscription) MonthlyResetTime() *time.Time {
if s.MonthlyWindowStart == nil {
return nil
}
t := s.MonthlyWindowStart.Add(30 * 24 * time.Hour)
return &t
}
// CheckDailyLimit 检查是否超出日限额
func (s *UserSubscription) CheckDailyLimit(group *Group, additionalCost float64) bool {
if !group.HasDailyLimit() {
return true // 无限制
}
return s.DailyUsageUSD+additionalCost <= *group.DailyLimitUSD
}
// CheckWeeklyLimit 检查是否超出周限额
func (s *UserSubscription) CheckWeeklyLimit(group *Group, additionalCost float64) bool {
if !group.HasWeeklyLimit() {
return true // 无限制
}
return s.WeeklyUsageUSD+additionalCost <= *group.WeeklyLimitUSD
}
// CheckMonthlyLimit 检查是否超出月限额
func (s *UserSubscription) CheckMonthlyLimit(group *Group, additionalCost float64) bool {
if !group.HasMonthlyLimit() {
return true // 无限制
}
return s.MonthlyUsageUSD+additionalCost <= *group.MonthlyLimitUSD
}
// CheckAllLimits 检查所有限额
func (s *UserSubscription) CheckAllLimits(group *Group, additionalCost float64) (daily, weekly, monthly bool) {
daily = s.CheckDailyLimit(group, additionalCost)
weekly = s.CheckWeeklyLimit(group, additionalCost)
monthly = s.CheckMonthlyLimit(group, additionalCost)
return
}

View File

@@ -0,0 +1,223 @@
package oauth
import (
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"encoding/hex"
"fmt"
"net/url"
"strings"
"sync"
"time"
)
// Claude OAuth Constants (from CRS project)
const (
// OAuth Client ID for Claude
ClientID = "9d1c250a-e61b-44d9-88ed-5944d1962f5e"
// OAuth endpoints
AuthorizeURL = "https://claude.ai/oauth/authorize"
TokenURL = "https://console.anthropic.com/v1/oauth/token"
RedirectURI = "https://console.anthropic.com/oauth/code/callback"
// Scopes
ScopeProfile = "user:profile"
ScopeInference = "user:inference"
// Session TTL
SessionTTL = 30 * time.Minute
)
// OAuthSession stores OAuth flow state
type OAuthSession struct {
State string `json:"state"`
CodeVerifier string `json:"code_verifier"`
Scope string `json:"scope"`
ProxyURL string `json:"proxy_url,omitempty"`
CreatedAt time.Time `json:"created_at"`
}
// SessionStore manages OAuth sessions in memory
type SessionStore struct {
mu sync.RWMutex
sessions map[string]*OAuthSession
}
// NewSessionStore creates a new session store
func NewSessionStore() *SessionStore {
store := &SessionStore{
sessions: make(map[string]*OAuthSession),
}
// Start cleanup goroutine
go store.cleanup()
return store
}
// Set stores a session
func (s *SessionStore) Set(sessionID string, session *OAuthSession) {
s.mu.Lock()
defer s.mu.Unlock()
s.sessions[sessionID] = session
}
// Get retrieves a session
func (s *SessionStore) Get(sessionID string) (*OAuthSession, bool) {
s.mu.RLock()
defer s.mu.RUnlock()
session, ok := s.sessions[sessionID]
if !ok {
return nil, false
}
// Check if expired
if time.Since(session.CreatedAt) > SessionTTL {
return nil, false
}
return session, true
}
// Delete removes a session
func (s *SessionStore) Delete(sessionID string) {
s.mu.Lock()
defer s.mu.Unlock()
delete(s.sessions, sessionID)
}
// cleanup removes expired sessions periodically
func (s *SessionStore) cleanup() {
ticker := time.NewTicker(5 * time.Minute)
for range ticker.C {
s.mu.Lock()
for id, session := range s.sessions {
if time.Since(session.CreatedAt) > SessionTTL {
delete(s.sessions, id)
}
}
s.mu.Unlock()
}
}
// GenerateRandomBytes generates cryptographically secure random bytes
func GenerateRandomBytes(n int) ([]byte, error) {
b := make([]byte, n)
_, err := rand.Read(b)
if err != nil {
return nil, err
}
return b, nil
}
// GenerateState generates a random state string for OAuth
func GenerateState() (string, error) {
bytes, err := GenerateRandomBytes(32)
if err != nil {
return "", err
}
return hex.EncodeToString(bytes), nil
}
// GenerateSessionID generates a unique session ID
func GenerateSessionID() (string, error) {
bytes, err := GenerateRandomBytes(16)
if err != nil {
return "", err
}
return hex.EncodeToString(bytes), nil
}
// GenerateCodeVerifier generates a PKCE code verifier (32 bytes -> base64url)
func GenerateCodeVerifier() (string, error) {
bytes, err := GenerateRandomBytes(32)
if err != nil {
return "", err
}
return base64URLEncode(bytes), nil
}
// GenerateCodeChallenge generates a PKCE code challenge using S256 method
func GenerateCodeChallenge(verifier string) string {
hash := sha256.Sum256([]byte(verifier))
return base64URLEncode(hash[:])
}
// base64URLEncode encodes bytes to base64url without padding
func base64URLEncode(data []byte) string {
encoded := base64.URLEncoding.EncodeToString(data)
// Remove padding
return strings.TrimRight(encoded, "=")
}
// BuildAuthorizationURL builds the OAuth authorization URL
func BuildAuthorizationURL(state, codeChallenge, scope string) string {
params := url.Values{}
params.Set("response_type", "code")
params.Set("client_id", ClientID)
params.Set("redirect_uri", RedirectURI)
params.Set("scope", scope)
params.Set("state", state)
params.Set("code_challenge", codeChallenge)
params.Set("code_challenge_method", "S256")
return fmt.Sprintf("%s?%s", AuthorizeURL, params.Encode())
}
// TokenRequest represents the token exchange request body
type TokenRequest struct {
GrantType string `json:"grant_type"`
ClientID string `json:"client_id"`
Code string `json:"code"`
RedirectURI string `json:"redirect_uri"`
CodeVerifier string `json:"code_verifier"`
State string `json:"state"`
}
// TokenResponse represents the token response from OAuth provider
type TokenResponse struct {
AccessToken string `json:"access_token"`
TokenType string `json:"token_type"`
ExpiresIn int64 `json:"expires_in"`
RefreshToken string `json:"refresh_token,omitempty"`
Scope string `json:"scope,omitempty"`
// Organization and Account info from OAuth response
Organization *OrgInfo `json:"organization,omitempty"`
Account *AccountInfo `json:"account,omitempty"`
}
// OrgInfo represents organization info from OAuth response
type OrgInfo struct {
UUID string `json:"uuid"`
}
// AccountInfo represents account info from OAuth response
type AccountInfo struct {
UUID string `json:"uuid"`
}
// RefreshTokenRequest represents the refresh token request
type RefreshTokenRequest struct {
GrantType string `json:"grant_type"`
RefreshToken string `json:"refresh_token"`
ClientID string `json:"client_id"`
}
// BuildTokenRequest creates a token exchange request
func BuildTokenRequest(code, codeVerifier, state string) *TokenRequest {
return &TokenRequest{
GrantType: "authorization_code",
ClientID: ClientID,
Code: code,
RedirectURI: RedirectURI,
CodeVerifier: codeVerifier,
State: state,
}
}
// BuildRefreshTokenRequest creates a refresh token request
func BuildRefreshTokenRequest(refreshToken string) *RefreshTokenRequest {
return &RefreshTokenRequest{
GrantType: "refresh_token",
RefreshToken: refreshToken,
ClientID: ClientID,
}
}

View File

@@ -0,0 +1,157 @@
package response
import (
"math"
"net/http"
"github.com/gin-gonic/gin"
)
// Response 标准API响应格式
type Response struct {
Code int `json:"code"`
Message string `json:"message"`
Data interface{} `json:"data,omitempty"`
}
// PaginatedData 分页数据格式(匹配前端期望)
type PaginatedData struct {
Items interface{} `json:"items"`
Total int64 `json:"total"`
Page int `json:"page"`
PageSize int `json:"page_size"`
Pages int `json:"pages"`
}
// Success 返回成功响应
func Success(c *gin.Context, data interface{}) {
c.JSON(http.StatusOK, Response{
Code: 0,
Message: "success",
Data: data,
})
}
// Created 返回创建成功响应
func Created(c *gin.Context, data interface{}) {
c.JSON(http.StatusCreated, Response{
Code: 0,
Message: "success",
Data: data,
})
}
// Error 返回错误响应
func Error(c *gin.Context, statusCode int, message string) {
c.JSON(statusCode, Response{
Code: statusCode,
Message: message,
})
}
// BadRequest 返回400错误
func BadRequest(c *gin.Context, message string) {
Error(c, http.StatusBadRequest, message)
}
// Unauthorized 返回401错误
func Unauthorized(c *gin.Context, message string) {
Error(c, http.StatusUnauthorized, message)
}
// Forbidden 返回403错误
func Forbidden(c *gin.Context, message string) {
Error(c, http.StatusForbidden, message)
}
// NotFound 返回404错误
func NotFound(c *gin.Context, message string) {
Error(c, http.StatusNotFound, message)
}
// InternalError 返回500错误
func InternalError(c *gin.Context, message string) {
Error(c, http.StatusInternalServerError, message)
}
// Paginated 返回分页数据
func Paginated(c *gin.Context, items interface{}, total int64, page, pageSize int) {
pages := int(math.Ceil(float64(total) / float64(pageSize)))
if pages < 1 {
pages = 1
}
Success(c, PaginatedData{
Items: items,
Total: total,
Page: page,
PageSize: pageSize,
Pages: pages,
})
}
// PaginationResult 分页结果与repository.PaginationResult兼容
type PaginationResult struct {
Total int64
Page int
PageSize int
Pages int
}
// PaginatedWithResult 使用PaginationResult返回分页数据
func PaginatedWithResult(c *gin.Context, items interface{}, pagination *PaginationResult) {
if pagination == nil {
Success(c, PaginatedData{
Items: items,
Total: 0,
Page: 1,
PageSize: 20,
Pages: 1,
})
return
}
Success(c, PaginatedData{
Items: items,
Total: pagination.Total,
Page: pagination.Page,
PageSize: pagination.PageSize,
Pages: pagination.Pages,
})
}
// ParsePagination 解析分页参数
func ParsePagination(c *gin.Context) (page, pageSize int) {
page = 1
pageSize = 20
if p := c.Query("page"); p != "" {
if val, err := parseInt(p); err == nil && val > 0 {
page = val
}
}
// 支持 page_size 和 limit 两种参数名
if ps := c.Query("page_size"); ps != "" {
if val, err := parseInt(ps); err == nil && val > 0 && val <= 100 {
pageSize = val
}
} else if l := c.Query("limit"); l != "" {
if val, err := parseInt(l); err == nil && val > 0 && val <= 100 {
pageSize = val
}
}
return page, pageSize
}
func parseInt(s string) (int, error) {
var result int
for _, c := range s {
if c < '0' || c > '9' {
return 0, nil
}
result = result*10 + int(c-'0')
}
return result, nil
}

View File

@@ -0,0 +1,124 @@
// Package timezone provides global timezone management for the application.
// Similar to PHP's date_default_timezone_set, this package allows setting
// a global timezone that affects all time.Now() calls.
package timezone
import (
"fmt"
"log"
"time"
)
var (
// location is the global timezone location
location *time.Location
// tzName stores the timezone name for logging/debugging
tzName string
)
// Init initializes the global timezone setting.
// This should be called once at application startup.
// Example timezone values: "Asia/Shanghai", "America/New_York", "UTC"
func Init(tz string) error {
if tz == "" {
tz = "Asia/Shanghai" // Default timezone
}
loc, err := time.LoadLocation(tz)
if err != nil {
return fmt.Errorf("invalid timezone %q: %w", tz, err)
}
// Set the global Go time.Local to our timezone
// This affects time.Now() throughout the application
time.Local = loc
location = loc
tzName = tz
log.Printf("Timezone initialized: %s (UTC offset: %s)", tz, getUTCOffset(loc))
return nil
}
// getUTCOffset returns the current UTC offset for a location
func getUTCOffset(loc *time.Location) string {
_, offset := time.Now().In(loc).Zone()
hours := offset / 3600
minutes := (offset % 3600) / 60
if minutes < 0 {
minutes = -minutes
}
sign := "+"
if hours < 0 {
sign = "-"
hours = -hours
}
return fmt.Sprintf("%s%02d:%02d", sign, hours, minutes)
}
// Now returns the current time in the configured timezone.
// This is equivalent to time.Now() after Init() is called,
// but provided for explicit timezone-aware code.
func Now() time.Time {
if location == nil {
return time.Now()
}
return time.Now().In(location)
}
// Location returns the configured timezone location.
func Location() *time.Location {
if location == nil {
return time.Local
}
return location
}
// Name returns the configured timezone name.
func Name() string {
if tzName == "" {
return "Local"
}
return tzName
}
// StartOfDay returns the start of the given day (00:00:00) in the configured timezone.
func StartOfDay(t time.Time) time.Time {
loc := Location()
t = t.In(loc)
return time.Date(t.Year(), t.Month(), t.Day(), 0, 0, 0, 0, loc)
}
// Today returns the start of today (00:00:00) in the configured timezone.
func Today() time.Time {
return StartOfDay(Now())
}
// EndOfDay returns the end of the given day (23:59:59.999999999) in the configured timezone.
func EndOfDay(t time.Time) time.Time {
loc := Location()
t = t.In(loc)
return time.Date(t.Year(), t.Month(), t.Day(), 23, 59, 59, 999999999, loc)
}
// StartOfWeek returns the start of the week (Monday 00:00:00) for the given time.
func StartOfWeek(t time.Time) time.Time {
loc := Location()
t = t.In(loc)
weekday := int(t.Weekday())
if weekday == 0 {
weekday = 7 // Sunday is day 7
}
return time.Date(t.Year(), t.Month(), t.Day()-weekday+1, 0, 0, 0, 0, loc)
}
// StartOfMonth returns the start of the month (1st day 00:00:00) for the given time.
func StartOfMonth(t time.Time) time.Time {
loc := Location()
t = t.In(loc)
return time.Date(t.Year(), t.Month(), 1, 0, 0, 0, 0, loc)
}
// ParseInLocation parses a time string in the configured timezone.
func ParseInLocation(layout, value string) (time.Time, error) {
return time.ParseInLocation(layout, value, Location())
}

View File

@@ -0,0 +1,127 @@
package timezone
import (
"testing"
"time"
)
func TestInit(t *testing.T) {
// Test with valid timezone
err := Init("Asia/Shanghai")
if err != nil {
t.Fatalf("Init failed with valid timezone: %v", err)
}
// Verify time.Local was set
if time.Local.String() != "Asia/Shanghai" {
t.Errorf("time.Local not set correctly, got %s", time.Local.String())
}
// Verify our location variable
if Location().String() != "Asia/Shanghai" {
t.Errorf("Location() not set correctly, got %s", Location().String())
}
// Test Name()
if Name() != "Asia/Shanghai" {
t.Errorf("Name() not set correctly, got %s", Name())
}
}
func TestInitInvalidTimezone(t *testing.T) {
err := Init("Invalid/Timezone")
if err == nil {
t.Error("Init should fail with invalid timezone")
}
}
func TestTimeNowAffected(t *testing.T) {
// Reset to UTC first
Init("UTC")
utcNow := time.Now()
// Switch to Shanghai (UTC+8)
Init("Asia/Shanghai")
shanghaiNow := time.Now()
// The times should be the same instant, but different timezone representation
// Shanghai should be 8 hours ahead in display
_, utcOffset := utcNow.Zone()
_, shanghaiOffset := shanghaiNow.Zone()
expectedDiff := 8 * 3600 // 8 hours in seconds
actualDiff := shanghaiOffset - utcOffset
if actualDiff != expectedDiff {
t.Errorf("Timezone offset difference incorrect: expected %d, got %d", expectedDiff, actualDiff)
}
}
func TestToday(t *testing.T) {
Init("Asia/Shanghai")
today := Today()
now := Now()
// Today should be at 00:00:00
if today.Hour() != 0 || today.Minute() != 0 || today.Second() != 0 {
t.Errorf("Today() not at start of day: %v", today)
}
// Today should be same date as now
if today.Year() != now.Year() || today.Month() != now.Month() || today.Day() != now.Day() {
t.Errorf("Today() date mismatch: today=%v, now=%v", today, now)
}
}
func TestStartOfDay(t *testing.T) {
Init("Asia/Shanghai")
// Create a time at 15:30:45
testTime := time.Date(2024, 6, 15, 15, 30, 45, 123456789, Location())
startOfDay := StartOfDay(testTime)
expected := time.Date(2024, 6, 15, 0, 0, 0, 0, Location())
if !startOfDay.Equal(expected) {
t.Errorf("StartOfDay incorrect: expected %v, got %v", expected, startOfDay)
}
}
func TestTruncateVsStartOfDay(t *testing.T) {
// This test demonstrates why Truncate(24*time.Hour) can be problematic
// and why StartOfDay is more reliable for timezone-aware code
Init("Asia/Shanghai")
now := Now()
// Truncate operates on UTC, not local time
truncated := now.Truncate(24 * time.Hour)
// StartOfDay operates on local time
startOfDay := StartOfDay(now)
// These will likely be different for non-UTC timezones
t.Logf("Now: %v", now)
t.Logf("Truncate(24h): %v", truncated)
t.Logf("StartOfDay: %v", startOfDay)
// The truncated time may not be at local midnight
// StartOfDay is always at local midnight
if startOfDay.Hour() != 0 {
t.Errorf("StartOfDay should be at hour 0, got %d", startOfDay.Hour())
}
}
func TestDSTAwareness(t *testing.T) {
// Test with a timezone that has DST (America/New_York)
err := Init("America/New_York")
if err != nil {
t.Skipf("America/New_York timezone not available: %v", err)
}
// Just verify it doesn't crash
_ = Today()
_ = Now()
_ = StartOfDay(Now())
}

View File

@@ -0,0 +1,268 @@
package repository
import (
"context"
"sub2api/internal/model"
"time"
"gorm.io/gorm"
)
type AccountRepository struct {
db *gorm.DB
}
func NewAccountRepository(db *gorm.DB) *AccountRepository {
return &AccountRepository{db: db}
}
func (r *AccountRepository) Create(ctx context.Context, account *model.Account) error {
return r.db.WithContext(ctx).Create(account).Error
}
func (r *AccountRepository) GetByID(ctx context.Context, id int64) (*model.Account, error) {
var account model.Account
err := r.db.WithContext(ctx).Preload("Proxy").Preload("AccountGroups").First(&account, id).Error
if err != nil {
return nil, err
}
// 填充 GroupIDs 虚拟字段
account.GroupIDs = make([]int64, 0, len(account.AccountGroups))
for _, ag := range account.AccountGroups {
account.GroupIDs = append(account.GroupIDs, ag.GroupID)
}
return &account, nil
}
func (r *AccountRepository) Update(ctx context.Context, account *model.Account) error {
return r.db.WithContext(ctx).Save(account).Error
}
func (r *AccountRepository) Delete(ctx context.Context, id int64) error {
// 先删除账号与分组的绑定关系
if err := r.db.WithContext(ctx).Where("account_id = ?", id).Delete(&model.AccountGroup{}).Error; err != nil {
return err
}
// 再删除账号
return r.db.WithContext(ctx).Delete(&model.Account{}, id).Error
}
func (r *AccountRepository) List(ctx context.Context, params PaginationParams) ([]model.Account, *PaginationResult, error) {
return r.ListWithFilters(ctx, params, "", "", "", "")
}
// ListWithFilters lists accounts with optional filtering by platform, type, status, and search query
func (r *AccountRepository) ListWithFilters(ctx context.Context, params PaginationParams, platform, accountType, status, search string) ([]model.Account, *PaginationResult, error) {
var accounts []model.Account
var total int64
db := r.db.WithContext(ctx).Model(&model.Account{})
// Apply filters
if platform != "" {
db = db.Where("platform = ?", platform)
}
if accountType != "" {
db = db.Where("type = ?", accountType)
}
if status != "" {
db = db.Where("status = ?", status)
}
if search != "" {
searchPattern := "%" + search + "%"
db = db.Where("name ILIKE ?", searchPattern)
}
if err := db.Count(&total).Error; err != nil {
return nil, nil, err
}
if err := db.Preload("Proxy").Preload("AccountGroups").Offset(params.Offset()).Limit(params.Limit()).Order("id DESC").Find(&accounts).Error; err != nil {
return nil, nil, err
}
// 填充每个 Account 的 GroupIDs 虚拟字段
for i := range accounts {
accounts[i].GroupIDs = make([]int64, 0, len(accounts[i].AccountGroups))
for _, ag := range accounts[i].AccountGroups {
accounts[i].GroupIDs = append(accounts[i].GroupIDs, ag.GroupID)
}
}
pages := int(total) / params.Limit()
if int(total)%params.Limit() > 0 {
pages++
}
return accounts, &PaginationResult{
Total: total,
Page: params.Page,
PageSize: params.Limit(),
Pages: pages,
}, nil
}
func (r *AccountRepository) ListByGroup(ctx context.Context, groupID int64) ([]model.Account, error) {
var accounts []model.Account
err := r.db.WithContext(ctx).
Joins("JOIN account_groups ON account_groups.account_id = accounts.id").
Where("account_groups.group_id = ? AND accounts.status = ?", groupID, model.StatusActive).
Preload("Proxy").
Order("account_groups.priority ASC, accounts.priority ASC").
Find(&accounts).Error
return accounts, err
}
func (r *AccountRepository) ListActive(ctx context.Context) ([]model.Account, error) {
var accounts []model.Account
err := r.db.WithContext(ctx).
Where("status = ?", model.StatusActive).
Preload("Proxy").
Order("priority ASC").
Find(&accounts).Error
return accounts, err
}
func (r *AccountRepository) UpdateLastUsed(ctx context.Context, id int64) error {
now := time.Now()
return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id).Update("last_used_at", now).Error
}
func (r *AccountRepository) SetError(ctx context.Context, id int64, errorMsg string) error {
return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id).
Updates(map[string]interface{}{
"status": model.StatusError,
"error_message": errorMsg,
}).Error
}
func (r *AccountRepository) AddToGroup(ctx context.Context, accountID, groupID int64, priority int) error {
ag := &model.AccountGroup{
AccountID: accountID,
GroupID: groupID,
Priority: priority,
}
return r.db.WithContext(ctx).Create(ag).Error
}
func (r *AccountRepository) RemoveFromGroup(ctx context.Context, accountID, groupID int64) error {
return r.db.WithContext(ctx).Where("account_id = ? AND group_id = ?", accountID, groupID).
Delete(&model.AccountGroup{}).Error
}
func (r *AccountRepository) GetGroups(ctx context.Context, accountID int64) ([]model.Group, error) {
var groups []model.Group
err := r.db.WithContext(ctx).
Joins("JOIN account_groups ON account_groups.group_id = groups.id").
Where("account_groups.account_id = ?", accountID).
Find(&groups).Error
return groups, err
}
func (r *AccountRepository) ListByPlatform(ctx context.Context, platform string) ([]model.Account, error) {
var accounts []model.Account
err := r.db.WithContext(ctx).
Where("platform = ? AND status = ?", platform, model.StatusActive).
Preload("Proxy").
Order("priority ASC").
Find(&accounts).Error
return accounts, err
}
func (r *AccountRepository) BindGroups(ctx context.Context, accountID int64, groupIDs []int64) error {
// 删除现有绑定
if err := r.db.WithContext(ctx).Where("account_id = ?", accountID).Delete(&model.AccountGroup{}).Error; err != nil {
return err
}
// 添加新绑定
if len(groupIDs) > 0 {
accountGroups := make([]model.AccountGroup, 0, len(groupIDs))
for i, groupID := range groupIDs {
accountGroups = append(accountGroups, model.AccountGroup{
AccountID: accountID,
GroupID: groupID,
Priority: i + 1, // 使用索引作为优先级
})
}
return r.db.WithContext(ctx).Create(&accountGroups).Error
}
return nil
}
// ListSchedulable 获取所有可调度的账号
func (r *AccountRepository) ListSchedulable(ctx context.Context) ([]model.Account, error) {
var accounts []model.Account
now := time.Now()
err := r.db.WithContext(ctx).
Where("status = ? AND schedulable = ?", model.StatusActive, true).
Where("(overload_until IS NULL OR overload_until <= ?)", now).
Where("(rate_limit_reset_at IS NULL OR rate_limit_reset_at <= ?)", now).
Preload("Proxy").
Order("priority ASC").
Find(&accounts).Error
return accounts, err
}
// ListSchedulableByGroupID 按组获取可调度的账号
func (r *AccountRepository) ListSchedulableByGroupID(ctx context.Context, groupID int64) ([]model.Account, error) {
var accounts []model.Account
now := time.Now()
err := r.db.WithContext(ctx).
Joins("JOIN account_groups ON account_groups.account_id = accounts.id").
Where("account_groups.group_id = ?", groupID).
Where("accounts.status = ? AND accounts.schedulable = ?", model.StatusActive, true).
Where("(accounts.overload_until IS NULL OR accounts.overload_until <= ?)", now).
Where("(accounts.rate_limit_reset_at IS NULL OR accounts.rate_limit_reset_at <= ?)", now).
Preload("Proxy").
Order("account_groups.priority ASC, accounts.priority ASC").
Find(&accounts).Error
return accounts, err
}
// SetRateLimited 标记账号为限流状态(429)
func (r *AccountRepository) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error {
now := time.Now()
return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id).
Updates(map[string]interface{}{
"rate_limited_at": now,
"rate_limit_reset_at": resetAt,
}).Error
}
// SetOverloaded 标记账号为过载状态(529)
func (r *AccountRepository) SetOverloaded(ctx context.Context, id int64, until time.Time) error {
return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id).
Update("overload_until", until).Error
}
// ClearRateLimit 清除账号的限流状态
func (r *AccountRepository) ClearRateLimit(ctx context.Context, id int64) error {
return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id).
Updates(map[string]interface{}{
"rate_limited_at": nil,
"rate_limit_reset_at": nil,
"overload_until": nil,
}).Error
}
// UpdateSessionWindow 更新账号的5小时时间窗口信息
func (r *AccountRepository) UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error {
updates := map[string]interface{}{
"session_window_status": status,
}
if start != nil {
updates["session_window_start"] = start
}
if end != nil {
updates["session_window_end"] = end
}
return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id).Updates(updates).Error
}
// SetSchedulable 设置账号的调度开关
func (r *AccountRepository) SetSchedulable(ctx context.Context, id int64, schedulable bool) error {
return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id).
Update("schedulable", schedulable).Error
}

View File

@@ -0,0 +1,149 @@
package repository
import (
"context"
"sub2api/internal/model"
"gorm.io/gorm"
)
type ApiKeyRepository struct {
db *gorm.DB
}
func NewApiKeyRepository(db *gorm.DB) *ApiKeyRepository {
return &ApiKeyRepository{db: db}
}
func (r *ApiKeyRepository) Create(ctx context.Context, key *model.ApiKey) error {
return r.db.WithContext(ctx).Create(key).Error
}
func (r *ApiKeyRepository) GetByID(ctx context.Context, id int64) (*model.ApiKey, error) {
var key model.ApiKey
err := r.db.WithContext(ctx).Preload("User").Preload("Group").First(&key, id).Error
if err != nil {
return nil, err
}
return &key, nil
}
func (r *ApiKeyRepository) GetByKey(ctx context.Context, key string) (*model.ApiKey, error) {
var apiKey model.ApiKey
err := r.db.WithContext(ctx).Preload("User").Preload("Group").Where("key = ?", key).First(&apiKey).Error
if err != nil {
return nil, err
}
return &apiKey, nil
}
func (r *ApiKeyRepository) Update(ctx context.Context, key *model.ApiKey) error {
return r.db.WithContext(ctx).Model(key).Select("name", "group_id", "status", "updated_at").Updates(key).Error
}
func (r *ApiKeyRepository) Delete(ctx context.Context, id int64) error {
return r.db.WithContext(ctx).Delete(&model.ApiKey{}, id).Error
}
func (r *ApiKeyRepository) ListByUserID(ctx context.Context, userID int64, params PaginationParams) ([]model.ApiKey, *PaginationResult, error) {
var keys []model.ApiKey
var total int64
db := r.db.WithContext(ctx).Model(&model.ApiKey{}).Where("user_id = ?", userID)
if err := db.Count(&total).Error; err != nil {
return nil, nil, err
}
if err := db.Preload("Group").Offset(params.Offset()).Limit(params.Limit()).Order("id DESC").Find(&keys).Error; err != nil {
return nil, nil, err
}
pages := int(total) / params.Limit()
if int(total)%params.Limit() > 0 {
pages++
}
return keys, &PaginationResult{
Total: total,
Page: params.Page,
PageSize: params.Limit(),
Pages: pages,
}, nil
}
func (r *ApiKeyRepository) CountByUserID(ctx context.Context, userID int64) (int64, error) {
var count int64
err := r.db.WithContext(ctx).Model(&model.ApiKey{}).Where("user_id = ?", userID).Count(&count).Error
return count, err
}
func (r *ApiKeyRepository) ExistsByKey(ctx context.Context, key string) (bool, error) {
var count int64
err := r.db.WithContext(ctx).Model(&model.ApiKey{}).Where("key = ?", key).Count(&count).Error
return count > 0, err
}
func (r *ApiKeyRepository) ListByGroupID(ctx context.Context, groupID int64, params PaginationParams) ([]model.ApiKey, *PaginationResult, error) {
var keys []model.ApiKey
var total int64
db := r.db.WithContext(ctx).Model(&model.ApiKey{}).Where("group_id = ?", groupID)
if err := db.Count(&total).Error; err != nil {
return nil, nil, err
}
if err := db.Preload("User").Offset(params.Offset()).Limit(params.Limit()).Order("id DESC").Find(&keys).Error; err != nil {
return nil, nil, err
}
pages := int(total) / params.Limit()
if int(total)%params.Limit() > 0 {
pages++
}
return keys, &PaginationResult{
Total: total,
Page: params.Page,
PageSize: params.Limit(),
Pages: pages,
}, nil
}
// SearchApiKeys searches API keys by user ID and/or keyword (name)
func (r *ApiKeyRepository) SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]model.ApiKey, error) {
var keys []model.ApiKey
db := r.db.WithContext(ctx).Model(&model.ApiKey{})
if userID > 0 {
db = db.Where("user_id = ?", userID)
}
if keyword != "" {
searchPattern := "%" + keyword + "%"
db = db.Where("name ILIKE ?", searchPattern)
}
if err := db.Limit(limit).Order("id DESC").Find(&keys).Error; err != nil {
return nil, err
}
return keys, nil
}
// ClearGroupIDByGroupID 将指定分组的所有 API Key 的 group_id 设为 nil
func (r *ApiKeyRepository) ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error) {
result := r.db.WithContext(ctx).Model(&model.ApiKey{}).
Where("group_id = ?", groupID).
Update("group_id", nil)
return result.RowsAffected, result.Error
}
// CountByGroupID 获取分组的 API Key 数量
func (r *ApiKeyRepository) CountByGroupID(ctx context.Context, groupID int64) (int64, error) {
var count int64
err := r.db.WithContext(ctx).Model(&model.ApiKey{}).Where("group_id = ?", groupID).Count(&count).Error
return count, err
}

View File

@@ -0,0 +1,137 @@
package repository
import (
"context"
"sub2api/internal/model"
"gorm.io/gorm"
)
type GroupRepository struct {
db *gorm.DB
}
func NewGroupRepository(db *gorm.DB) *GroupRepository {
return &GroupRepository{db: db}
}
func (r *GroupRepository) Create(ctx context.Context, group *model.Group) error {
return r.db.WithContext(ctx).Create(group).Error
}
func (r *GroupRepository) GetByID(ctx context.Context, id int64) (*model.Group, error) {
var group model.Group
err := r.db.WithContext(ctx).First(&group, id).Error
if err != nil {
return nil, err
}
return &group, nil
}
func (r *GroupRepository) Update(ctx context.Context, group *model.Group) error {
return r.db.WithContext(ctx).Save(group).Error
}
func (r *GroupRepository) Delete(ctx context.Context, id int64) error {
return r.db.WithContext(ctx).Delete(&model.Group{}, id).Error
}
func (r *GroupRepository) List(ctx context.Context, params PaginationParams) ([]model.Group, *PaginationResult, error) {
return r.ListWithFilters(ctx, params, "", "", nil)
}
// ListWithFilters lists groups with optional filtering by platform, status, and is_exclusive
func (r *GroupRepository) ListWithFilters(ctx context.Context, params PaginationParams, platform, status string, isExclusive *bool) ([]model.Group, *PaginationResult, error) {
var groups []model.Group
var total int64
db := r.db.WithContext(ctx).Model(&model.Group{})
// Apply filters
if platform != "" {
db = db.Where("platform = ?", platform)
}
if status != "" {
db = db.Where("status = ?", status)
}
if isExclusive != nil {
db = db.Where("is_exclusive = ?", *isExclusive)
}
if err := db.Count(&total).Error; err != nil {
return nil, nil, err
}
if err := db.Offset(params.Offset()).Limit(params.Limit()).Order("id ASC").Find(&groups).Error; err != nil {
return nil, nil, err
}
// 获取每个分组的账号数量
for i := range groups {
count, _ := r.GetAccountCount(ctx, groups[i].ID)
groups[i].AccountCount = count
}
pages := int(total) / params.Limit()
if int(total)%params.Limit() > 0 {
pages++
}
return groups, &PaginationResult{
Total: total,
Page: params.Page,
PageSize: params.Limit(),
Pages: pages,
}, nil
}
func (r *GroupRepository) ListActive(ctx context.Context) ([]model.Group, error) {
var groups []model.Group
err := r.db.WithContext(ctx).Where("status = ?", model.StatusActive).Order("id ASC").Find(&groups).Error
if err != nil {
return nil, err
}
// 获取每个分组的账号数量
for i := range groups {
count, _ := r.GetAccountCount(ctx, groups[i].ID)
groups[i].AccountCount = count
}
return groups, nil
}
func (r *GroupRepository) ListActiveByPlatform(ctx context.Context, platform string) ([]model.Group, error) {
var groups []model.Group
err := r.db.WithContext(ctx).Where("status = ? AND platform = ?", model.StatusActive, platform).Order("id ASC").Find(&groups).Error
if err != nil {
return nil, err
}
// 获取每个分组的账号数量
for i := range groups {
count, _ := r.GetAccountCount(ctx, groups[i].ID)
groups[i].AccountCount = count
}
return groups, nil
}
func (r *GroupRepository) ExistsByName(ctx context.Context, name string) (bool, error) {
var count int64
err := r.db.WithContext(ctx).Model(&model.Group{}).Where("name = ?", name).Count(&count).Error
return count > 0, err
}
func (r *GroupRepository) GetAccountCount(ctx context.Context, groupID int64) (int64, error) {
var count int64
err := r.db.WithContext(ctx).Model(&model.AccountGroup{}).Where("group_id = ?", groupID).Count(&count).Error
return count, err
}
// DeleteAccountGroupsByGroupID 删除分组与账号的关联关系
func (r *GroupRepository) DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error) {
result := r.db.WithContext(ctx).Where("group_id = ?", groupID).Delete(&model.AccountGroup{})
return result.RowsAffected, result.Error
}
// DB 返回底层数据库连接,用于事务处理
func (r *GroupRepository) DB() *gorm.DB {
return r.db
}

View File

@@ -0,0 +1,161 @@
package repository
import (
"context"
"sub2api/internal/model"
"gorm.io/gorm"
)
type ProxyRepository struct {
db *gorm.DB
}
func NewProxyRepository(db *gorm.DB) *ProxyRepository {
return &ProxyRepository{db: db}
}
func (r *ProxyRepository) Create(ctx context.Context, proxy *model.Proxy) error {
return r.db.WithContext(ctx).Create(proxy).Error
}
func (r *ProxyRepository) GetByID(ctx context.Context, id int64) (*model.Proxy, error) {
var proxy model.Proxy
err := r.db.WithContext(ctx).First(&proxy, id).Error
if err != nil {
return nil, err
}
return &proxy, nil
}
func (r *ProxyRepository) Update(ctx context.Context, proxy *model.Proxy) error {
return r.db.WithContext(ctx).Save(proxy).Error
}
func (r *ProxyRepository) Delete(ctx context.Context, id int64) error {
return r.db.WithContext(ctx).Delete(&model.Proxy{}, id).Error
}
func (r *ProxyRepository) List(ctx context.Context, params PaginationParams) ([]model.Proxy, *PaginationResult, error) {
return r.ListWithFilters(ctx, params, "", "", "")
}
// ListWithFilters lists proxies with optional filtering by protocol, status, and search query
func (r *ProxyRepository) ListWithFilters(ctx context.Context, params PaginationParams, protocol, status, search string) ([]model.Proxy, *PaginationResult, error) {
var proxies []model.Proxy
var total int64
db := r.db.WithContext(ctx).Model(&model.Proxy{})
// Apply filters
if protocol != "" {
db = db.Where("protocol = ?", protocol)
}
if status != "" {
db = db.Where("status = ?", status)
}
if search != "" {
searchPattern := "%" + search + "%"
db = db.Where("name ILIKE ?", searchPattern)
}
if err := db.Count(&total).Error; err != nil {
return nil, nil, err
}
if err := db.Offset(params.Offset()).Limit(params.Limit()).Order("id DESC").Find(&proxies).Error; err != nil {
return nil, nil, err
}
pages := int(total) / params.Limit()
if int(total)%params.Limit() > 0 {
pages++
}
return proxies, &PaginationResult{
Total: total,
Page: params.Page,
PageSize: params.Limit(),
Pages: pages,
}, nil
}
func (r *ProxyRepository) ListActive(ctx context.Context) ([]model.Proxy, error) {
var proxies []model.Proxy
err := r.db.WithContext(ctx).Where("status = ?", model.StatusActive).Find(&proxies).Error
return proxies, err
}
// ExistsByHostPortAuth checks if a proxy with the same host, port, username, and password exists
func (r *ProxyRepository) ExistsByHostPortAuth(ctx context.Context, host string, port int, username, password string) (bool, error) {
var count int64
err := r.db.WithContext(ctx).Model(&model.Proxy{}).
Where("host = ? AND port = ? AND username = ? AND password = ?", host, port, username, password).
Count(&count).Error
if err != nil {
return false, err
}
return count > 0, nil
}
// CountAccountsByProxyID returns the number of accounts using a specific proxy
func (r *ProxyRepository) CountAccountsByProxyID(ctx context.Context, proxyID int64) (int64, error) {
var count int64
err := r.db.WithContext(ctx).Model(&model.Account{}).
Where("proxy_id = ?", proxyID).
Count(&count).Error
return count, err
}
// GetAccountCountsForProxies returns a map of proxy ID to account count for all proxies
func (r *ProxyRepository) GetAccountCountsForProxies(ctx context.Context) (map[int64]int64, error) {
type result struct {
ProxyID int64 `gorm:"column:proxy_id"`
Count int64 `gorm:"column:count"`
}
var results []result
err := r.db.WithContext(ctx).
Model(&model.Account{}).
Select("proxy_id, COUNT(*) as count").
Where("proxy_id IS NOT NULL").
Group("proxy_id").
Scan(&results).Error
if err != nil {
return nil, err
}
counts := make(map[int64]int64)
for _, r := range results {
counts[r.ProxyID] = r.Count
}
return counts, nil
}
// ListActiveWithAccountCount returns all active proxies with account count, sorted by creation time descending
func (r *ProxyRepository) ListActiveWithAccountCount(ctx context.Context) ([]model.ProxyWithAccountCount, error) {
var proxies []model.Proxy
err := r.db.WithContext(ctx).
Where("status = ?", model.StatusActive).
Order("created_at DESC").
Find(&proxies).Error
if err != nil {
return nil, err
}
// Get account counts
counts, err := r.GetAccountCountsForProxies(ctx)
if err != nil {
return nil, err
}
// Build result with account counts
result := make([]model.ProxyWithAccountCount, len(proxies))
for i, proxy := range proxies {
result[i] = model.ProxyWithAccountCount{
Proxy: proxy,
AccountCount: counts[proxy.ID],
}
}
return result, nil
}

View File

@@ -0,0 +1,133 @@
package repository
import (
"context"
"sub2api/internal/model"
"time"
"gorm.io/gorm"
)
type RedeemCodeRepository struct {
db *gorm.DB
}
func NewRedeemCodeRepository(db *gorm.DB) *RedeemCodeRepository {
return &RedeemCodeRepository{db: db}
}
func (r *RedeemCodeRepository) Create(ctx context.Context, code *model.RedeemCode) error {
return r.db.WithContext(ctx).Create(code).Error
}
func (r *RedeemCodeRepository) CreateBatch(ctx context.Context, codes []model.RedeemCode) error {
return r.db.WithContext(ctx).Create(&codes).Error
}
func (r *RedeemCodeRepository) GetByID(ctx context.Context, id int64) (*model.RedeemCode, error) {
var code model.RedeemCode
err := r.db.WithContext(ctx).First(&code, id).Error
if err != nil {
return nil, err
}
return &code, nil
}
func (r *RedeemCodeRepository) GetByCode(ctx context.Context, code string) (*model.RedeemCode, error) {
var redeemCode model.RedeemCode
err := r.db.WithContext(ctx).Where("code = ?", code).First(&redeemCode).Error
if err != nil {
return nil, err
}
return &redeemCode, nil
}
func (r *RedeemCodeRepository) Delete(ctx context.Context, id int64) error {
return r.db.WithContext(ctx).Delete(&model.RedeemCode{}, id).Error
}
func (r *RedeemCodeRepository) List(ctx context.Context, params PaginationParams) ([]model.RedeemCode, *PaginationResult, error) {
return r.ListWithFilters(ctx, params, "", "", "")
}
// ListWithFilters lists redeem codes with optional filtering by type, status, and search query
func (r *RedeemCodeRepository) ListWithFilters(ctx context.Context, params PaginationParams, codeType, status, search string) ([]model.RedeemCode, *PaginationResult, error) {
var codes []model.RedeemCode
var total int64
db := r.db.WithContext(ctx).Model(&model.RedeemCode{})
// Apply filters
if codeType != "" {
db = db.Where("type = ?", codeType)
}
if status != "" {
db = db.Where("status = ?", status)
}
if search != "" {
searchPattern := "%" + search + "%"
db = db.Where("code ILIKE ?", searchPattern)
}
if err := db.Count(&total).Error; err != nil {
return nil, nil, err
}
if err := db.Preload("User").Preload("Group").Offset(params.Offset()).Limit(params.Limit()).Order("id DESC").Find(&codes).Error; err != nil {
return nil, nil, err
}
pages := int(total) / params.Limit()
if int(total)%params.Limit() > 0 {
pages++
}
return codes, &PaginationResult{
Total: total,
Page: params.Page,
PageSize: params.Limit(),
Pages: pages,
}, nil
}
func (r *RedeemCodeRepository) Update(ctx context.Context, code *model.RedeemCode) error {
return r.db.WithContext(ctx).Save(code).Error
}
func (r *RedeemCodeRepository) Use(ctx context.Context, id, userID int64) error {
now := time.Now()
result := r.db.WithContext(ctx).Model(&model.RedeemCode{}).
Where("id = ? AND status = ?", id, model.StatusUnused).
Updates(map[string]interface{}{
"status": model.StatusUsed,
"used_by": userID,
"used_at": now,
})
if result.Error != nil {
return result.Error
}
if result.RowsAffected == 0 {
return gorm.ErrRecordNotFound // 兑换码不存在或已被使用
}
return nil
}
// ListByUser returns all redeem codes used by a specific user
func (r *RedeemCodeRepository) ListByUser(ctx context.Context, userID int64, limit int) ([]model.RedeemCode, error) {
var codes []model.RedeemCode
if limit <= 0 {
limit = 10
}
err := r.db.WithContext(ctx).
Preload("Group").
Where("used_by = ?", userID).
Order("used_at DESC").
Limit(limit).
Find(&codes).Error
if err != nil {
return nil, err
}
return codes, nil
}

View File

@@ -0,0 +1,74 @@
package repository
import (
"gorm.io/gorm"
)
// Repositories 所有仓库的集合
type Repositories struct {
User *UserRepository
ApiKey *ApiKeyRepository
Group *GroupRepository
Account *AccountRepository
Proxy *ProxyRepository
RedeemCode *RedeemCodeRepository
UsageLog *UsageLogRepository
Setting *SettingRepository
UserSubscription *UserSubscriptionRepository
}
// NewRepositories 创建所有仓库
func NewRepositories(db *gorm.DB) *Repositories {
return &Repositories{
User: NewUserRepository(db),
ApiKey: NewApiKeyRepository(db),
Group: NewGroupRepository(db),
Account: NewAccountRepository(db),
Proxy: NewProxyRepository(db),
RedeemCode: NewRedeemCodeRepository(db),
UsageLog: NewUsageLogRepository(db),
Setting: NewSettingRepository(db),
UserSubscription: NewUserSubscriptionRepository(db),
}
}
// PaginationParams 分页参数
type PaginationParams struct {
Page int
PageSize int
}
// PaginationResult 分页结果
type PaginationResult struct {
Total int64
Page int
PageSize int
Pages int
}
// DefaultPagination 默认分页参数
func DefaultPagination() PaginationParams {
return PaginationParams{
Page: 1,
PageSize: 20,
}
}
// Offset 计算偏移量
func (p PaginationParams) Offset() int {
if p.Page < 1 {
p.Page = 1
}
return (p.Page - 1) * p.PageSize
}
// Limit 获取限制数
func (p PaginationParams) Limit() int {
if p.PageSize < 1 {
return 20
}
if p.PageSize > 100 {
return 100
}
return p.PageSize
}

View File

@@ -0,0 +1,108 @@
package repository
import (
"context"
"sub2api/internal/model"
"time"
"gorm.io/gorm"
"gorm.io/gorm/clause"
)
// SettingRepository 系统设置数据访问层
type SettingRepository struct {
db *gorm.DB
}
// NewSettingRepository 创建系统设置仓库实例
func NewSettingRepository(db *gorm.DB) *SettingRepository {
return &SettingRepository{db: db}
}
// Get 根据Key获取设置值
func (r *SettingRepository) Get(ctx context.Context, key string) (*model.Setting, error) {
var setting model.Setting
err := r.db.WithContext(ctx).Where("key = ?", key).First(&setting).Error
if err != nil {
return nil, err
}
return &setting, nil
}
// GetValue 获取设置值字符串
func (r *SettingRepository) GetValue(ctx context.Context, key string) (string, error) {
setting, err := r.Get(ctx, key)
if err != nil {
return "", err
}
return setting.Value, nil
}
// Set 设置值(存在则更新,不存在则创建)
func (r *SettingRepository) Set(ctx context.Context, key, value string) error {
setting := &model.Setting{
Key: key,
Value: value,
UpdatedAt: time.Now(),
}
return r.db.WithContext(ctx).Clauses(clause.OnConflict{
Columns: []clause.Column{{Name: "key"}},
DoUpdates: clause.AssignmentColumns([]string{"value", "updated_at"}),
}).Create(setting).Error
}
// GetMultiple 批量获取设置
func (r *SettingRepository) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) {
var settings []model.Setting
err := r.db.WithContext(ctx).Where("key IN ?", keys).Find(&settings).Error
if err != nil {
return nil, err
}
result := make(map[string]string)
for _, s := range settings {
result[s.Key] = s.Value
}
return result, nil
}
// SetMultiple 批量设置值
func (r *SettingRepository) SetMultiple(ctx context.Context, settings map[string]string) error {
return r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
for key, value := range settings {
setting := &model.Setting{
Key: key,
Value: value,
UpdatedAt: time.Now(),
}
if err := tx.Clauses(clause.OnConflict{
Columns: []clause.Column{{Name: "key"}},
DoUpdates: clause.AssignmentColumns([]string{"value", "updated_at"}),
}).Create(setting).Error; err != nil {
return err
}
}
return nil
})
}
// GetAll 获取所有设置
func (r *SettingRepository) GetAll(ctx context.Context) (map[string]string, error) {
var settings []model.Setting
err := r.db.WithContext(ctx).Find(&settings).Error
if err != nil {
return nil, err
}
result := make(map[string]string)
for _, s := range settings {
result[s.Key] = s.Value
}
return result, nil
}
// Delete 删除设置
func (r *SettingRepository) Delete(ctx context.Context, key string) error {
return r.db.WithContext(ctx).Where("key = ?", key).Delete(&model.Setting{}).Error
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,130 @@
package repository
import (
"context"
"sub2api/internal/model"
"gorm.io/gorm"
)
type UserRepository struct {
db *gorm.DB
}
func NewUserRepository(db *gorm.DB) *UserRepository {
return &UserRepository{db: db}
}
func (r *UserRepository) Create(ctx context.Context, user *model.User) error {
return r.db.WithContext(ctx).Create(user).Error
}
func (r *UserRepository) GetByID(ctx context.Context, id int64) (*model.User, error) {
var user model.User
err := r.db.WithContext(ctx).First(&user, id).Error
if err != nil {
return nil, err
}
return &user, nil
}
func (r *UserRepository) GetByEmail(ctx context.Context, email string) (*model.User, error) {
var user model.User
err := r.db.WithContext(ctx).Where("email = ?", email).First(&user).Error
if err != nil {
return nil, err
}
return &user, nil
}
func (r *UserRepository) Update(ctx context.Context, user *model.User) error {
return r.db.WithContext(ctx).Save(user).Error
}
func (r *UserRepository) Delete(ctx context.Context, id int64) error {
return r.db.WithContext(ctx).Delete(&model.User{}, id).Error
}
func (r *UserRepository) List(ctx context.Context, params PaginationParams) ([]model.User, *PaginationResult, error) {
return r.ListWithFilters(ctx, params, "", "", "")
}
// ListWithFilters lists users with optional filtering by status, role, and search query
func (r *UserRepository) ListWithFilters(ctx context.Context, params PaginationParams, status, role, search string) ([]model.User, *PaginationResult, error) {
var users []model.User
var total int64
db := r.db.WithContext(ctx).Model(&model.User{})
// Apply filters
if status != "" {
db = db.Where("status = ?", status)
}
if role != "" {
db = db.Where("role = ?", role)
}
if search != "" {
searchPattern := "%" + search + "%"
db = db.Where("email ILIKE ?", searchPattern)
}
if err := db.Count(&total).Error; err != nil {
return nil, nil, err
}
if err := db.Offset(params.Offset()).Limit(params.Limit()).Order("id DESC").Find(&users).Error; err != nil {
return nil, nil, err
}
pages := int(total) / params.Limit()
if int(total)%params.Limit() > 0 {
pages++
}
return users, &PaginationResult{
Total: total,
Page: params.Page,
PageSize: params.Limit(),
Pages: pages,
}, nil
}
func (r *UserRepository) UpdateBalance(ctx context.Context, id int64, amount float64) error {
return r.db.WithContext(ctx).Model(&model.User{}).Where("id = ?", id).
Update("balance", gorm.Expr("balance + ?", amount)).Error
}
// DeductBalance 扣减用户余额,仅当余额充足时执行
func (r *UserRepository) DeductBalance(ctx context.Context, id int64, amount float64) error {
result := r.db.WithContext(ctx).Model(&model.User{}).
Where("id = ? AND balance >= ?", id, amount).
Update("balance", gorm.Expr("balance - ?", amount))
if result.Error != nil {
return result.Error
}
if result.RowsAffected == 0 {
return gorm.ErrRecordNotFound // 余额不足或用户不存在
}
return nil
}
func (r *UserRepository) UpdateConcurrency(ctx context.Context, id int64, amount int) error {
return r.db.WithContext(ctx).Model(&model.User{}).Where("id = ?", id).
Update("concurrency", gorm.Expr("concurrency + ?", amount)).Error
}
func (r *UserRepository) ExistsByEmail(ctx context.Context, email string) (bool, error) {
var count int64
err := r.db.WithContext(ctx).Model(&model.User{}).Where("email = ?", email).Count(&count).Error
return count > 0, err
}
// RemoveGroupFromAllowedGroups 从所有用户的 allowed_groups 数组中移除指定的分组ID
// 使用 PostgreSQL 的 array_remove 函数
func (r *UserRepository) RemoveGroupFromAllowedGroups(ctx context.Context, groupID int64) (int64, error) {
result := r.db.WithContext(ctx).Model(&model.User{}).
Where("? = ANY(allowed_groups)", groupID).
Update("allowed_groups", gorm.Expr("array_remove(allowed_groups, ?)", groupID))
return result.RowsAffected, result.Error
}

View File

@@ -0,0 +1,322 @@
package repository
import (
"context"
"time"
"sub2api/internal/model"
"gorm.io/gorm"
)
// UserSubscriptionRepository 用户订阅仓库
type UserSubscriptionRepository struct {
db *gorm.DB
}
// NewUserSubscriptionRepository 创建用户订阅仓库
func NewUserSubscriptionRepository(db *gorm.DB) *UserSubscriptionRepository {
return &UserSubscriptionRepository{db: db}
}
// Create 创建订阅
func (r *UserSubscriptionRepository) Create(ctx context.Context, sub *model.UserSubscription) error {
return r.db.WithContext(ctx).Create(sub).Error
}
// GetByID 根据ID获取订阅
func (r *UserSubscriptionRepository) GetByID(ctx context.Context, id int64) (*model.UserSubscription, error) {
var sub model.UserSubscription
err := r.db.WithContext(ctx).
Preload("User").
Preload("Group").
Preload("AssignedByUser").
First(&sub, id).Error
if err != nil {
return nil, err
}
return &sub, nil
}
// GetByUserIDAndGroupID 根据用户ID和分组ID获取订阅
func (r *UserSubscriptionRepository) GetByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (*model.UserSubscription, error) {
var sub model.UserSubscription
err := r.db.WithContext(ctx).
Preload("Group").
Where("user_id = ? AND group_id = ?", userID, groupID).
First(&sub).Error
if err != nil {
return nil, err
}
return &sub, nil
}
// GetActiveByUserIDAndGroupID 获取用户对特定分组的有效订阅
func (r *UserSubscriptionRepository) GetActiveByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (*model.UserSubscription, error) {
var sub model.UserSubscription
err := r.db.WithContext(ctx).
Preload("Group").
Where("user_id = ? AND group_id = ? AND status = ? AND expires_at > ?",
userID, groupID, model.SubscriptionStatusActive, time.Now()).
First(&sub).Error
if err != nil {
return nil, err
}
return &sub, nil
}
// Update 更新订阅
func (r *UserSubscriptionRepository) Update(ctx context.Context, sub *model.UserSubscription) error {
sub.UpdatedAt = time.Now()
return r.db.WithContext(ctx).Save(sub).Error
}
// Delete 删除订阅
func (r *UserSubscriptionRepository) Delete(ctx context.Context, id int64) error {
return r.db.WithContext(ctx).Delete(&model.UserSubscription{}, id).Error
}
// ListByUserID 获取用户的所有订阅
func (r *UserSubscriptionRepository) ListByUserID(ctx context.Context, userID int64) ([]model.UserSubscription, error) {
var subs []model.UserSubscription
err := r.db.WithContext(ctx).
Preload("Group").
Where("user_id = ?", userID).
Order("created_at DESC").
Find(&subs).Error
return subs, err
}
// ListActiveByUserID 获取用户的所有有效订阅
func (r *UserSubscriptionRepository) ListActiveByUserID(ctx context.Context, userID int64) ([]model.UserSubscription, error) {
var subs []model.UserSubscription
err := r.db.WithContext(ctx).
Preload("Group").
Where("user_id = ? AND status = ? AND expires_at > ?",
userID, model.SubscriptionStatusActive, time.Now()).
Order("created_at DESC").
Find(&subs).Error
return subs, err
}
// ListByGroupID 获取分组的所有订阅(分页)
func (r *UserSubscriptionRepository) ListByGroupID(ctx context.Context, groupID int64, params PaginationParams) ([]model.UserSubscription, *PaginationResult, error) {
var subs []model.UserSubscription
var total int64
query := r.db.WithContext(ctx).Model(&model.UserSubscription{}).Where("group_id = ?", groupID)
if err := query.Count(&total).Error; err != nil {
return nil, nil, err
}
err := query.
Preload("User").
Preload("Group").
Order("created_at DESC").
Offset(params.Offset()).
Limit(params.Limit()).
Find(&subs).Error
if err != nil {
return nil, nil, err
}
pages := int(total) / params.Limit()
if int(total)%params.Limit() > 0 {
pages++
}
return subs, &PaginationResult{
Total: total,
Page: params.Page,
PageSize: params.Limit(),
Pages: pages,
}, nil
}
// List 获取所有订阅(分页,支持筛选)
func (r *UserSubscriptionRepository) List(ctx context.Context, params PaginationParams, userID, groupID *int64, status string) ([]model.UserSubscription, *PaginationResult, error) {
var subs []model.UserSubscription
var total int64
query := r.db.WithContext(ctx).Model(&model.UserSubscription{})
if userID != nil {
query = query.Where("user_id = ?", *userID)
}
if groupID != nil {
query = query.Where("group_id = ?", *groupID)
}
if status != "" {
query = query.Where("status = ?", status)
}
if err := query.Count(&total).Error; err != nil {
return nil, nil, err
}
err := query.
Preload("User").
Preload("Group").
Preload("AssignedByUser").
Order("created_at DESC").
Offset(params.Offset()).
Limit(params.Limit()).
Find(&subs).Error
if err != nil {
return nil, nil, err
}
pages := int(total) / params.Limit()
if int(total)%params.Limit() > 0 {
pages++
}
return subs, &PaginationResult{
Total: total,
Page: params.Page,
PageSize: params.Limit(),
Pages: pages,
}, nil
}
// IncrementUsage 增加使用量
func (r *UserSubscriptionRepository) IncrementUsage(ctx context.Context, id int64, costUSD float64) error {
return r.db.WithContext(ctx).Model(&model.UserSubscription{}).
Where("id = ?", id).
Updates(map[string]interface{}{
"daily_usage_usd": gorm.Expr("daily_usage_usd + ?", costUSD),
"weekly_usage_usd": gorm.Expr("weekly_usage_usd + ?", costUSD),
"monthly_usage_usd": gorm.Expr("monthly_usage_usd + ?", costUSD),
"updated_at": time.Now(),
}).Error
}
// ResetDailyUsage 重置日使用量
func (r *UserSubscriptionRepository) ResetDailyUsage(ctx context.Context, id int64, newWindowStart time.Time) error {
return r.db.WithContext(ctx).Model(&model.UserSubscription{}).
Where("id = ?", id).
Updates(map[string]interface{}{
"daily_usage_usd": 0,
"daily_window_start": newWindowStart,
"updated_at": time.Now(),
}).Error
}
// ResetWeeklyUsage 重置周使用量
func (r *UserSubscriptionRepository) ResetWeeklyUsage(ctx context.Context, id int64, newWindowStart time.Time) error {
return r.db.WithContext(ctx).Model(&model.UserSubscription{}).
Where("id = ?", id).
Updates(map[string]interface{}{
"weekly_usage_usd": 0,
"weekly_window_start": newWindowStart,
"updated_at": time.Now(),
}).Error
}
// ResetMonthlyUsage 重置月使用量
func (r *UserSubscriptionRepository) ResetMonthlyUsage(ctx context.Context, id int64, newWindowStart time.Time) error {
return r.db.WithContext(ctx).Model(&model.UserSubscription{}).
Where("id = ?", id).
Updates(map[string]interface{}{
"monthly_usage_usd": 0,
"monthly_window_start": newWindowStart,
"updated_at": time.Now(),
}).Error
}
// ActivateWindows 激活所有窗口(首次使用时)
func (r *UserSubscriptionRepository) ActivateWindows(ctx context.Context, id int64, activateTime time.Time) error {
return r.db.WithContext(ctx).Model(&model.UserSubscription{}).
Where("id = ?", id).
Updates(map[string]interface{}{
"daily_window_start": activateTime,
"weekly_window_start": activateTime,
"monthly_window_start": activateTime,
"updated_at": time.Now(),
}).Error
}
// UpdateStatus 更新订阅状态
func (r *UserSubscriptionRepository) UpdateStatus(ctx context.Context, id int64, status string) error {
return r.db.WithContext(ctx).Model(&model.UserSubscription{}).
Where("id = ?", id).
Updates(map[string]interface{}{
"status": status,
"updated_at": time.Now(),
}).Error
}
// ExtendExpiry 延长订阅过期时间
func (r *UserSubscriptionRepository) ExtendExpiry(ctx context.Context, id int64, newExpiresAt time.Time) error {
return r.db.WithContext(ctx).Model(&model.UserSubscription{}).
Where("id = ?", id).
Updates(map[string]interface{}{
"expires_at": newExpiresAt,
"updated_at": time.Now(),
}).Error
}
// UpdateNotes 更新订阅备注
func (r *UserSubscriptionRepository) UpdateNotes(ctx context.Context, id int64, notes string) error {
return r.db.WithContext(ctx).Model(&model.UserSubscription{}).
Where("id = ?", id).
Updates(map[string]interface{}{
"notes": notes,
"updated_at": time.Now(),
}).Error
}
// ListExpired 获取所有已过期但状态仍为active的订阅
func (r *UserSubscriptionRepository) ListExpired(ctx context.Context) ([]model.UserSubscription, error) {
var subs []model.UserSubscription
err := r.db.WithContext(ctx).
Where("status = ? AND expires_at <= ?", model.SubscriptionStatusActive, time.Now()).
Find(&subs).Error
return subs, err
}
// BatchUpdateExpiredStatus 批量更新过期订阅状态
func (r *UserSubscriptionRepository) BatchUpdateExpiredStatus(ctx context.Context) (int64, error) {
result := r.db.WithContext(ctx).Model(&model.UserSubscription{}).
Where("status = ? AND expires_at <= ?", model.SubscriptionStatusActive, time.Now()).
Updates(map[string]interface{}{
"status": model.SubscriptionStatusExpired,
"updated_at": time.Now(),
})
return result.RowsAffected, result.Error
}
// ExistsByUserIDAndGroupID 检查用户是否已有该分组的订阅
func (r *UserSubscriptionRepository) ExistsByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (bool, error) {
var count int64
err := r.db.WithContext(ctx).Model(&model.UserSubscription{}).
Where("user_id = ? AND group_id = ?", userID, groupID).
Count(&count).Error
return count > 0, err
}
// CountByGroupID 获取分组的订阅数量
func (r *UserSubscriptionRepository) CountByGroupID(ctx context.Context, groupID int64) (int64, error) {
var count int64
err := r.db.WithContext(ctx).Model(&model.UserSubscription{}).
Where("group_id = ?", groupID).
Count(&count).Error
return count, err
}
// CountActiveByGroupID 获取分组的有效订阅数量
func (r *UserSubscriptionRepository) CountActiveByGroupID(ctx context.Context, groupID int64) (int64, error) {
var count int64
err := r.db.WithContext(ctx).Model(&model.UserSubscription{}).
Where("group_id = ? AND status = ? AND expires_at > ?",
groupID, model.SubscriptionStatusActive, time.Now()).
Count(&count).Error
return count, err
}
// DeleteByGroupID 删除分组相关的所有订阅记录
func (r *UserSubscriptionRepository) DeleteByGroupID(ctx context.Context, groupID int64) (int64, error) {
result := r.db.WithContext(ctx).Where("group_id = ?", groupID).Delete(&model.UserSubscription{})
return result.RowsAffected, result.Error
}

View File

@@ -0,0 +1,284 @@
package service
import (
"context"
"errors"
"fmt"
"sub2api/internal/model"
"sub2api/internal/repository"
"gorm.io/gorm"
)
var (
ErrAccountNotFound = errors.New("account not found")
)
// CreateAccountRequest 创建账号请求
type CreateAccountRequest struct {
Name string `json:"name"`
Platform string `json:"platform"`
Type string `json:"type"`
Credentials map[string]interface{} `json:"credentials"`
Extra map[string]interface{} `json:"extra"`
ProxyID *int64 `json:"proxy_id"`
Concurrency int `json:"concurrency"`
Priority int `json:"priority"`
GroupIDs []int64 `json:"group_ids"`
}
// UpdateAccountRequest 更新账号请求
type UpdateAccountRequest struct {
Name *string `json:"name"`
Credentials *map[string]interface{} `json:"credentials"`
Extra *map[string]interface{} `json:"extra"`
ProxyID *int64 `json:"proxy_id"`
Concurrency *int `json:"concurrency"`
Priority *int `json:"priority"`
Status *string `json:"status"`
GroupIDs *[]int64 `json:"group_ids"`
}
// AccountService 账号管理服务
type AccountService struct {
accountRepo *repository.AccountRepository
groupRepo *repository.GroupRepository
}
// NewAccountService 创建账号服务实例
func NewAccountService(accountRepo *repository.AccountRepository, groupRepo *repository.GroupRepository) *AccountService {
return &AccountService{
accountRepo: accountRepo,
groupRepo: groupRepo,
}
}
// Create 创建账号
func (s *AccountService) Create(ctx context.Context, req CreateAccountRequest) (*model.Account, error) {
// 验证分组是否存在(如果指定了分组)
if len(req.GroupIDs) > 0 {
for _, groupID := range req.GroupIDs {
_, err := s.groupRepo.GetByID(ctx, groupID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, fmt.Errorf("group %d not found", groupID)
}
return nil, fmt.Errorf("get group: %w", err)
}
}
}
// 创建账号
account := &model.Account{
Name: req.Name,
Platform: req.Platform,
Type: req.Type,
Credentials: req.Credentials,
Extra: req.Extra,
ProxyID: req.ProxyID,
Concurrency: req.Concurrency,
Priority: req.Priority,
Status: model.StatusActive,
}
if err := s.accountRepo.Create(ctx, account); err != nil {
return nil, fmt.Errorf("create account: %w", err)
}
// 绑定分组
if len(req.GroupIDs) > 0 {
if err := s.accountRepo.BindGroups(ctx, account.ID, req.GroupIDs); err != nil {
return nil, fmt.Errorf("bind groups: %w", err)
}
}
return account, nil
}
// GetByID 根据ID获取账号
func (s *AccountService) GetByID(ctx context.Context, id int64) (*model.Account, error) {
account, err := s.accountRepo.GetByID(ctx, id)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrAccountNotFound
}
return nil, fmt.Errorf("get account: %w", err)
}
return account, nil
}
// List 获取账号列表
func (s *AccountService) List(ctx context.Context, params repository.PaginationParams) ([]model.Account, *repository.PaginationResult, error) {
accounts, pagination, err := s.accountRepo.List(ctx, params)
if err != nil {
return nil, nil, fmt.Errorf("list accounts: %w", err)
}
return accounts, pagination, nil
}
// ListByPlatform 根据平台获取账号列表
func (s *AccountService) ListByPlatform(ctx context.Context, platform string) ([]model.Account, error) {
accounts, err := s.accountRepo.ListByPlatform(ctx, platform)
if err != nil {
return nil, fmt.Errorf("list accounts by platform: %w", err)
}
return accounts, nil
}
// ListByGroup 根据分组获取账号列表
func (s *AccountService) ListByGroup(ctx context.Context, groupID int64) ([]model.Account, error) {
accounts, err := s.accountRepo.ListByGroup(ctx, groupID)
if err != nil {
return nil, fmt.Errorf("list accounts by group: %w", err)
}
return accounts, nil
}
// Update 更新账号
func (s *AccountService) Update(ctx context.Context, id int64, req UpdateAccountRequest) (*model.Account, error) {
account, err := s.accountRepo.GetByID(ctx, id)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrAccountNotFound
}
return nil, fmt.Errorf("get account: %w", err)
}
// 更新字段
if req.Name != nil {
account.Name = *req.Name
}
if req.Credentials != nil {
account.Credentials = *req.Credentials
}
if req.Extra != nil {
account.Extra = *req.Extra
}
if req.ProxyID != nil {
account.ProxyID = req.ProxyID
}
if req.Concurrency != nil {
account.Concurrency = *req.Concurrency
}
if req.Priority != nil {
account.Priority = *req.Priority
}
if req.Status != nil {
account.Status = *req.Status
}
if err := s.accountRepo.Update(ctx, account); err != nil {
return nil, fmt.Errorf("update account: %w", err)
}
// 更新分组绑定
if req.GroupIDs != nil {
// 验证分组是否存在
for _, groupID := range *req.GroupIDs {
_, err := s.groupRepo.GetByID(ctx, groupID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, fmt.Errorf("group %d not found", groupID)
}
return nil, fmt.Errorf("get group: %w", err)
}
}
if err := s.accountRepo.BindGroups(ctx, account.ID, *req.GroupIDs); err != nil {
return nil, fmt.Errorf("bind groups: %w", err)
}
}
return account, nil
}
// Delete 删除账号
func (s *AccountService) Delete(ctx context.Context, id int64) error {
// 检查账号是否存在
_, err := s.accountRepo.GetByID(ctx, id)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return ErrAccountNotFound
}
return fmt.Errorf("get account: %w", err)
}
if err := s.accountRepo.Delete(ctx, id); err != nil {
return fmt.Errorf("delete account: %w", err)
}
return nil
}
// UpdateStatus 更新账号状态
func (s *AccountService) UpdateStatus(ctx context.Context, id int64, status string, errorMessage string) error {
account, err := s.accountRepo.GetByID(ctx, id)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return ErrAccountNotFound
}
return fmt.Errorf("get account: %w", err)
}
account.Status = status
account.ErrorMessage = errorMessage
if err := s.accountRepo.Update(ctx, account); err != nil {
return fmt.Errorf("update account: %w", err)
}
return nil
}
// UpdateLastUsed 更新最后使用时间
func (s *AccountService) UpdateLastUsed(ctx context.Context, id int64) error {
if err := s.accountRepo.UpdateLastUsed(ctx, id); err != nil {
return fmt.Errorf("update last used: %w", err)
}
return nil
}
// GetCredential 获取账号凭证(安全访问)
func (s *AccountService) GetCredential(ctx context.Context, id int64, key string) (string, error) {
account, err := s.accountRepo.GetByID(ctx, id)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return "", ErrAccountNotFound
}
return "", fmt.Errorf("get account: %w", err)
}
return account.GetCredential(key), nil
}
// TestCredentials 测试账号凭证是否有效(需要实现具体平台的测试逻辑)
func (s *AccountService) TestCredentials(ctx context.Context, id int64) error {
account, err := s.accountRepo.GetByID(ctx, id)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return ErrAccountNotFound
}
return fmt.Errorf("get account: %w", err)
}
// 根据平台执行不同的测试逻辑
switch account.Platform {
case model.PlatformAnthropic:
// TODO: 测试Anthropic API凭证
return nil
case model.PlatformOpenAI:
// TODO: 测试OpenAI API凭证
return nil
case model.PlatformGemini:
// TODO: 测试Gemini API凭证
return nil
default:
return fmt.Errorf("unsupported platform: %s", account.Platform)
}
}

View File

@@ -0,0 +1,314 @@
package service
import (
"bufio"
"bytes"
"crypto/rand"
"encoding/hex"
"encoding/json"
"fmt"
"io"
"log"
"net/http"
"net/url"
"strconv"
"strings"
"time"
"sub2api/internal/repository"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
)
const (
testClaudeAPIURL = "https://api.anthropic.com/v1/messages"
testModel = "claude-sonnet-4-5-20250929"
)
// TestEvent represents a SSE event for account testing
type TestEvent struct {
Type string `json:"type"`
Text string `json:"text,omitempty"`
Model string `json:"model,omitempty"`
Success bool `json:"success,omitempty"`
Error string `json:"error,omitempty"`
}
// AccountTestService handles account testing operations
type AccountTestService struct {
repos *repository.Repositories
oauthService *OAuthService
httpClient *http.Client
}
// NewAccountTestService creates a new AccountTestService
func NewAccountTestService(repos *repository.Repositories, oauthService *OAuthService) *AccountTestService {
return &AccountTestService{
repos: repos,
oauthService: oauthService,
httpClient: &http.Client{
Timeout: 60 * time.Second,
},
}
}
// generateSessionString generates a Claude Code style session string
func generateSessionString() string {
bytes := make([]byte, 32)
rand.Read(bytes)
hex64 := hex.EncodeToString(bytes)
sessionUUID := uuid.New().String()
return fmt.Sprintf("user_%s_account__session_%s", hex64, sessionUUID)
}
// createTestPayload creates a minimal test request payload for OAuth/Setup Token accounts
func createTestPayload() map[string]interface{} {
return map[string]interface{}{
"model": testModel,
"messages": []map[string]interface{}{
{
"role": "user",
"content": []map[string]interface{}{
{
"type": "text",
"text": "hi",
"cache_control": map[string]string{
"type": "ephemeral",
},
},
},
},
},
"system": []map[string]interface{}{
{
"type": "text",
"text": "You are Claude Code, Anthropic's official CLI for Claude.",
"cache_control": map[string]string{
"type": "ephemeral",
},
},
},
"metadata": map[string]string{
"user_id": generateSessionString(),
},
"max_tokens": 1024,
"temperature": 1,
"stream": true,
}
}
// createApiKeyTestPayload creates a simpler test request payload for API Key accounts
func createApiKeyTestPayload(model string) map[string]interface{} {
return map[string]interface{}{
"model": model,
"messages": []map[string]interface{}{
{
"role": "user",
"content": "hi",
},
},
"max_tokens": 1024,
"stream": true,
}
}
// TestAccountConnection tests an account's connection by sending a test request
func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int64) error {
ctx := c.Request.Context()
// Get account
account, err := s.repos.Account.GetByID(ctx, accountID)
if err != nil {
return s.sendErrorAndEnd(c, "Account not found")
}
// Determine authentication method based on account type
var authToken string
var authType string // "bearer" for OAuth, "apikey" for API Key
var apiURL string
if account.IsOAuth() {
// OAuth or Setup Token account
authType = "bearer"
apiURL = testClaudeAPIURL
authToken = account.GetCredential("access_token")
if authToken == "" {
return s.sendErrorAndEnd(c, "No access token available")
}
// Check if token needs refresh
needRefresh := false
if expiresAtStr := account.GetCredential("expires_at"); expiresAtStr != "" {
expiresAt, err := strconv.ParseInt(expiresAtStr, 10, 64)
if err == nil && time.Now().Unix()+300 > expiresAt { // 5 minute buffer
needRefresh = true
}
}
if needRefresh && s.oauthService != nil {
tokenInfo, err := s.oauthService.RefreshAccountToken(ctx, account)
if err != nil {
return s.sendErrorAndEnd(c, fmt.Sprintf("Failed to refresh token: %s", err.Error()))
}
authToken = tokenInfo.AccessToken
}
} else if account.Type == "apikey" {
// API Key account
authType = "apikey"
authToken = account.GetCredential("api_key")
if authToken == "" {
return s.sendErrorAndEnd(c, "No API key available")
}
// Get base URL (use default if not set)
apiURL = account.GetBaseURL()
if apiURL == "" {
apiURL = "https://api.anthropic.com"
}
// Append /v1/messages endpoint
apiURL = strings.TrimSuffix(apiURL, "/") + "/v1/messages"
} else {
return s.sendErrorAndEnd(c, fmt.Sprintf("Unsupported account type: %s", account.Type))
}
// Set SSE headers
c.Writer.Header().Set("Content-Type", "text/event-stream")
c.Writer.Header().Set("Cache-Control", "no-cache")
c.Writer.Header().Set("Connection", "keep-alive")
c.Writer.Header().Set("X-Accel-Buffering", "no")
c.Writer.Flush()
// Create test request payload
var payload map[string]interface{}
var actualModel string
if authType == "apikey" {
// Use simpler payload for API Key (without Claude Code specific fields)
// Apply model mapping if configured
actualModel = account.GetMappedModel(testModel)
payload = createApiKeyTestPayload(actualModel)
} else {
actualModel = testModel
payload = createTestPayload()
}
payloadBytes, _ := json.Marshal(payload)
// Send test_start event with model info
s.sendEvent(c, TestEvent{Type: "test_start", Model: actualModel})
req, err := http.NewRequestWithContext(ctx, "POST", apiURL, bytes.NewReader(payloadBytes))
if err != nil {
return s.sendErrorAndEnd(c, "Failed to create request")
}
// Set headers based on auth type
req.Header.Set("Content-Type", "application/json")
req.Header.Set("anthropic-version", "2023-06-01")
if authType == "bearer" {
req.Header.Set("Authorization", "Bearer "+authToken)
req.Header.Set("anthropic-beta", "prompt-caching-2024-07-31,interleaved-thinking-2025-05-14,output-128k-2025-02-19")
} else {
// API Key uses x-api-key header
req.Header.Set("x-api-key", authToken)
}
// Configure proxy if account has one
transport := http.DefaultTransport.(*http.Transport).Clone()
if account.ProxyID != nil && account.Proxy != nil {
proxyURL := account.Proxy.URL()
if proxyURL != "" {
if parsedURL, err := url.Parse(proxyURL); err == nil {
transport.Proxy = http.ProxyURL(parsedURL)
}
}
}
client := &http.Client{
Transport: transport,
Timeout: 60 * time.Second,
}
resp, err := client.Do(req)
if err != nil {
return s.sendErrorAndEnd(c, fmt.Sprintf("Request failed: %s", err.Error()))
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
return s.sendErrorAndEnd(c, fmt.Sprintf("API returned %d: %s", resp.StatusCode, string(body)))
}
// Process SSE stream
return s.processStream(c, resp.Body)
}
// processStream processes the SSE stream from Claude API
func (s *AccountTestService) processStream(c *gin.Context, body io.Reader) error {
reader := bufio.NewReader(body)
for {
line, err := reader.ReadString('\n')
if err != nil {
if err == io.EOF {
// Stream ended, send complete event
s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
return nil
}
return s.sendErrorAndEnd(c, fmt.Sprintf("Stream read error: %s", err.Error()))
}
line = strings.TrimSpace(line)
if line == "" || !strings.HasPrefix(line, "data: ") {
continue
}
jsonStr := strings.TrimPrefix(line, "data: ")
if jsonStr == "[DONE]" {
s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
return nil
}
var data map[string]interface{}
if err := json.Unmarshal([]byte(jsonStr), &data); err != nil {
continue
}
eventType, _ := data["type"].(string)
switch eventType {
case "content_block_delta":
if delta, ok := data["delta"].(map[string]interface{}); ok {
if text, ok := delta["text"].(string); ok {
s.sendEvent(c, TestEvent{Type: "content", Text: text})
}
}
case "message_stop":
s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
return nil
case "error":
errorMsg := "Unknown error"
if errData, ok := data["error"].(map[string]interface{}); ok {
if msg, ok := errData["message"].(string); ok {
errorMsg = msg
}
}
return s.sendErrorAndEnd(c, errorMsg)
}
}
}
// sendEvent sends a SSE event to the client
func (s *AccountTestService) sendEvent(c *gin.Context, event TestEvent) {
eventJSON, _ := json.Marshal(event)
fmt.Fprintf(c.Writer, "data: %s\n\n", eventJSON)
c.Writer.Flush()
}
// sendErrorAndEnd sends an error event and ends the stream
func (s *AccountTestService) sendErrorAndEnd(c *gin.Context, errorMsg string) error {
log.Printf("Account test error: %s", errorMsg)
s.sendEvent(c, TestEvent{Type: "error", Error: errorMsg})
return fmt.Errorf(errorMsg)
}

View File

@@ -0,0 +1,345 @@
package service
import (
"context"
"encoding/json"
"fmt"
"io"
"log"
"net/http"
"net/url"
"sync"
"time"
"sub2api/internal/model"
"sub2api/internal/repository"
)
// usageCache 用于缓存usage数据
type usageCache struct {
data *UsageInfo
timestamp time.Time
}
var (
usageCacheMap = sync.Map{}
cacheTTL = 10 * time.Minute
)
// WindowStats 窗口期统计
type WindowStats struct {
Requests int64 `json:"requests"`
Tokens int64 `json:"tokens"`
Cost float64 `json:"cost"`
}
// UsageProgress 使用量进度
type UsageProgress struct {
Utilization float64 `json:"utilization"` // 使用率百分比 (0-100+100表示100%)
ResetsAt *time.Time `json:"resets_at"` // 重置时间
RemainingSeconds int `json:"remaining_seconds"` // 距重置剩余秒数
WindowStats *WindowStats `json:"window_stats,omitempty"` // 窗口期统计(从窗口开始到当前的使用量)
}
// UsageInfo 账号使用量信息
type UsageInfo struct {
UpdatedAt *time.Time `json:"updated_at,omitempty"` // 更新时间
FiveHour *UsageProgress `json:"five_hour"` // 5小时窗口
SevenDay *UsageProgress `json:"seven_day,omitempty"` // 7天窗口
SevenDaySonnet *UsageProgress `json:"seven_day_sonnet,omitempty"` // 7天Sonnet窗口
}
// ClaudeUsageResponse Anthropic API返回的usage结构
type ClaudeUsageResponse struct {
FiveHour struct {
Utilization float64 `json:"utilization"`
ResetsAt string `json:"resets_at"`
} `json:"five_hour"`
SevenDay struct {
Utilization float64 `json:"utilization"`
ResetsAt string `json:"resets_at"`
} `json:"seven_day"`
SevenDaySonnet struct {
Utilization float64 `json:"utilization"`
ResetsAt string `json:"resets_at"`
} `json:"seven_day_sonnet"`
}
// AccountUsageService 账号使用量查询服务
type AccountUsageService struct {
repos *repository.Repositories
oauthService *OAuthService
httpClient *http.Client
}
// NewAccountUsageService 创建AccountUsageService实例
func NewAccountUsageService(repos *repository.Repositories, oauthService *OAuthService) *AccountUsageService {
return &AccountUsageService{
repos: repos,
oauthService: oauthService,
httpClient: &http.Client{
Timeout: 30 * time.Second,
},
}
}
// GetUsage 获取账号使用量
// OAuth账号: 调用Anthropic API获取真实数据需要profile scope缓存10分钟
// Setup Token账号: 根据session_window推算5h窗口7d数据不可用没有profile scope
// API Key账号: 不支持usage查询
func (s *AccountUsageService) GetUsage(ctx context.Context, accountID int64) (*UsageInfo, error) {
account, err := s.repos.Account.GetByID(ctx, accountID)
if err != nil {
return nil, fmt.Errorf("get account failed: %w", err)
}
// 只有oauth类型账号可以通过API获取usage有profile scope
if account.CanGetUsage() {
// 检查缓存
if cached, ok := usageCacheMap.Load(accountID); ok {
cache := cached.(*usageCache)
if time.Since(cache.timestamp) < cacheTTL {
return cache.data, nil
}
}
// 从API获取数据
usage, err := s.fetchOAuthUsage(ctx, account)
if err != nil {
return nil, err
}
// 添加5h窗口统计数据
s.addWindowStats(ctx, account, usage)
// 缓存结果
usageCacheMap.Store(accountID, &usageCache{
data: usage,
timestamp: time.Now(),
})
return usage, nil
}
// Setup Token账号根据session_window推算没有profile scope无法调用usage API
if account.Type == model.AccountTypeSetupToken {
usage := s.estimateSetupTokenUsage(account)
// 添加窗口统计
s.addWindowStats(ctx, account, usage)
return usage, nil
}
// API Key账号不支持usage查询
return nil, fmt.Errorf("account type %s does not support usage query", account.Type)
}
// addWindowStats 为usage数据添加窗口期统计
func (s *AccountUsageService) addWindowStats(ctx context.Context, account *model.Account, usage *UsageInfo) {
if usage.FiveHour == nil {
return
}
// 使用session_window_start作为统计起始时间
var startTime time.Time
if account.SessionWindowStart != nil {
startTime = *account.SessionWindowStart
} else {
// 如果没有窗口信息使用5小时前作为默认
startTime = time.Now().Add(-5 * time.Hour)
}
stats, err := s.repos.UsageLog.GetAccountWindowStats(ctx, account.ID, startTime)
if err != nil {
log.Printf("Failed to get window stats for account %d: %v", account.ID, err)
return
}
usage.FiveHour.WindowStats = &WindowStats{
Requests: stats.Requests,
Tokens: stats.Tokens,
Cost: stats.Cost,
}
}
// GetTodayStats 获取账号今日统计
func (s *AccountUsageService) GetTodayStats(ctx context.Context, accountID int64) (*WindowStats, error) {
stats, err := s.repos.UsageLog.GetAccountTodayStats(ctx, accountID)
if err != nil {
return nil, fmt.Errorf("get today stats failed: %w", err)
}
return &WindowStats{
Requests: stats.Requests,
Tokens: stats.Tokens,
Cost: stats.Cost,
}, nil
}
// fetchOAuthUsage 从Anthropic API获取OAuth账号的使用量
func (s *AccountUsageService) fetchOAuthUsage(ctx context.Context, account *model.Account) (*UsageInfo, error) {
// 获取access token从credentials中获取
accessToken := account.GetCredential("access_token")
if accessToken == "" {
return nil, fmt.Errorf("no access token available")
}
// 获取代理配置
transport := http.DefaultTransport.(*http.Transport).Clone()
if account.ProxyID != nil && account.Proxy != nil {
proxyURL := account.Proxy.URL()
if proxyURL != "" {
if parsedURL, err := url.Parse(proxyURL); err == nil {
transport.Proxy = http.ProxyURL(parsedURL)
}
}
}
client := &http.Client{
Transport: transport,
Timeout: 30 * time.Second,
}
// 构建请求
req, err := http.NewRequestWithContext(ctx, "GET", "https://api.anthropic.com/api/oauth/usage", nil)
if err != nil {
return nil, fmt.Errorf("create request failed: %w", err)
}
req.Header.Set("Authorization", "Bearer "+accessToken)
req.Header.Set("anthropic-beta", "oauth-2025-04-20")
// 发送请求
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("request failed: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
return nil, fmt.Errorf("API returned status %d: %s", resp.StatusCode, string(body))
}
// 解析响应
var usageResp ClaudeUsageResponse
if err := json.NewDecoder(resp.Body).Decode(&usageResp); err != nil {
return nil, fmt.Errorf("decode response failed: %w", err)
}
// 转换为UsageInfo
now := time.Now()
return s.buildUsageInfo(&usageResp, &now), nil
}
// parseTime 尝试多种格式解析时间
func parseTime(s string) (time.Time, error) {
formats := []string{
time.RFC3339,
time.RFC3339Nano,
"2006-01-02T15:04:05Z",
"2006-01-02T15:04:05.000Z",
}
for _, format := range formats {
if t, err := time.Parse(format, s); err == nil {
return t, nil
}
}
return time.Time{}, fmt.Errorf("unable to parse time: %s", s)
}
// buildUsageInfo 构建UsageInfo
func (s *AccountUsageService) buildUsageInfo(resp *ClaudeUsageResponse, updatedAt *time.Time) *UsageInfo {
info := &UsageInfo{
UpdatedAt: updatedAt,
}
// 5小时窗口
if resp.FiveHour.ResetsAt != "" {
if fiveHourReset, err := parseTime(resp.FiveHour.ResetsAt); err == nil {
info.FiveHour = &UsageProgress{
Utilization: resp.FiveHour.Utilization,
ResetsAt: &fiveHourReset,
RemainingSeconds: int(time.Until(fiveHourReset).Seconds()),
}
} else {
log.Printf("Failed to parse FiveHour.ResetsAt: %s, error: %v", resp.FiveHour.ResetsAt, err)
// 即使解析失败也返回utilization
info.FiveHour = &UsageProgress{
Utilization: resp.FiveHour.Utilization,
}
}
}
// 7天窗口
if resp.SevenDay.ResetsAt != "" {
if sevenDayReset, err := parseTime(resp.SevenDay.ResetsAt); err == nil {
info.SevenDay = &UsageProgress{
Utilization: resp.SevenDay.Utilization,
ResetsAt: &sevenDayReset,
RemainingSeconds: int(time.Until(sevenDayReset).Seconds()),
}
} else {
log.Printf("Failed to parse SevenDay.ResetsAt: %s, error: %v", resp.SevenDay.ResetsAt, err)
info.SevenDay = &UsageProgress{
Utilization: resp.SevenDay.Utilization,
}
}
}
// 7天Sonnet窗口
if resp.SevenDaySonnet.ResetsAt != "" {
if sonnetReset, err := parseTime(resp.SevenDaySonnet.ResetsAt); err == nil {
info.SevenDaySonnet = &UsageProgress{
Utilization: resp.SevenDaySonnet.Utilization,
ResetsAt: &sonnetReset,
RemainingSeconds: int(time.Until(sonnetReset).Seconds()),
}
} else {
log.Printf("Failed to parse SevenDaySonnet.ResetsAt: %s, error: %v", resp.SevenDaySonnet.ResetsAt, err)
info.SevenDaySonnet = &UsageProgress{
Utilization: resp.SevenDaySonnet.Utilization,
}
}
}
return info
}
// estimateSetupTokenUsage 根据session_window推算Setup Token账号的使用量
func (s *AccountUsageService) estimateSetupTokenUsage(account *model.Account) *UsageInfo {
info := &UsageInfo{}
// 如果有session_window信息
if account.SessionWindowEnd != nil {
remaining := int(time.Until(*account.SessionWindowEnd).Seconds())
if remaining < 0 {
remaining = 0
}
// 根据状态估算使用率 (百分比形式100 = 100%)
var utilization float64
switch account.SessionWindowStatus {
case "rejected":
utilization = 100.0
case "allowed_warning":
utilization = 80.0
default:
utilization = 0.0
}
info.FiveHour = &UsageProgress{
Utilization: utilization,
ResetsAt: account.SessionWindowEnd,
RemainingSeconds: remaining,
}
} else {
// 没有窗口信息,返回空数据
info.FiveHour = &UsageProgress{
Utilization: 0,
RemainingSeconds: 0,
}
}
// Setup Token无法获取7d数据
return info
}

View File

@@ -0,0 +1,989 @@
package service
import (
"context"
"crypto/tls"
"encoding/json"
"errors"
"fmt"
"io"
"net"
"net/http"
"net/url"
"time"
"sub2api/internal/model"
"sub2api/internal/repository"
"golang.org/x/net/proxy"
"gorm.io/gorm"
)
// AdminService interface defines admin management operations
type AdminService interface {
// User management
ListUsers(ctx context.Context, page, pageSize int, status, role, search string) ([]model.User, int64, error)
GetUser(ctx context.Context, id int64) (*model.User, error)
CreateUser(ctx context.Context, input *CreateUserInput) (*model.User, error)
UpdateUser(ctx context.Context, id int64, input *UpdateUserInput) (*model.User, error)
DeleteUser(ctx context.Context, id int64) error
UpdateUserBalance(ctx context.Context, userID int64, balance float64, operation string) (*model.User, error)
GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int) ([]model.ApiKey, int64, error)
GetUserUsageStats(ctx context.Context, userID int64, period string) (interface{}, error)
// Group management
ListGroups(ctx context.Context, page, pageSize int, platform, status string, isExclusive *bool) ([]model.Group, int64, error)
GetAllGroups(ctx context.Context) ([]model.Group, error)
GetAllGroupsByPlatform(ctx context.Context, platform string) ([]model.Group, error)
GetGroup(ctx context.Context, id int64) (*model.Group, error)
CreateGroup(ctx context.Context, input *CreateGroupInput) (*model.Group, error)
UpdateGroup(ctx context.Context, id int64, input *UpdateGroupInput) (*model.Group, error)
DeleteGroup(ctx context.Context, id int64) error
GetGroupAPIKeys(ctx context.Context, groupID int64, page, pageSize int) ([]model.ApiKey, int64, error)
// Account management
ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string) ([]model.Account, int64, error)
GetAccount(ctx context.Context, id int64) (*model.Account, error)
CreateAccount(ctx context.Context, input *CreateAccountInput) (*model.Account, error)
UpdateAccount(ctx context.Context, id int64, input *UpdateAccountInput) (*model.Account, error)
DeleteAccount(ctx context.Context, id int64) error
RefreshAccountCredentials(ctx context.Context, id int64) (*model.Account, error)
ClearAccountError(ctx context.Context, id int64) (*model.Account, error)
SetAccountSchedulable(ctx context.Context, id int64, schedulable bool) (*model.Account, error)
// Proxy management
ListProxies(ctx context.Context, page, pageSize int, protocol, status, search string) ([]model.Proxy, int64, error)
GetAllProxies(ctx context.Context) ([]model.Proxy, error)
GetAllProxiesWithAccountCount(ctx context.Context) ([]model.ProxyWithAccountCount, error)
GetProxy(ctx context.Context, id int64) (*model.Proxy, error)
CreateProxy(ctx context.Context, input *CreateProxyInput) (*model.Proxy, error)
UpdateProxy(ctx context.Context, id int64, input *UpdateProxyInput) (*model.Proxy, error)
DeleteProxy(ctx context.Context, id int64) error
GetProxyAccounts(ctx context.Context, proxyID int64, page, pageSize int) ([]model.Account, int64, error)
CheckProxyExists(ctx context.Context, host string, port int, username, password string) (bool, error)
TestProxy(ctx context.Context, id int64) (*ProxyTestResult, error)
// Redeem code management
ListRedeemCodes(ctx context.Context, page, pageSize int, codeType, status, search string) ([]model.RedeemCode, int64, error)
GetRedeemCode(ctx context.Context, id int64) (*model.RedeemCode, error)
GenerateRedeemCodes(ctx context.Context, input *GenerateRedeemCodesInput) ([]model.RedeemCode, error)
DeleteRedeemCode(ctx context.Context, id int64) error
BatchDeleteRedeemCodes(ctx context.Context, ids []int64) (int64, error)
ExpireRedeemCode(ctx context.Context, id int64) (*model.RedeemCode, error)
}
// Input types for admin operations
type CreateUserInput struct {
Email string
Password string
Balance float64
Concurrency int
AllowedGroups []int64
}
type UpdateUserInput struct {
Email string
Password string
Balance *float64 // 使用指针区分"未提供"和"设置为0"
Concurrency *int // 使用指针区分"未提供"和"设置为0"
Status string
AllowedGroups *[]int64 // 使用指针区分"未提供"和"设置为空数组"
}
type CreateGroupInput struct {
Name string
Description string
Platform string
RateMultiplier float64
IsExclusive bool
SubscriptionType string // standard/subscription
DailyLimitUSD *float64 // 日限额 (USD)
WeeklyLimitUSD *float64 // 周限额 (USD)
MonthlyLimitUSD *float64 // 月限额 (USD)
}
type UpdateGroupInput struct {
Name string
Description string
Platform string
RateMultiplier *float64 // 使用指针以支持设置为0
IsExclusive *bool
Status string
SubscriptionType string // standard/subscription
DailyLimitUSD *float64 // 日限额 (USD)
WeeklyLimitUSD *float64 // 周限额 (USD)
MonthlyLimitUSD *float64 // 月限额 (USD)
}
type CreateAccountInput struct {
Name string
Platform string
Type string
Credentials map[string]interface{}
Extra map[string]interface{}
ProxyID *int64
Concurrency int
Priority int
GroupIDs []int64
}
type UpdateAccountInput struct {
Name string
Type string // Account type: oauth, setup-token, apikey
Credentials map[string]interface{}
Extra map[string]interface{}
ProxyID *int64
Concurrency *int // 使用指针区分"未提供"和"设置为0"
Priority *int // 使用指针区分"未提供"和"设置为0"
Status string
GroupIDs *[]int64
}
type CreateProxyInput struct {
Name string
Protocol string
Host string
Port int
Username string
Password string
}
type UpdateProxyInput struct {
Name string
Protocol string
Host string
Port int
Username string
Password string
Status string
}
type GenerateRedeemCodesInput struct {
Count int
Type string
Value float64
GroupID *int64 // 订阅类型专用关联的分组ID
ValidityDays int // 订阅类型专用:有效天数
}
// ProxyTestResult represents the result of testing a proxy
type ProxyTestResult struct {
Success bool `json:"success"`
Message string `json:"message"`
LatencyMs int64 `json:"latency_ms,omitempty"`
IPAddress string `json:"ip_address,omitempty"`
City string `json:"city,omitempty"`
Region string `json:"region,omitempty"`
Country string `json:"country,omitempty"`
}
// adminServiceImpl implements AdminService
type adminServiceImpl struct {
userRepo *repository.UserRepository
groupRepo *repository.GroupRepository
accountRepo *repository.AccountRepository
proxyRepo *repository.ProxyRepository
apiKeyRepo *repository.ApiKeyRepository
redeemCodeRepo *repository.RedeemCodeRepository
usageLogRepo *repository.UsageLogRepository
userSubRepo *repository.UserSubscriptionRepository
billingCacheService *BillingCacheService
}
// NewAdminService creates a new AdminService
func NewAdminService(repos *repository.Repositories) AdminService {
return &adminServiceImpl{
userRepo: repos.User,
groupRepo: repos.Group,
accountRepo: repos.Account,
proxyRepo: repos.Proxy,
apiKeyRepo: repos.ApiKey,
redeemCodeRepo: repos.RedeemCode,
usageLogRepo: repos.UsageLog,
userSubRepo: repos.UserSubscription,
}
}
// SetBillingCacheService 设置计费缓存服务(用于缓存失效)
// 注意AdminService是接口需要类型断言
func SetAdminServiceBillingCache(adminService AdminService, billingCacheService *BillingCacheService) {
if impl, ok := adminService.(*adminServiceImpl); ok {
impl.billingCacheService = billingCacheService
}
}
// User management implementations
func (s *adminServiceImpl) ListUsers(ctx context.Context, page, pageSize int, status, role, search string) ([]model.User, int64, error) {
params := repository.PaginationParams{Page: page, PageSize: pageSize}
users, result, err := s.userRepo.ListWithFilters(ctx, params, status, role, search)
if err != nil {
return nil, 0, err
}
return users, result.Total, nil
}
func (s *adminServiceImpl) GetUser(ctx context.Context, id int64) (*model.User, error) {
return s.userRepo.GetByID(ctx, id)
}
func (s *adminServiceImpl) CreateUser(ctx context.Context, input *CreateUserInput) (*model.User, error) {
user := &model.User{
Email: input.Email,
Role: "user", // Always create as regular user, never admin
Balance: input.Balance,
Concurrency: input.Concurrency,
Status: model.StatusActive,
}
if err := user.SetPassword(input.Password); err != nil {
return nil, err
}
if err := s.userRepo.Create(ctx, user); err != nil {
return nil, err
}
return user, nil
}
func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *UpdateUserInput) (*model.User, error) {
user, err := s.userRepo.GetByID(ctx, id)
if err != nil {
return nil, err
}
// Protect admin users: cannot disable admin accounts
if user.Role == "admin" && input.Status == "disabled" {
return nil, errors.New("cannot disable admin user")
}
// Track balance and concurrency changes for logging
oldBalance := user.Balance
oldConcurrency := user.Concurrency
if input.Email != "" {
user.Email = input.Email
}
if input.Password != "" {
if err := user.SetPassword(input.Password); err != nil {
return nil, err
}
}
// Role is not allowed to be changed via API to prevent privilege escalation
if input.Status != "" {
user.Status = input.Status
}
// 只在指针非 nil 时更新 Balance支持设置为 0
if input.Balance != nil {
user.Balance = *input.Balance
}
// 只在指针非 nil 时更新 Concurrency支持设置为任意值
if input.Concurrency != nil {
user.Concurrency = *input.Concurrency
}
// 只在指针非 nil 时更新 AllowedGroups
if input.AllowedGroups != nil {
user.AllowedGroups = *input.AllowedGroups
}
if err := s.userRepo.Update(ctx, user); err != nil {
return nil, err
}
// 余额变化时失效缓存
if input.Balance != nil && *input.Balance != oldBalance {
if s.billingCacheService != nil {
go func() {
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
s.billingCacheService.InvalidateUserBalance(cacheCtx, id)
}()
}
}
// Create adjustment records for balance/concurrency changes
balanceDiff := user.Balance - oldBalance
if balanceDiff != 0 {
adjustmentRecord := &model.RedeemCode{
Code: model.GenerateRedeemCode(),
Type: model.AdjustmentTypeAdminBalance,
Value: balanceDiff,
Status: model.StatusUsed,
UsedBy: &user.ID,
}
now := time.Now()
adjustmentRecord.UsedAt = &now
if err := s.redeemCodeRepo.Create(ctx, adjustmentRecord); err != nil {
// Log error but don't fail the update
// The user update has already succeeded
}
}
concurrencyDiff := user.Concurrency - oldConcurrency
if concurrencyDiff != 0 {
adjustmentRecord := &model.RedeemCode{
Code: model.GenerateRedeemCode(),
Type: model.AdjustmentTypeAdminConcurrency,
Value: float64(concurrencyDiff),
Status: model.StatusUsed,
UsedBy: &user.ID,
}
now := time.Now()
adjustmentRecord.UsedAt = &now
if err := s.redeemCodeRepo.Create(ctx, adjustmentRecord); err != nil {
// Log error but don't fail the update
// The user update has already succeeded
}
}
return user, nil
}
func (s *adminServiceImpl) DeleteUser(ctx context.Context, id int64) error {
// Protect admin users: cannot delete admin accounts
user, err := s.userRepo.GetByID(ctx, id)
if err != nil {
return err
}
if user.Role == "admin" {
return errors.New("cannot delete admin user")
}
return s.userRepo.Delete(ctx, id)
}
func (s *adminServiceImpl) UpdateUserBalance(ctx context.Context, userID int64, balance float64, operation string) (*model.User, error) {
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
return nil, err
}
switch operation {
case "set":
user.Balance = balance
case "add":
user.Balance += balance
case "subtract":
user.Balance -= balance
}
if err := s.userRepo.Update(ctx, user); err != nil {
return nil, err
}
// 失效余额缓存
if s.billingCacheService != nil {
go func() {
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
s.billingCacheService.InvalidateUserBalance(cacheCtx, userID)
}()
}
return user, nil
}
func (s *adminServiceImpl) GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int) ([]model.ApiKey, int64, error) {
params := repository.PaginationParams{Page: page, PageSize: pageSize}
keys, result, err := s.apiKeyRepo.ListByUserID(ctx, userID, params)
if err != nil {
return nil, 0, err
}
return keys, result.Total, nil
}
func (s *adminServiceImpl) GetUserUsageStats(ctx context.Context, userID int64, period string) (interface{}, error) {
// Return mock data for now
return map[string]interface{}{
"period": period,
"total_requests": 0,
"total_cost": 0.0,
"total_tokens": 0,
"avg_duration_ms": 0,
}, nil
}
// Group management implementations
func (s *adminServiceImpl) ListGroups(ctx context.Context, page, pageSize int, platform, status string, isExclusive *bool) ([]model.Group, int64, error) {
params := repository.PaginationParams{Page: page, PageSize: pageSize}
groups, result, err := s.groupRepo.ListWithFilters(ctx, params, platform, status, isExclusive)
if err != nil {
return nil, 0, err
}
return groups, result.Total, nil
}
func (s *adminServiceImpl) GetAllGroups(ctx context.Context) ([]model.Group, error) {
return s.groupRepo.ListActive(ctx)
}
func (s *adminServiceImpl) GetAllGroupsByPlatform(ctx context.Context, platform string) ([]model.Group, error) {
return s.groupRepo.ListActiveByPlatform(ctx, platform)
}
func (s *adminServiceImpl) GetGroup(ctx context.Context, id int64) (*model.Group, error) {
return s.groupRepo.GetByID(ctx, id)
}
func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupInput) (*model.Group, error) {
platform := input.Platform
if platform == "" {
platform = model.PlatformAnthropic
}
subscriptionType := input.SubscriptionType
if subscriptionType == "" {
subscriptionType = model.SubscriptionTypeStandard
}
group := &model.Group{
Name: input.Name,
Description: input.Description,
Platform: platform,
RateMultiplier: input.RateMultiplier,
IsExclusive: input.IsExclusive,
Status: model.StatusActive,
SubscriptionType: subscriptionType,
DailyLimitUSD: input.DailyLimitUSD,
WeeklyLimitUSD: input.WeeklyLimitUSD,
MonthlyLimitUSD: input.MonthlyLimitUSD,
}
if err := s.groupRepo.Create(ctx, group); err != nil {
return nil, err
}
return group, nil
}
func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *UpdateGroupInput) (*model.Group, error) {
group, err := s.groupRepo.GetByID(ctx, id)
if err != nil {
return nil, err
}
if input.Name != "" {
group.Name = input.Name
}
if input.Description != "" {
group.Description = input.Description
}
if input.Platform != "" {
group.Platform = input.Platform
}
if input.RateMultiplier != nil {
group.RateMultiplier = *input.RateMultiplier
}
if input.IsExclusive != nil {
group.IsExclusive = *input.IsExclusive
}
if input.Status != "" {
group.Status = input.Status
}
// 订阅相关字段
if input.SubscriptionType != "" {
group.SubscriptionType = input.SubscriptionType
}
// 限额字段支持设置为nil清除限额或具体值
if input.DailyLimitUSD != nil {
group.DailyLimitUSD = input.DailyLimitUSD
}
if input.WeeklyLimitUSD != nil {
group.WeeklyLimitUSD = input.WeeklyLimitUSD
}
if input.MonthlyLimitUSD != nil {
group.MonthlyLimitUSD = input.MonthlyLimitUSD
}
if err := s.groupRepo.Update(ctx, group); err != nil {
return nil, err
}
return group, nil
}
func (s *adminServiceImpl) DeleteGroup(ctx context.Context, id int64) error {
// 先获取分组信息,检查是否存在
group, err := s.groupRepo.GetByID(ctx, id)
if err != nil {
return fmt.Errorf("group not found: %w", err)
}
// 订阅类型分组先获取受影响的用户ID列表用于事务后失效缓存
var affectedUserIDs []int64
if group.IsSubscriptionType() && s.billingCacheService != nil {
var subscriptions []model.UserSubscription
if err := s.groupRepo.DB().WithContext(ctx).
Where("group_id = ?", id).
Select("user_id").
Find(&subscriptions).Error; err == nil {
for _, sub := range subscriptions {
affectedUserIDs = append(affectedUserIDs, sub.UserID)
}
}
}
// 使用事务处理所有级联删除
db := s.groupRepo.DB()
err = db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
// 1. 如果是订阅类型分组,删除 user_subscriptions 中的相关记录
if group.IsSubscriptionType() {
if err := tx.Where("group_id = ?", id).Delete(&model.UserSubscription{}).Error; err != nil {
return fmt.Errorf("delete user subscriptions: %w", err)
}
}
// 2. 将 api_keys 中绑定该分组的 group_id 设为 nil任何类型的分组都需要
if err := tx.Model(&model.ApiKey{}).Where("group_id = ?", id).Update("group_id", nil).Error; err != nil {
return fmt.Errorf("clear api key group_id: %w", err)
}
// 3. 从 users.allowed_groups 数组中移除该分组 ID
if err := tx.Model(&model.User{}).
Where("? = ANY(allowed_groups)", id).
Update("allowed_groups", gorm.Expr("array_remove(allowed_groups, ?)", id)).Error; err != nil {
return fmt.Errorf("remove from allowed_groups: %w", err)
}
// 4. 删除 account_groups 中间表的数据
if err := tx.Where("group_id = ?", id).Delete(&model.AccountGroup{}).Error; err != nil {
return fmt.Errorf("delete account groups: %w", err)
}
// 5. 删除分组本身
if err := tx.Delete(&model.Group{}, id).Error; err != nil {
return fmt.Errorf("delete group: %w", err)
}
return nil
})
if err != nil {
return err
}
// 事务成功后,异步失效受影响用户的订阅缓存
if len(affectedUserIDs) > 0 && s.billingCacheService != nil {
groupID := id
go func() {
cacheCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
for _, userID := range affectedUserIDs {
s.billingCacheService.InvalidateSubscription(cacheCtx, userID, groupID)
}
}()
}
return nil
}
func (s *adminServiceImpl) GetGroupAPIKeys(ctx context.Context, groupID int64, page, pageSize int) ([]model.ApiKey, int64, error) {
params := repository.PaginationParams{Page: page, PageSize: pageSize}
keys, result, err := s.apiKeyRepo.ListByGroupID(ctx, groupID, params)
if err != nil {
return nil, 0, err
}
return keys, result.Total, nil
}
// Account management implementations
func (s *adminServiceImpl) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string) ([]model.Account, int64, error) {
params := repository.PaginationParams{Page: page, PageSize: pageSize}
accounts, result, err := s.accountRepo.ListWithFilters(ctx, params, platform, accountType, status, search)
if err != nil {
return nil, 0, err
}
return accounts, result.Total, nil
}
func (s *adminServiceImpl) GetAccount(ctx context.Context, id int64) (*model.Account, error) {
return s.accountRepo.GetByID(ctx, id)
}
func (s *adminServiceImpl) CreateAccount(ctx context.Context, input *CreateAccountInput) (*model.Account, error) {
account := &model.Account{
Name: input.Name,
Platform: input.Platform,
Type: input.Type,
Credentials: model.JSONB(input.Credentials),
Extra: model.JSONB(input.Extra),
ProxyID: input.ProxyID,
Concurrency: input.Concurrency,
Priority: input.Priority,
Status: model.StatusActive,
}
if err := s.accountRepo.Create(ctx, account); err != nil {
return nil, err
}
// 绑定分组
if len(input.GroupIDs) > 0 {
if err := s.accountRepo.BindGroups(ctx, account.ID, input.GroupIDs); err != nil {
return nil, err
}
}
return account, nil
}
func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *UpdateAccountInput) (*model.Account, error) {
account, err := s.accountRepo.GetByID(ctx, id)
if err != nil {
return nil, err
}
if input.Name != "" {
account.Name = input.Name
}
if input.Type != "" {
account.Type = input.Type
}
if input.Credentials != nil && len(input.Credentials) > 0 {
account.Credentials = model.JSONB(input.Credentials)
}
if input.Extra != nil && len(input.Extra) > 0 {
account.Extra = model.JSONB(input.Extra)
}
if input.ProxyID != nil {
account.ProxyID = input.ProxyID
}
// 只在指针非 nil 时更新 Concurrency支持设置为 0
if input.Concurrency != nil {
account.Concurrency = *input.Concurrency
}
// 只在指针非 nil 时更新 Priority支持设置为 0
if input.Priority != nil {
account.Priority = *input.Priority
}
if input.Status != "" {
account.Status = input.Status
}
if err := s.accountRepo.Update(ctx, account); err != nil {
return nil, err
}
// 更新分组绑定
if input.GroupIDs != nil {
if err := s.accountRepo.BindGroups(ctx, account.ID, *input.GroupIDs); err != nil {
return nil, err
}
}
return account, nil
}
func (s *adminServiceImpl) DeleteAccount(ctx context.Context, id int64) error {
return s.accountRepo.Delete(ctx, id)
}
func (s *adminServiceImpl) RefreshAccountCredentials(ctx context.Context, id int64) (*model.Account, error) {
account, err := s.accountRepo.GetByID(ctx, id)
if err != nil {
return nil, err
}
// TODO: Implement refresh logic
return account, nil
}
func (s *adminServiceImpl) ClearAccountError(ctx context.Context, id int64) (*model.Account, error) {
account, err := s.accountRepo.GetByID(ctx, id)
if err != nil {
return nil, err
}
account.Status = model.StatusActive
account.ErrorMessage = ""
if err := s.accountRepo.Update(ctx, account); err != nil {
return nil, err
}
return account, nil
}
func (s *adminServiceImpl) SetAccountSchedulable(ctx context.Context, id int64, schedulable bool) (*model.Account, error) {
if err := s.accountRepo.SetSchedulable(ctx, id, schedulable); err != nil {
return nil, err
}
return s.accountRepo.GetByID(ctx, id)
}
// Proxy management implementations
func (s *adminServiceImpl) ListProxies(ctx context.Context, page, pageSize int, protocol, status, search string) ([]model.Proxy, int64, error) {
params := repository.PaginationParams{Page: page, PageSize: pageSize}
proxies, result, err := s.proxyRepo.ListWithFilters(ctx, params, protocol, status, search)
if err != nil {
return nil, 0, err
}
return proxies, result.Total, nil
}
func (s *adminServiceImpl) GetAllProxies(ctx context.Context) ([]model.Proxy, error) {
return s.proxyRepo.ListActive(ctx)
}
func (s *adminServiceImpl) GetAllProxiesWithAccountCount(ctx context.Context) ([]model.ProxyWithAccountCount, error) {
return s.proxyRepo.ListActiveWithAccountCount(ctx)
}
func (s *adminServiceImpl) GetProxy(ctx context.Context, id int64) (*model.Proxy, error) {
return s.proxyRepo.GetByID(ctx, id)
}
func (s *adminServiceImpl) CreateProxy(ctx context.Context, input *CreateProxyInput) (*model.Proxy, error) {
proxy := &model.Proxy{
Name: input.Name,
Protocol: input.Protocol,
Host: input.Host,
Port: input.Port,
Username: input.Username,
Password: input.Password,
Status: model.StatusActive,
}
if err := s.proxyRepo.Create(ctx, proxy); err != nil {
return nil, err
}
return proxy, nil
}
func (s *adminServiceImpl) UpdateProxy(ctx context.Context, id int64, input *UpdateProxyInput) (*model.Proxy, error) {
proxy, err := s.proxyRepo.GetByID(ctx, id)
if err != nil {
return nil, err
}
if input.Name != "" {
proxy.Name = input.Name
}
if input.Protocol != "" {
proxy.Protocol = input.Protocol
}
if input.Host != "" {
proxy.Host = input.Host
}
if input.Port != 0 {
proxy.Port = input.Port
}
if input.Username != "" {
proxy.Username = input.Username
}
if input.Password != "" {
proxy.Password = input.Password
}
if input.Status != "" {
proxy.Status = input.Status
}
if err := s.proxyRepo.Update(ctx, proxy); err != nil {
return nil, err
}
return proxy, nil
}
func (s *adminServiceImpl) DeleteProxy(ctx context.Context, id int64) error {
return s.proxyRepo.Delete(ctx, id)
}
func (s *adminServiceImpl) GetProxyAccounts(ctx context.Context, proxyID int64, page, pageSize int) ([]model.Account, int64, error) {
// Return mock data for now - would need a dedicated repository method
return []model.Account{}, 0, nil
}
func (s *adminServiceImpl) CheckProxyExists(ctx context.Context, host string, port int, username, password string) (bool, error) {
return s.proxyRepo.ExistsByHostPortAuth(ctx, host, port, username, password)
}
// Redeem code management implementations
func (s *adminServiceImpl) ListRedeemCodes(ctx context.Context, page, pageSize int, codeType, status, search string) ([]model.RedeemCode, int64, error) {
params := repository.PaginationParams{Page: page, PageSize: pageSize}
codes, result, err := s.redeemCodeRepo.ListWithFilters(ctx, params, codeType, status, search)
if err != nil {
return nil, 0, err
}
return codes, result.Total, nil
}
func (s *adminServiceImpl) GetRedeemCode(ctx context.Context, id int64) (*model.RedeemCode, error) {
return s.redeemCodeRepo.GetByID(ctx, id)
}
func (s *adminServiceImpl) GenerateRedeemCodes(ctx context.Context, input *GenerateRedeemCodesInput) ([]model.RedeemCode, error) {
// 如果是订阅类型,验证必须有 GroupID
if input.Type == model.RedeemTypeSubscription {
if input.GroupID == nil {
return nil, errors.New("group_id is required for subscription type")
}
// 验证分组存在且为订阅类型
group, err := s.groupRepo.GetByID(ctx, *input.GroupID)
if err != nil {
return nil, fmt.Errorf("group not found: %w", err)
}
if !group.IsSubscriptionType() {
return nil, errors.New("group must be subscription type")
}
}
codes := make([]model.RedeemCode, 0, input.Count)
for i := 0; i < input.Count; i++ {
code := model.RedeemCode{
Code: model.GenerateRedeemCode(),
Type: input.Type,
Value: input.Value,
Status: model.StatusUnused,
}
// 订阅类型专用字段
if input.Type == model.RedeemTypeSubscription {
code.GroupID = input.GroupID
code.ValidityDays = input.ValidityDays
if code.ValidityDays <= 0 {
code.ValidityDays = 30 // 默认30天
}
}
if err := s.redeemCodeRepo.Create(ctx, &code); err != nil {
return nil, err
}
codes = append(codes, code)
}
return codes, nil
}
func (s *adminServiceImpl) DeleteRedeemCode(ctx context.Context, id int64) error {
return s.redeemCodeRepo.Delete(ctx, id)
}
func (s *adminServiceImpl) BatchDeleteRedeemCodes(ctx context.Context, ids []int64) (int64, error) {
var deleted int64
for _, id := range ids {
if err := s.redeemCodeRepo.Delete(ctx, id); err == nil {
deleted++
}
}
return deleted, nil
}
func (s *adminServiceImpl) ExpireRedeemCode(ctx context.Context, id int64) (*model.RedeemCode, error) {
code, err := s.redeemCodeRepo.GetByID(ctx, id)
if err != nil {
return nil, err
}
code.Status = model.StatusExpired
if err := s.redeemCodeRepo.Update(ctx, code); err != nil {
return nil, err
}
return code, nil
}
func (s *adminServiceImpl) TestProxy(ctx context.Context, id int64) (*ProxyTestResult, error) {
proxy, err := s.proxyRepo.GetByID(ctx, id)
if err != nil {
return nil, err
}
return testProxyConnection(ctx, proxy)
}
// testProxyConnection tests proxy connectivity by requesting ipinfo.io/json
func testProxyConnection(ctx context.Context, proxy *model.Proxy) (*ProxyTestResult, error) {
proxyURL := proxy.URL()
// Create HTTP client with proxy
transport, err := createProxyTransport(proxyURL)
if err != nil {
return &ProxyTestResult{
Success: false,
Message: fmt.Sprintf("Failed to create proxy transport: %v", err),
}, nil
}
client := &http.Client{
Transport: transport,
Timeout: 15 * time.Second,
}
// Measure latency
startTime := time.Now()
req, err := http.NewRequestWithContext(ctx, "GET", "https://ipinfo.io/json", nil)
if err != nil {
return &ProxyTestResult{
Success: false,
Message: fmt.Sprintf("Failed to create request: %v", err),
}, nil
}
resp, err := client.Do(req)
if err != nil {
return &ProxyTestResult{
Success: false,
Message: fmt.Sprintf("Proxy connection failed: %v", err),
}, nil
}
defer resp.Body.Close()
latencyMs := time.Since(startTime).Milliseconds()
if resp.StatusCode != http.StatusOK {
return &ProxyTestResult{
Success: false,
Message: fmt.Sprintf("Request failed with status: %d", resp.StatusCode),
LatencyMs: latencyMs,
}, nil
}
// Parse ipinfo.io response
var ipInfo struct {
IP string `json:"ip"`
City string `json:"city"`
Region string `json:"region"`
Country string `json:"country"`
}
body, err := io.ReadAll(resp.Body)
if err != nil {
return &ProxyTestResult{
Success: true,
Message: "Proxy is accessible but failed to read response",
LatencyMs: latencyMs,
}, nil
}
if err := json.Unmarshal(body, &ipInfo); err != nil {
return &ProxyTestResult{
Success: true,
Message: "Proxy is accessible but failed to parse response",
LatencyMs: latencyMs,
}, nil
}
return &ProxyTestResult{
Success: true,
Message: "Proxy is accessible",
LatencyMs: latencyMs,
IPAddress: ipInfo.IP,
City: ipInfo.City,
Region: ipInfo.Region,
Country: ipInfo.Country,
}, nil
}
// createProxyTransport creates an HTTP transport with the given proxy URL
func createProxyTransport(proxyURL string) (*http.Transport, error) {
parsedURL, err := url.Parse(proxyURL)
if err != nil {
return nil, fmt.Errorf("invalid proxy URL: %w", err)
}
transport := &http.Transport{
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
}
switch parsedURL.Scheme {
case "http", "https":
transport.Proxy = http.ProxyURL(parsedURL)
case "socks5":
dialer, err := proxy.FromURL(parsedURL, proxy.Direct)
if err != nil {
return nil, fmt.Errorf("failed to create socks5 dialer: %w", err)
}
transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
return dialer.Dial(network, addr)
}
default:
return nil, fmt.Errorf("unsupported proxy protocol: %s", parsedURL.Scheme)
}
return transport, nil
}

View File

@@ -0,0 +1,464 @@
package service
import (
"context"
"crypto/rand"
"encoding/hex"
"errors"
"fmt"
"sub2api/internal/config"
"sub2api/internal/model"
"sub2api/internal/pkg/timezone"
"sub2api/internal/repository"
"time"
"github.com/redis/go-redis/v9"
"gorm.io/gorm"
)
var (
ErrApiKeyNotFound = errors.New("api key not found")
ErrGroupNotAllowed = errors.New("user is not allowed to bind this group")
ErrApiKeyExists = errors.New("api key already exists")
ErrApiKeyTooShort = errors.New("api key must be at least 16 characters")
ErrApiKeyInvalidChars = errors.New("api key can only contain letters, numbers, underscores, and hyphens")
ErrApiKeyRateLimited = errors.New("too many failed attempts, please try again later")
)
const (
apiKeyRateLimitKeyPrefix = "apikey:create_rate_limit:"
apiKeyMaxErrorsPerHour = 20
apiKeyRateLimitDuration = time.Hour
)
// CreateApiKeyRequest 创建API Key请求
type CreateApiKeyRequest struct {
Name string `json:"name"`
GroupID *int64 `json:"group_id"`
CustomKey *string `json:"custom_key"` // 可选的自定义key
}
// UpdateApiKeyRequest 更新API Key请求
type UpdateApiKeyRequest struct {
Name *string `json:"name"`
GroupID *int64 `json:"group_id"`
Status *string `json:"status"`
}
// ApiKeyService API Key服务
type ApiKeyService struct {
apiKeyRepo *repository.ApiKeyRepository
userRepo *repository.UserRepository
groupRepo *repository.GroupRepository
userSubRepo *repository.UserSubscriptionRepository
rdb *redis.Client
cfg *config.Config
}
// NewApiKeyService 创建API Key服务实例
func NewApiKeyService(
apiKeyRepo *repository.ApiKeyRepository,
userRepo *repository.UserRepository,
groupRepo *repository.GroupRepository,
userSubRepo *repository.UserSubscriptionRepository,
rdb *redis.Client,
cfg *config.Config,
) *ApiKeyService {
return &ApiKeyService{
apiKeyRepo: apiKeyRepo,
userRepo: userRepo,
groupRepo: groupRepo,
userSubRepo: userSubRepo,
rdb: rdb,
cfg: cfg,
}
}
// GenerateKey 生成随机API Key
func (s *ApiKeyService) GenerateKey() (string, error) {
// 生成32字节随机数据
bytes := make([]byte, 32)
if _, err := rand.Read(bytes); err != nil {
return "", fmt.Errorf("generate random bytes: %w", err)
}
// 转换为十六进制字符串并添加前缀
prefix := s.cfg.Default.ApiKeyPrefix
if prefix == "" {
prefix = "sk-"
}
key := prefix + hex.EncodeToString(bytes)
return key, nil
}
// ValidateCustomKey 验证自定义API Key格式
func (s *ApiKeyService) ValidateCustomKey(key string) error {
// 检查长度
if len(key) < 16 {
return ErrApiKeyTooShort
}
// 检查字符:只允许字母、数字、下划线、连字符
for _, c := range key {
if !((c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') ||
(c >= '0' && c <= '9') || c == '_' || c == '-') {
return ErrApiKeyInvalidChars
}
}
return nil
}
// checkApiKeyRateLimit 检查用户创建自定义Key的错误次数是否超限
func (s *ApiKeyService) checkApiKeyRateLimit(ctx context.Context, userID int64) error {
if s.rdb == nil {
return nil
}
key := fmt.Sprintf("%s%d", apiKeyRateLimitKeyPrefix, userID)
count, err := s.rdb.Get(ctx, key).Int()
if err != nil && !errors.Is(err, redis.Nil) {
// Redis 出错时不阻止用户操作
return nil
}
if count >= apiKeyMaxErrorsPerHour {
return ErrApiKeyRateLimited
}
return nil
}
// incrementApiKeyErrorCount 增加用户创建自定义Key的错误计数
func (s *ApiKeyService) incrementApiKeyErrorCount(ctx context.Context, userID int64) {
if s.rdb == nil {
return
}
key := fmt.Sprintf("%s%d", apiKeyRateLimitKeyPrefix, userID)
pipe := s.rdb.Pipeline()
pipe.Incr(ctx, key)
pipe.Expire(ctx, key, apiKeyRateLimitDuration)
_, _ = pipe.Exec(ctx)
}
// canUserBindGroup 检查用户是否可以绑定指定分组
// 对于订阅类型分组:检查用户是否有有效订阅
// 对于标准类型分组:使用原有的 AllowedGroups 和 IsExclusive 逻辑
func (s *ApiKeyService) canUserBindGroup(ctx context.Context, user *model.User, group *model.Group) bool {
// 订阅类型分组:需要有效订阅
if group.IsSubscriptionType() {
_, err := s.userSubRepo.GetActiveByUserIDAndGroupID(ctx, user.ID, group.ID)
return err == nil // 有有效订阅则允许
}
// 标准类型分组:使用原有逻辑
return user.CanBindGroup(group.ID, group.IsExclusive)
}
// Create 创建API Key
func (s *ApiKeyService) Create(ctx context.Context, userID int64, req CreateApiKeyRequest) (*model.ApiKey, error) {
// 验证用户存在
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrUserNotFound
}
return nil, fmt.Errorf("get user: %w", err)
}
// 验证分组权限(如果指定了分组)
if req.GroupID != nil {
group, err := s.groupRepo.GetByID(ctx, *req.GroupID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, errors.New("group not found")
}
return nil, fmt.Errorf("get group: %w", err)
}
// 检查用户是否可以绑定该分组
if !s.canUserBindGroup(ctx, user, group) {
return nil, ErrGroupNotAllowed
}
}
var key string
// 判断是否使用自定义Key
if req.CustomKey != nil && *req.CustomKey != "" {
// 检查限流仅对自定义key进行限流
if err := s.checkApiKeyRateLimit(ctx, userID); err != nil {
return nil, err
}
// 验证自定义Key格式
if err := s.ValidateCustomKey(*req.CustomKey); err != nil {
return nil, err
}
// 检查Key是否已存在
exists, err := s.apiKeyRepo.ExistsByKey(ctx, *req.CustomKey)
if err != nil {
return nil, fmt.Errorf("check key exists: %w", err)
}
if exists {
// Key已存在增加错误计数
s.incrementApiKeyErrorCount(ctx, userID)
return nil, ErrApiKeyExists
}
key = *req.CustomKey
} else {
// 生成随机API Key
var err error
key, err = s.GenerateKey()
if err != nil {
return nil, fmt.Errorf("generate key: %w", err)
}
}
// 创建API Key记录
apiKey := &model.ApiKey{
UserID: userID,
Key: key,
Name: req.Name,
GroupID: req.GroupID,
Status: model.StatusActive,
}
if err := s.apiKeyRepo.Create(ctx, apiKey); err != nil {
return nil, fmt.Errorf("create api key: %w", err)
}
return apiKey, nil
}
// List 获取用户的API Key列表
func (s *ApiKeyService) List(ctx context.Context, userID int64, params repository.PaginationParams) ([]model.ApiKey, *repository.PaginationResult, error) {
keys, pagination, err := s.apiKeyRepo.ListByUserID(ctx, userID, params)
if err != nil {
return nil, nil, fmt.Errorf("list api keys: %w", err)
}
return keys, pagination, nil
}
// GetByID 根据ID获取API Key
func (s *ApiKeyService) GetByID(ctx context.Context, id int64) (*model.ApiKey, error) {
apiKey, err := s.apiKeyRepo.GetByID(ctx, id)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrApiKeyNotFound
}
return nil, fmt.Errorf("get api key: %w", err)
}
return apiKey, nil
}
// GetByKey 根据Key字符串获取API Key用于认证
func (s *ApiKeyService) GetByKey(ctx context.Context, key string) (*model.ApiKey, error) {
// 尝试从Redis缓存获取
cacheKey := fmt.Sprintf("apikey:%s", key)
// 这里可以添加Redis缓存逻辑暂时直接查询数据库
apiKey, err := s.apiKeyRepo.GetByKey(ctx, key)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrApiKeyNotFound
}
return nil, fmt.Errorf("get api key: %w", err)
}
// 缓存到Redis可选TTL设置为5分钟
if s.rdb != nil {
// 这里可以序列化并缓存API Key
_ = cacheKey // 使用变量避免未使用错误
}
return apiKey, nil
}
// Update 更新API Key
func (s *ApiKeyService) Update(ctx context.Context, id int64, userID int64, req UpdateApiKeyRequest) (*model.ApiKey, error) {
apiKey, err := s.apiKeyRepo.GetByID(ctx, id)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrApiKeyNotFound
}
return nil, fmt.Errorf("get api key: %w", err)
}
// 验证所有权
if apiKey.UserID != userID {
return nil, ErrInsufficientPerms
}
// 更新字段
if req.Name != nil {
apiKey.Name = *req.Name
}
if req.GroupID != nil {
// 验证分组权限
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
return nil, fmt.Errorf("get user: %w", err)
}
group, err := s.groupRepo.GetByID(ctx, *req.GroupID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, errors.New("group not found")
}
return nil, fmt.Errorf("get group: %w", err)
}
if !s.canUserBindGroup(ctx, user, group) {
return nil, ErrGroupNotAllowed
}
apiKey.GroupID = req.GroupID
}
if req.Status != nil {
apiKey.Status = *req.Status
// 如果状态改变清除Redis缓存
if s.rdb != nil {
cacheKey := fmt.Sprintf("apikey:%s", apiKey.Key)
_ = s.rdb.Del(ctx, cacheKey)
}
}
if err := s.apiKeyRepo.Update(ctx, apiKey); err != nil {
return nil, fmt.Errorf("update api key: %w", err)
}
return apiKey, nil
}
// Delete 删除API Key
func (s *ApiKeyService) Delete(ctx context.Context, id int64, userID int64) error {
apiKey, err := s.apiKeyRepo.GetByID(ctx, id)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return ErrApiKeyNotFound
}
return fmt.Errorf("get api key: %w", err)
}
// 验证所有权
if apiKey.UserID != userID {
return ErrInsufficientPerms
}
// 清除Redis缓存
if s.rdb != nil {
cacheKey := fmt.Sprintf("apikey:%s", apiKey.Key)
_ = s.rdb.Del(ctx, cacheKey)
}
if err := s.apiKeyRepo.Delete(ctx, id); err != nil {
return fmt.Errorf("delete api key: %w", err)
}
return nil
}
// ValidateKey 验证API Key是否有效用于认证中间件
func (s *ApiKeyService) ValidateKey(ctx context.Context, key string) (*model.ApiKey, *model.User, error) {
// 获取API Key
apiKey, err := s.GetByKey(ctx, key)
if err != nil {
return nil, nil, err
}
// 检查API Key状态
if !apiKey.IsActive() {
return nil, nil, errors.New("api key is not active")
}
// 获取用户信息
user, err := s.userRepo.GetByID(ctx, apiKey.UserID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, nil, ErrUserNotFound
}
return nil, nil, fmt.Errorf("get user: %w", err)
}
// 检查用户状态
if !user.IsActive() {
return nil, nil, ErrUserNotActive
}
return apiKey, user, nil
}
// IncrementUsage 增加API Key使用次数可选用于统计
func (s *ApiKeyService) IncrementUsage(ctx context.Context, keyID int64) error {
// 使用Redis计数器
if s.rdb != nil {
cacheKey := fmt.Sprintf("apikey:usage:%d:%s", keyID, timezone.Now().Format("2006-01-02"))
if err := s.rdb.Incr(ctx, cacheKey).Err(); err != nil {
return fmt.Errorf("increment usage: %w", err)
}
// 设置24小时过期
_ = s.rdb.Expire(ctx, cacheKey, 24*time.Hour)
}
return nil
}
// GetAvailableGroups 获取用户有权限绑定的分组列表
// 返回用户可以选择的分组:
// - 标准类型分组:公开的(非专属)或用户被明确允许的
// - 订阅类型分组:用户有有效订阅的
func (s *ApiKeyService) GetAvailableGroups(ctx context.Context, userID int64) ([]model.Group, error) {
// 获取用户信息
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrUserNotFound
}
return nil, fmt.Errorf("get user: %w", err)
}
// 获取所有活跃分组
allGroups, err := s.groupRepo.ListActive(ctx)
if err != nil {
return nil, fmt.Errorf("list active groups: %w", err)
}
// 获取用户的所有有效订阅
activeSubscriptions, err := s.userSubRepo.ListActiveByUserID(ctx, userID)
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
return nil, fmt.Errorf("list active subscriptions: %w", err)
}
// 构建订阅分组 ID 集合
subscribedGroupIDs := make(map[int64]bool)
for _, sub := range activeSubscriptions {
subscribedGroupIDs[sub.GroupID] = true
}
// 过滤出用户有权限的分组
availableGroups := make([]model.Group, 0)
for _, group := range allGroups {
if s.canUserBindGroupInternal(user, &group, subscribedGroupIDs) {
availableGroups = append(availableGroups, group)
}
}
return availableGroups, nil
}
// canUserBindGroupInternal 内部方法,检查用户是否可以绑定分组(使用预加载的订阅数据)
func (s *ApiKeyService) canUserBindGroupInternal(user *model.User, group *model.Group, subscribedGroupIDs map[int64]bool) bool {
// 订阅类型分组:需要有效订阅
if group.IsSubscriptionType() {
return subscribedGroupIDs[group.ID]
}
// 标准类型分组:使用原有逻辑
return user.CanBindGroup(group.ID, group.IsExclusive)
}

View File

@@ -0,0 +1,376 @@
package service
import (
"context"
"errors"
"fmt"
"log"
"sub2api/internal/config"
"sub2api/internal/model"
"sub2api/internal/repository"
"time"
"github.com/golang-jwt/jwt/v5"
"golang.org/x/crypto/bcrypt"
"gorm.io/gorm"
)
var (
ErrInvalidCredentials = errors.New("invalid email or password")
ErrUserNotActive = errors.New("user is not active")
ErrEmailExists = errors.New("email already exists")
ErrInvalidToken = errors.New("invalid token")
ErrTokenExpired = errors.New("token has expired")
ErrEmailVerifyRequired = errors.New("email verification is required")
ErrRegDisabled = errors.New("registration is currently disabled")
)
// JWTClaims JWT载荷数据
type JWTClaims struct {
UserID int64 `json:"user_id"`
Email string `json:"email"`
Role string `json:"role"`
jwt.RegisteredClaims
}
// AuthService 认证服务
type AuthService struct {
userRepo *repository.UserRepository
cfg *config.Config
settingService *SettingService
emailService *EmailService
turnstileService *TurnstileService
emailQueueService *EmailQueueService
}
// NewAuthService 创建认证服务实例
func NewAuthService(userRepo *repository.UserRepository, cfg *config.Config) *AuthService {
return &AuthService{
userRepo: userRepo,
cfg: cfg,
}
}
// SetSettingService 设置系统设置服务(用于检查注册开关和邮件验证)
func (s *AuthService) SetSettingService(settingService *SettingService) {
s.settingService = settingService
}
// SetEmailService 设置邮件服务(用于邮件验证)
func (s *AuthService) SetEmailService(emailService *EmailService) {
s.emailService = emailService
}
// SetTurnstileService 设置Turnstile服务用于验证码校验
func (s *AuthService) SetTurnstileService(turnstileService *TurnstileService) {
s.turnstileService = turnstileService
}
// SetEmailQueueService 设置邮件队列服务(用于异步发送邮件)
func (s *AuthService) SetEmailQueueService(emailQueueService *EmailQueueService) {
s.emailQueueService = emailQueueService
}
// Register 用户注册返回token和用户
func (s *AuthService) Register(ctx context.Context, email, password string) (string, *model.User, error) {
return s.RegisterWithVerification(ctx, email, password, "")
}
// RegisterWithVerification 用户注册支持邮件验证返回token和用户
func (s *AuthService) RegisterWithVerification(ctx context.Context, email, password, verifyCode string) (string, *model.User, error) {
// 检查是否开放注册
if s.settingService != nil && !s.settingService.IsRegistrationEnabled(ctx) {
return "", nil, ErrRegDisabled
}
// 检查是否需要邮件验证
if s.settingService != nil && s.settingService.IsEmailVerifyEnabled(ctx) {
if verifyCode == "" {
return "", nil, ErrEmailVerifyRequired
}
// 验证邮箱验证码
if s.emailService != nil {
if err := s.emailService.VerifyCode(ctx, email, verifyCode); err != nil {
return "", nil, fmt.Errorf("verify code: %w", err)
}
}
}
// 检查邮箱是否已存在
existsEmail, err := s.userRepo.ExistsByEmail(ctx, email)
if err != nil {
return "", nil, fmt.Errorf("check email exists: %w", err)
}
if existsEmail {
return "", nil, ErrEmailExists
}
// 密码哈希
hashedPassword, err := s.HashPassword(password)
if err != nil {
return "", nil, fmt.Errorf("hash password: %w", err)
}
// 获取默认配置
defaultBalance := s.cfg.Default.UserBalance
defaultConcurrency := s.cfg.Default.UserConcurrency
if s.settingService != nil {
defaultBalance = s.settingService.GetDefaultBalance(ctx)
defaultConcurrency = s.settingService.GetDefaultConcurrency(ctx)
}
// 创建用户
user := &model.User{
Email: email,
PasswordHash: hashedPassword,
Role: model.RoleUser,
Balance: defaultBalance,
Concurrency: defaultConcurrency,
Status: model.StatusActive,
}
if err := s.userRepo.Create(ctx, user); err != nil {
return "", nil, fmt.Errorf("create user: %w", err)
}
// 生成token
token, err := s.GenerateToken(user)
if err != nil {
return "", nil, fmt.Errorf("generate token: %w", err)
}
return token, user, nil
}
// SendVerifyCodeResult 发送验证码返回结果
type SendVerifyCodeResult struct {
Countdown int `json:"countdown"` // 倒计时秒数
}
// SendVerifyCode 发送邮箱验证码(同步方式)
func (s *AuthService) SendVerifyCode(ctx context.Context, email string) error {
// 检查是否开放注册
if s.settingService != nil && !s.settingService.IsRegistrationEnabled(ctx) {
return ErrRegDisabled
}
// 检查邮箱是否已存在
existsEmail, err := s.userRepo.ExistsByEmail(ctx, email)
if err != nil {
return fmt.Errorf("check email exists: %w", err)
}
if existsEmail {
return ErrEmailExists
}
// 发送验证码
if s.emailService == nil {
return errors.New("email service not configured")
}
// 获取网站名称
siteName := "Sub2API"
if s.settingService != nil {
siteName = s.settingService.GetSiteName(ctx)
}
return s.emailService.SendVerifyCode(ctx, email, siteName)
}
// SendVerifyCodeAsync 异步发送邮箱验证码并返回倒计时
func (s *AuthService) SendVerifyCodeAsync(ctx context.Context, email string) (*SendVerifyCodeResult, error) {
log.Printf("[Auth] SendVerifyCodeAsync called for email: %s", email)
// 检查是否开放注册
if s.settingService != nil && !s.settingService.IsRegistrationEnabled(ctx) {
log.Println("[Auth] Registration is disabled")
return nil, ErrRegDisabled
}
// 检查邮箱是否已存在
existsEmail, err := s.userRepo.ExistsByEmail(ctx, email)
if err != nil {
log.Printf("[Auth] Error checking email exists: %v", err)
return nil, fmt.Errorf("check email exists: %w", err)
}
if existsEmail {
log.Printf("[Auth] Email already exists: %s", email)
return nil, ErrEmailExists
}
// 检查邮件队列服务是否配置
if s.emailQueueService == nil {
log.Println("[Auth] Email queue service not configured")
return nil, errors.New("email queue service not configured")
}
// 获取网站名称
siteName := "Sub2API"
if s.settingService != nil {
siteName = s.settingService.GetSiteName(ctx)
}
// 异步发送
log.Printf("[Auth] Enqueueing verify code for: %s", email)
if err := s.emailQueueService.EnqueueVerifyCode(email, siteName); err != nil {
log.Printf("[Auth] Failed to enqueue: %v", err)
return nil, fmt.Errorf("enqueue verify code: %w", err)
}
log.Printf("[Auth] Verify code enqueued successfully for: %s", email)
return &SendVerifyCodeResult{
Countdown: 60, // 60秒倒计时
}, nil
}
// VerifyTurnstile 验证Turnstile token
func (s *AuthService) VerifyTurnstile(ctx context.Context, token string, remoteIP string) error {
if s.turnstileService == nil {
return nil // 服务未配置则跳过验证
}
return s.turnstileService.VerifyToken(ctx, token, remoteIP)
}
// IsTurnstileEnabled 检查是否启用Turnstile验证
func (s *AuthService) IsTurnstileEnabled(ctx context.Context) bool {
if s.turnstileService == nil {
return false
}
return s.turnstileService.IsEnabled(ctx)
}
// IsRegistrationEnabled 检查是否开放注册
func (s *AuthService) IsRegistrationEnabled(ctx context.Context) bool {
if s.settingService == nil {
return true
}
return s.settingService.IsRegistrationEnabled(ctx)
}
// IsEmailVerifyEnabled 检查是否开启邮件验证
func (s *AuthService) IsEmailVerifyEnabled(ctx context.Context) bool {
if s.settingService == nil {
return false
}
return s.settingService.IsEmailVerifyEnabled(ctx)
}
// Login 用户登录返回JWT token
func (s *AuthService) Login(ctx context.Context, email, password string) (string, *model.User, error) {
// 查找用户
user, err := s.userRepo.GetByEmail(ctx, email)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return "", nil, ErrInvalidCredentials
}
return "", nil, fmt.Errorf("get user by email: %w", err)
}
// 验证密码
if !s.CheckPassword(password, user.PasswordHash) {
return "", nil, ErrInvalidCredentials
}
// 检查用户状态
if !user.IsActive() {
return "", nil, ErrUserNotActive
}
// 生成JWT token
token, err := s.GenerateToken(user)
if err != nil {
return "", nil, fmt.Errorf("generate token: %w", err)
}
return token, user, nil
}
// ValidateToken 验证JWT token并返回用户声明
func (s *AuthService) ValidateToken(tokenString string) (*JWTClaims, error) {
token, err := jwt.ParseWithClaims(tokenString, &JWTClaims{}, func(token *jwt.Token) (interface{}, error) {
// 验证签名方法
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
}
return []byte(s.cfg.JWT.Secret), nil
})
if err != nil {
if errors.Is(err, jwt.ErrTokenExpired) {
return nil, ErrTokenExpired
}
return nil, ErrInvalidToken
}
if claims, ok := token.Claims.(*JWTClaims); ok && token.Valid {
return claims, nil
}
return nil, ErrInvalidToken
}
// GenerateToken 生成JWT token
func (s *AuthService) GenerateToken(user *model.User) (string, error) {
now := time.Now()
expiresAt := now.Add(time.Duration(s.cfg.JWT.ExpireHour) * time.Hour)
claims := &JWTClaims{
UserID: user.ID,
Email: user.Email,
Role: user.Role,
RegisteredClaims: jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(expiresAt),
IssuedAt: jwt.NewNumericDate(now),
NotBefore: jwt.NewNumericDate(now),
},
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
tokenString, err := token.SignedString([]byte(s.cfg.JWT.Secret))
if err != nil {
return "", fmt.Errorf("sign token: %w", err)
}
return tokenString, nil
}
// HashPassword 使用bcrypt加密密码
func (s *AuthService) HashPassword(password string) (string, error) {
hashedBytes, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
if err != nil {
return "", err
}
return string(hashedBytes), nil
}
// CheckPassword 验证密码是否匹配
func (s *AuthService) CheckPassword(password, hashedPassword string) bool {
err := bcrypt.CompareHashAndPassword([]byte(hashedPassword), []byte(password))
return err == nil
}
// RefreshToken 刷新token
func (s *AuthService) RefreshToken(ctx context.Context, oldTokenString string) (string, error) {
// 验证旧token即使过期也允许用于刷新
claims, err := s.ValidateToken(oldTokenString)
if err != nil && !errors.Is(err, ErrTokenExpired) {
return "", err
}
// 获取最新的用户信息
user, err := s.userRepo.GetByID(ctx, claims.UserID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return "", ErrInvalidToken
}
return "", fmt.Errorf("get user: %w", err)
}
// 检查用户状态
if !user.IsActive() {
return "", ErrUserNotActive
}
// 生成新token
return s.GenerateToken(user)
}

View File

@@ -0,0 +1,422 @@
package service
import (
"context"
"errors"
"fmt"
"log"
"strconv"
"time"
"sub2api/internal/model"
"sub2api/internal/repository"
"github.com/redis/go-redis/v9"
)
// 缓存Key前缀和TTL
const (
billingBalanceKeyPrefix = "billing:balance:"
billingSubKeyPrefix = "billing:sub:"
billingCacheTTL = 5 * time.Minute
)
// 订阅缓存Hash字段
const (
subFieldStatus = "status"
subFieldExpiresAt = "expires_at"
subFieldDailyUsage = "daily_usage"
subFieldWeeklyUsage = "weekly_usage"
subFieldMonthlyUsage = "monthly_usage"
subFieldVersion = "version"
)
// 错误定义
// 注ErrInsufficientBalance在redeem_service.go中定义
// 注ErrDailyLimitExceeded/ErrWeeklyLimitExceeded/ErrMonthlyLimitExceeded在subscription_service.go中定义
var (
ErrSubscriptionInvalid = errors.New("subscription is invalid or expired")
)
// 预编译的Lua脚本
var (
// deductBalanceScript: 扣减余额缓存key不存在则忽略
deductBalanceScript = redis.NewScript(`
local current = redis.call('GET', KEYS[1])
if current == false then
return 0
end
local newVal = tonumber(current) - tonumber(ARGV[1])
redis.call('SET', KEYS[1], newVal)
redis.call('EXPIRE', KEYS[1], ARGV[2])
return 1
`)
// updateSubUsageScript: 更新订阅用量缓存key不存在则忽略
updateSubUsageScript = redis.NewScript(`
local exists = redis.call('EXISTS', KEYS[1])
if exists == 0 then
return 0
end
local cost = tonumber(ARGV[1])
redis.call('HINCRBYFLOAT', KEYS[1], 'daily_usage', cost)
redis.call('HINCRBYFLOAT', KEYS[1], 'weekly_usage', cost)
redis.call('HINCRBYFLOAT', KEYS[1], 'monthly_usage', cost)
redis.call('EXPIRE', KEYS[1], ARGV[2])
return 1
`)
)
// subscriptionCacheData 订阅缓存数据结构(内部使用)
type subscriptionCacheData struct {
Status string
ExpiresAt time.Time
DailyUsage float64
WeeklyUsage float64
MonthlyUsage float64
Version int64
}
// BillingCacheService 计费缓存服务
// 负责余额和订阅数据的缓存管理,提供高性能的计费资格检查
type BillingCacheService struct {
rdb *redis.Client
userRepo *repository.UserRepository
subRepo *repository.UserSubscriptionRepository
}
// NewBillingCacheService 创建计费缓存服务
func NewBillingCacheService(rdb *redis.Client, userRepo *repository.UserRepository, subRepo *repository.UserSubscriptionRepository) *BillingCacheService {
return &BillingCacheService{
rdb: rdb,
userRepo: userRepo,
subRepo: subRepo,
}
}
// ============================================
// 余额缓存方法
// ============================================
// GetUserBalance 获取用户余额(优先从缓存读取)
func (s *BillingCacheService) GetUserBalance(ctx context.Context, userID int64) (float64, error) {
if s.rdb == nil {
// Redis不可用直接查询数据库
return s.getUserBalanceFromDB(ctx, userID)
}
key := fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID)
// 尝试从缓存读取
val, err := s.rdb.Get(ctx, key).Result()
if err == nil {
balance, parseErr := strconv.ParseFloat(val, 64)
if parseErr == nil {
return balance, nil
}
}
// 缓存未命中或解析错误,从数据库读取
balance, err := s.getUserBalanceFromDB(ctx, userID)
if err != nil {
return 0, err
}
// 异步建立缓存
go func() {
cacheCtx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
s.setBalanceCache(cacheCtx, userID, balance)
}()
return balance, nil
}
// getUserBalanceFromDB 从数据库获取用户余额
func (s *BillingCacheService) getUserBalanceFromDB(ctx context.Context, userID int64) (float64, error) {
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
return 0, fmt.Errorf("get user balance: %w", err)
}
return user.Balance, nil
}
// setBalanceCache 设置余额缓存
func (s *BillingCacheService) setBalanceCache(ctx context.Context, userID int64, balance float64) {
if s.rdb == nil {
return
}
key := fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID)
if err := s.rdb.Set(ctx, key, balance, billingCacheTTL).Err(); err != nil {
log.Printf("Warning: set balance cache failed for user %d: %v", userID, err)
}
}
// DeductBalanceCache 扣减余额缓存(异步调用,用于扣费后更新缓存)
func (s *BillingCacheService) DeductBalanceCache(ctx context.Context, userID int64, amount float64) error {
if s.rdb == nil {
return nil
}
key := fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID)
// 使用预编译的Lua脚本原子性扣减如果key不存在则忽略
_, err := deductBalanceScript.Run(ctx, s.rdb, []string{key}, amount, int(billingCacheTTL.Seconds())).Result()
if err != nil && !errors.Is(err, redis.Nil) {
log.Printf("Warning: deduct balance cache failed for user %d: %v", userID, err)
}
return nil
}
// InvalidateUserBalance 失效用户余额缓存
func (s *BillingCacheService) InvalidateUserBalance(ctx context.Context, userID int64) error {
if s.rdb == nil {
return nil
}
key := fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID)
if err := s.rdb.Del(ctx, key).Err(); err != nil {
log.Printf("Warning: invalidate balance cache failed for user %d: %v", userID, err)
return err
}
return nil
}
// ============================================
// 订阅缓存方法
// ============================================
// GetSubscriptionStatus 获取订阅状态(优先从缓存读取)
func (s *BillingCacheService) GetSubscriptionStatus(ctx context.Context, userID, groupID int64) (*subscriptionCacheData, error) {
if s.rdb == nil {
return s.getSubscriptionFromDB(ctx, userID, groupID)
}
key := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID)
// 尝试从缓存读取
result, err := s.rdb.HGetAll(ctx, key).Result()
if err == nil && len(result) > 0 {
data, parseErr := s.parseSubscriptionCache(result)
if parseErr == nil {
return data, nil
}
}
// 缓存未命中,从数据库读取
data, err := s.getSubscriptionFromDB(ctx, userID, groupID)
if err != nil {
return nil, err
}
// 异步建立缓存
go func() {
cacheCtx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
s.setSubscriptionCache(cacheCtx, userID, groupID, data)
}()
return data, nil
}
// getSubscriptionFromDB 从数据库获取订阅数据
func (s *BillingCacheService) getSubscriptionFromDB(ctx context.Context, userID, groupID int64) (*subscriptionCacheData, error) {
sub, err := s.subRepo.GetActiveByUserIDAndGroupID(ctx, userID, groupID)
if err != nil {
return nil, fmt.Errorf("get subscription: %w", err)
}
return &subscriptionCacheData{
Status: sub.Status,
ExpiresAt: sub.ExpiresAt,
DailyUsage: sub.DailyUsageUSD,
WeeklyUsage: sub.WeeklyUsageUSD,
MonthlyUsage: sub.MonthlyUsageUSD,
Version: sub.UpdatedAt.Unix(),
}, nil
}
// parseSubscriptionCache 解析订阅缓存数据
func (s *BillingCacheService) parseSubscriptionCache(data map[string]string) (*subscriptionCacheData, error) {
result := &subscriptionCacheData{}
result.Status = data[subFieldStatus]
if result.Status == "" {
return nil, errors.New("invalid cache: missing status")
}
if expiresStr, ok := data[subFieldExpiresAt]; ok {
expiresAt, err := strconv.ParseInt(expiresStr, 10, 64)
if err == nil {
result.ExpiresAt = time.Unix(expiresAt, 0)
}
}
if dailyStr, ok := data[subFieldDailyUsage]; ok {
result.DailyUsage, _ = strconv.ParseFloat(dailyStr, 64)
}
if weeklyStr, ok := data[subFieldWeeklyUsage]; ok {
result.WeeklyUsage, _ = strconv.ParseFloat(weeklyStr, 64)
}
if monthlyStr, ok := data[subFieldMonthlyUsage]; ok {
result.MonthlyUsage, _ = strconv.ParseFloat(monthlyStr, 64)
}
if versionStr, ok := data[subFieldVersion]; ok {
result.Version, _ = strconv.ParseInt(versionStr, 10, 64)
}
return result, nil
}
// setSubscriptionCache 设置订阅缓存
func (s *BillingCacheService) setSubscriptionCache(ctx context.Context, userID, groupID int64, data *subscriptionCacheData) {
if s.rdb == nil || data == nil {
return
}
key := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID)
fields := map[string]interface{}{
subFieldStatus: data.Status,
subFieldExpiresAt: data.ExpiresAt.Unix(),
subFieldDailyUsage: data.DailyUsage,
subFieldWeeklyUsage: data.WeeklyUsage,
subFieldMonthlyUsage: data.MonthlyUsage,
subFieldVersion: data.Version,
}
pipe := s.rdb.Pipeline()
pipe.HSet(ctx, key, fields)
pipe.Expire(ctx, key, billingCacheTTL)
if _, err := pipe.Exec(ctx); err != nil {
log.Printf("Warning: set subscription cache failed for user %d group %d: %v", userID, groupID, err)
}
}
// UpdateSubscriptionUsage 更新订阅用量缓存(异步调用,用于扣费后更新缓存)
func (s *BillingCacheService) UpdateSubscriptionUsage(ctx context.Context, userID, groupID int64, costUSD float64) error {
if s.rdb == nil {
return nil
}
key := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID)
// 使用预编译的Lua脚本原子性增加用量如果key不存在则忽略
_, err := updateSubUsageScript.Run(ctx, s.rdb, []string{key}, costUSD, int(billingCacheTTL.Seconds())).Result()
if err != nil && !errors.Is(err, redis.Nil) {
log.Printf("Warning: update subscription usage cache failed for user %d group %d: %v", userID, groupID, err)
}
return nil
}
// InvalidateSubscription 失效指定订阅缓存
func (s *BillingCacheService) InvalidateSubscription(ctx context.Context, userID, groupID int64) error {
if s.rdb == nil {
return nil
}
key := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID)
if err := s.rdb.Del(ctx, key).Err(); err != nil {
log.Printf("Warning: invalidate subscription cache failed for user %d group %d: %v", userID, groupID, err)
return err
}
return nil
}
// ============================================
// 统一检查方法
// ============================================
// CheckBillingEligibility 检查用户是否有资格发起请求
// 余额模式:检查缓存余额 > 0
// 订阅模式检查缓存用量未超过限额Group限额从参数传入
func (s *BillingCacheService) CheckBillingEligibility(ctx context.Context, user *model.User, apiKey *model.ApiKey, group *model.Group, subscription *model.UserSubscription) error {
// 判断计费模式
isSubscriptionMode := group != nil && group.IsSubscriptionType() && subscription != nil
if isSubscriptionMode {
return s.checkSubscriptionEligibility(ctx, user.ID, group, subscription)
}
return s.checkBalanceEligibility(ctx, user.ID)
}
// checkBalanceEligibility 检查余额模式资格
func (s *BillingCacheService) checkBalanceEligibility(ctx context.Context, userID int64) error {
balance, err := s.GetUserBalance(ctx, userID)
if err != nil {
// 缓存/数据库错误,允许通过(降级处理)
log.Printf("Warning: get user balance failed, allowing request: %v", err)
return nil
}
if balance <= 0 {
return ErrInsufficientBalance
}
return nil
}
// checkSubscriptionEligibility 检查订阅模式资格
func (s *BillingCacheService) checkSubscriptionEligibility(ctx context.Context, userID int64, group *model.Group, subscription *model.UserSubscription) error {
// 获取订阅缓存数据
subData, err := s.GetSubscriptionStatus(ctx, userID, group.ID)
if err != nil {
// 缓存/数据库错误降级使用传入的subscription进行检查
log.Printf("Warning: get subscription cache failed, using fallback: %v", err)
return s.checkSubscriptionLimitsFallback(subscription, group)
}
// 检查订阅状态
if subData.Status != model.SubscriptionStatusActive {
return ErrSubscriptionInvalid
}
// 检查是否过期
if time.Now().After(subData.ExpiresAt) {
return ErrSubscriptionInvalid
}
// 检查限额使用传入的Group限额配置
if group.HasDailyLimit() && subData.DailyUsage >= *group.DailyLimitUSD {
return ErrDailyLimitExceeded
}
if group.HasWeeklyLimit() && subData.WeeklyUsage >= *group.WeeklyLimitUSD {
return ErrWeeklyLimitExceeded
}
if group.HasMonthlyLimit() && subData.MonthlyUsage >= *group.MonthlyLimitUSD {
return ErrMonthlyLimitExceeded
}
return nil
}
// checkSubscriptionLimitsFallback 降级检查订阅限额
func (s *BillingCacheService) checkSubscriptionLimitsFallback(subscription *model.UserSubscription, group *model.Group) error {
if subscription == nil {
return ErrSubscriptionInvalid
}
if !subscription.IsActive() {
return ErrSubscriptionInvalid
}
if !subscription.CheckDailyLimit(group, 0) {
return ErrDailyLimitExceeded
}
if !subscription.CheckWeeklyLimit(group, 0) {
return ErrWeeklyLimitExceeded
}
if !subscription.CheckMonthlyLimit(group, 0) {
return ErrMonthlyLimitExceeded
}
return nil
}

View File

@@ -0,0 +1,279 @@
package service
import (
"fmt"
"log"
"strings"
"sub2api/internal/config"
)
// ModelPricing 模型价格配置per-token价格与LiteLLM格式一致
type ModelPricing struct {
InputPricePerToken float64 // 每token输入价格 (USD)
OutputPricePerToken float64 // 每token输出价格 (USD)
CacheCreationPricePerToken float64 // 缓存创建每token价格 (USD)
CacheReadPricePerToken float64 // 缓存读取每token价格 (USD)
CacheCreation5mPrice float64 // 5分钟缓存创建价格每百万token- 仅用于硬编码回退
CacheCreation1hPrice float64 // 1小时缓存创建价格每百万token- 仅用于硬编码回退
SupportsCacheBreakdown bool // 是否支持详细的缓存分类
}
// UsageTokens 使用的token数量
type UsageTokens struct {
InputTokens int
OutputTokens int
CacheCreationTokens int
CacheReadTokens int
CacheCreation5mTokens int
CacheCreation1hTokens int
}
// CostBreakdown 费用明细
type CostBreakdown struct {
InputCost float64
OutputCost float64
CacheCreationCost float64
CacheReadCost float64
TotalCost float64
ActualCost float64 // 应用倍率后的实际费用
}
// BillingService 计费服务
type BillingService struct {
cfg *config.Config
pricingService *PricingService
fallbackPrices map[string]*ModelPricing // 硬编码回退价格
}
// NewBillingService 创建计费服务实例
func NewBillingService(cfg *config.Config, pricingService *PricingService) *BillingService {
s := &BillingService{
cfg: cfg,
pricingService: pricingService,
fallbackPrices: make(map[string]*ModelPricing),
}
// 初始化硬编码回退价格(当动态价格不可用时使用)
s.initFallbackPricing()
return s
}
// initFallbackPricing 初始化硬编码回退价格(当动态价格不可用时使用)
// 价格单位USD per token与LiteLLM格式一致
func (s *BillingService) initFallbackPricing() {
// Claude 4.5 Opus
s.fallbackPrices["claude-opus-4.5"] = &ModelPricing{
InputPricePerToken: 5e-6, // $5 per MTok
OutputPricePerToken: 25e-6, // $25 per MTok
CacheCreationPricePerToken: 6.25e-6, // $6.25 per MTok
CacheReadPricePerToken: 0.5e-6, // $0.50 per MTok
SupportsCacheBreakdown: false,
}
// Claude 4 Sonnet
s.fallbackPrices["claude-sonnet-4"] = &ModelPricing{
InputPricePerToken: 3e-6, // $3 per MTok
OutputPricePerToken: 15e-6, // $15 per MTok
CacheCreationPricePerToken: 3.75e-6, // $3.75 per MTok
CacheReadPricePerToken: 0.3e-6, // $0.30 per MTok
SupportsCacheBreakdown: false,
}
// Claude 3.5 Sonnet
s.fallbackPrices["claude-3-5-sonnet"] = &ModelPricing{
InputPricePerToken: 3e-6, // $3 per MTok
OutputPricePerToken: 15e-6, // $15 per MTok
CacheCreationPricePerToken: 3.75e-6, // $3.75 per MTok
CacheReadPricePerToken: 0.3e-6, // $0.30 per MTok
SupportsCacheBreakdown: false,
}
// Claude 3.5 Haiku
s.fallbackPrices["claude-3-5-haiku"] = &ModelPricing{
InputPricePerToken: 1e-6, // $1 per MTok
OutputPricePerToken: 5e-6, // $5 per MTok
CacheCreationPricePerToken: 1.25e-6, // $1.25 per MTok
CacheReadPricePerToken: 0.1e-6, // $0.10 per MTok
SupportsCacheBreakdown: false,
}
// Claude 3 Opus
s.fallbackPrices["claude-3-opus"] = &ModelPricing{
InputPricePerToken: 15e-6, // $15 per MTok
OutputPricePerToken: 75e-6, // $75 per MTok
CacheCreationPricePerToken: 18.75e-6, // $18.75 per MTok
CacheReadPricePerToken: 1.5e-6, // $1.50 per MTok
SupportsCacheBreakdown: false,
}
// Claude 3 Haiku
s.fallbackPrices["claude-3-haiku"] = &ModelPricing{
InputPricePerToken: 0.25e-6, // $0.25 per MTok
OutputPricePerToken: 1.25e-6, // $1.25 per MTok
CacheCreationPricePerToken: 0.3e-6, // $0.30 per MTok
CacheReadPricePerToken: 0.03e-6, // $0.03 per MTok
SupportsCacheBreakdown: false,
}
}
// getFallbackPricing 根据模型系列获取回退价格
func (s *BillingService) getFallbackPricing(model string) *ModelPricing {
modelLower := strings.ToLower(model)
// 按模型系列匹配
if strings.Contains(modelLower, "opus") {
if strings.Contains(modelLower, "4.5") || strings.Contains(modelLower, "4-5") {
return s.fallbackPrices["claude-opus-4.5"]
}
return s.fallbackPrices["claude-3-opus"]
}
if strings.Contains(modelLower, "sonnet") {
if strings.Contains(modelLower, "4") && !strings.Contains(modelLower, "3") {
return s.fallbackPrices["claude-sonnet-4"]
}
return s.fallbackPrices["claude-3-5-sonnet"]
}
if strings.Contains(modelLower, "haiku") {
if strings.Contains(modelLower, "3-5") || strings.Contains(modelLower, "3.5") {
return s.fallbackPrices["claude-3-5-haiku"]
}
return s.fallbackPrices["claude-3-haiku"]
}
// 默认使用Sonnet价格
return s.fallbackPrices["claude-sonnet-4"]
}
// GetModelPricing 获取模型价格配置
func (s *BillingService) GetModelPricing(model string) (*ModelPricing, error) {
// 标准化模型名称(转小写)
model = strings.ToLower(model)
// 1. 优先从动态价格服务获取
if s.pricingService != nil {
litellmPricing := s.pricingService.GetModelPricing(model)
if litellmPricing != nil {
return &ModelPricing{
InputPricePerToken: litellmPricing.InputCostPerToken,
OutputPricePerToken: litellmPricing.OutputCostPerToken,
CacheCreationPricePerToken: litellmPricing.CacheCreationInputTokenCost,
CacheReadPricePerToken: litellmPricing.CacheReadInputTokenCost,
SupportsCacheBreakdown: false,
}, nil
}
}
// 2. 使用硬编码回退价格
fallback := s.getFallbackPricing(model)
if fallback != nil {
log.Printf("[Billing] Using fallback pricing for model: %s", model)
return fallback, nil
}
return nil, fmt.Errorf("pricing not found for model: %s", model)
}
// CalculateCost 计算使用费用
func (s *BillingService) CalculateCost(model string, tokens UsageTokens, rateMultiplier float64) (*CostBreakdown, error) {
pricing, err := s.GetModelPricing(model)
if err != nil {
return nil, err
}
breakdown := &CostBreakdown{}
// 计算输入token费用使用per-token价格
breakdown.InputCost = float64(tokens.InputTokens) * pricing.InputPricePerToken
// 计算输出token费用
breakdown.OutputCost = float64(tokens.OutputTokens) * pricing.OutputPricePerToken
// 计算缓存费用
if pricing.SupportsCacheBreakdown && (pricing.CacheCreation5mPrice > 0 || pricing.CacheCreation1hPrice > 0) {
// 支持详细缓存分类的模型5分钟/1小时缓存
breakdown.CacheCreationCost = float64(tokens.CacheCreation5mTokens)/1_000_000*pricing.CacheCreation5mPrice +
float64(tokens.CacheCreation1hTokens)/1_000_000*pricing.CacheCreation1hPrice
} else {
// 标准缓存创建价格per-token
breakdown.CacheCreationCost = float64(tokens.CacheCreationTokens) * pricing.CacheCreationPricePerToken
}
breakdown.CacheReadCost = float64(tokens.CacheReadTokens) * pricing.CacheReadPricePerToken
// 计算总费用
breakdown.TotalCost = breakdown.InputCost + breakdown.OutputCost +
breakdown.CacheCreationCost + breakdown.CacheReadCost
// 应用倍率计算实际费用
if rateMultiplier <= 0 {
rateMultiplier = 1.0
}
breakdown.ActualCost = breakdown.TotalCost * rateMultiplier
return breakdown, nil
}
// CalculateCostWithConfig 使用配置中的默认倍率计算费用
func (s *BillingService) CalculateCostWithConfig(model string, tokens UsageTokens) (*CostBreakdown, error) {
multiplier := s.cfg.Default.RateMultiplier
if multiplier <= 0 {
multiplier = 1.0
}
return s.CalculateCost(model, tokens, multiplier)
}
// ListSupportedModels 列出所有支持的模型现在总是返回true因为有模糊匹配
func (s *BillingService) ListSupportedModels() []string {
models := make([]string, 0)
// 返回回退价格支持的模型系列
for model := range s.fallbackPrices {
models = append(models, model)
}
return models
}
// IsModelSupported 检查模型是否支持现在总是返回true因为有模糊匹配回退
func (s *BillingService) IsModelSupported(model string) bool {
// 所有Claude模型都有回退价格支持
modelLower := strings.ToLower(model)
return strings.Contains(modelLower, "claude") ||
strings.Contains(modelLower, "opus") ||
strings.Contains(modelLower, "sonnet") ||
strings.Contains(modelLower, "haiku")
}
// GetEstimatedCost 估算费用(用于前端展示)
func (s *BillingService) GetEstimatedCost(model string, estimatedInputTokens, estimatedOutputTokens int) (float64, error) {
tokens := UsageTokens{
InputTokens: estimatedInputTokens,
OutputTokens: estimatedOutputTokens,
}
breakdown, err := s.CalculateCostWithConfig(model, tokens)
if err != nil {
return 0, err
}
return breakdown.ActualCost, nil
}
// GetPricingServiceStatus 获取价格服务状态
func (s *BillingService) GetPricingServiceStatus() map[string]interface{} {
if s.pricingService != nil {
return s.pricingService.GetStatus()
}
return map[string]interface{}{
"model_count": len(s.fallbackPrices),
"last_updated": "using fallback",
"local_hash": "N/A",
}
}
// ForceUpdatePricing 强制更新价格数据
func (s *BillingService) ForceUpdatePricing() error {
if s.pricingService != nil {
return s.pricingService.ForceUpdate()
}
return fmt.Errorf("pricing service not initialized")
}

View File

@@ -0,0 +1,251 @@
package service
import (
"context"
"fmt"
"log"
"time"
"github.com/redis/go-redis/v9"
)
const (
// Redis key prefixes
accountConcurrencyKey = "concurrency:account:"
userConcurrencyKey = "concurrency:user:"
userWaitCountKey = "concurrency:wait:"
// TTL for concurrency keys (auto-release safety net)
concurrencyKeyTTL = 10 * time.Minute
// Wait polling interval
waitPollInterval = 100 * time.Millisecond
// Default max wait time
defaultMaxWait = 60 * time.Second
// Default extra wait slots beyond concurrency limit
defaultExtraWaitSlots = 20
)
// Pre-compiled Lua scripts for better performance
var (
// acquireScript: increment counter if below max, return 1 if successful
acquireScript = redis.NewScript(`
local current = redis.call('GET', KEYS[1])
if current == false then
current = 0
else
current = tonumber(current)
end
if current < tonumber(ARGV[1]) then
redis.call('INCR', KEYS[1])
redis.call('EXPIRE', KEYS[1], ARGV[2])
return 1
end
return 0
`)
// releaseScript: decrement counter, but don't go below 0
releaseScript = redis.NewScript(`
local current = redis.call('GET', KEYS[1])
if current ~= false and tonumber(current) > 0 then
redis.call('DECR', KEYS[1])
end
return 1
`)
// incrementWaitScript: increment wait counter if below max, return 1 if successful
incrementWaitScript = redis.NewScript(`
local waitKey = KEYS[1]
local maxWait = tonumber(ARGV[1])
local ttl = tonumber(ARGV[2])
local current = redis.call('GET', waitKey)
if current == false then
current = 0
else
current = tonumber(current)
end
if current >= maxWait then
return 0
end
redis.call('INCR', waitKey)
redis.call('EXPIRE', waitKey, ttl)
return 1
`)
// decrementWaitScript: decrement wait counter, but don't go below 0
decrementWaitScript = redis.NewScript(`
local current = redis.call('GET', KEYS[1])
if current ~= false and tonumber(current) > 0 then
redis.call('DECR', KEYS[1])
end
return 1
`)
)
// ConcurrencyService manages concurrent request limiting for accounts and users
type ConcurrencyService struct {
rdb *redis.Client
}
// NewConcurrencyService creates a new ConcurrencyService
func NewConcurrencyService(rdb *redis.Client) *ConcurrencyService {
return &ConcurrencyService{rdb: rdb}
}
// AcquireResult represents the result of acquiring a concurrency slot
type AcquireResult struct {
Acquired bool
ReleaseFunc func() // Must be called when done (typically via defer)
}
// AcquireAccountSlot attempts to acquire a concurrency slot for an account.
// If the account is at max concurrency, it waits until a slot is available or timeout.
// Returns a release function that MUST be called when the request completes.
func (s *ConcurrencyService) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int) (*AcquireResult, error) {
key := fmt.Sprintf("%s%d", accountConcurrencyKey, accountID)
return s.acquireSlot(ctx, key, maxConcurrency)
}
// AcquireUserSlot attempts to acquire a concurrency slot for a user.
// If the user is at max concurrency, it waits until a slot is available or timeout.
// Returns a release function that MUST be called when the request completes.
func (s *ConcurrencyService) AcquireUserSlot(ctx context.Context, userID int64, maxConcurrency int) (*AcquireResult, error) {
key := fmt.Sprintf("%s%d", userConcurrencyKey, userID)
return s.acquireSlot(ctx, key, maxConcurrency)
}
// acquireSlot is the core implementation for acquiring a concurrency slot
func (s *ConcurrencyService) acquireSlot(ctx context.Context, key string, maxConcurrency int) (*AcquireResult, error) {
// If maxConcurrency is 0 or negative, no limit
if maxConcurrency <= 0 {
return &AcquireResult{
Acquired: true,
ReleaseFunc: func() {}, // no-op
}, nil
}
// Try to acquire immediately
acquired, err := s.tryAcquire(ctx, key, maxConcurrency)
if err != nil {
return nil, err
}
if acquired {
return &AcquireResult{
Acquired: true,
ReleaseFunc: s.makeReleaseFunc(key),
}, nil
}
// Not acquired, return with Acquired=false
// The caller (gateway handler) will handle waiting with ping support
return &AcquireResult{
Acquired: false,
ReleaseFunc: nil,
}, nil
}
// tryAcquire attempts to increment the counter if below max
// Uses pre-compiled Lua script for atomicity and performance
func (s *ConcurrencyService) tryAcquire(ctx context.Context, key string, maxConcurrency int) (bool, error) {
result, err := acquireScript.Run(ctx, s.rdb, []string{key}, maxConcurrency, int(concurrencyKeyTTL.Seconds())).Int()
if err != nil {
return false, fmt.Errorf("acquire slot failed: %w", err)
}
return result == 1, nil
}
// makeReleaseFunc creates a function to release a concurrency slot
func (s *ConcurrencyService) makeReleaseFunc(key string) func() {
return func() {
// Use background context to ensure release even if original context is cancelled
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := releaseScript.Run(ctx, s.rdb, []string{key}).Err(); err != nil {
// Log error but don't panic - TTL will eventually clean up
log.Printf("Warning: failed to release concurrency slot for %s: %v", key, err)
}
}
}
// GetCurrentCount returns the current concurrency count for debugging/monitoring
func (s *ConcurrencyService) GetCurrentCount(ctx context.Context, key string) (int, error) {
val, err := s.rdb.Get(ctx, key).Int()
if err == redis.Nil {
return 0, nil
}
if err != nil {
return 0, err
}
return val, nil
}
// GetAccountCurrentCount returns current concurrency count for an account
func (s *ConcurrencyService) GetAccountCurrentCount(ctx context.Context, accountID int64) (int, error) {
key := fmt.Sprintf("%s%d", accountConcurrencyKey, accountID)
return s.GetCurrentCount(ctx, key)
}
// GetUserCurrentCount returns current concurrency count for a user
func (s *ConcurrencyService) GetUserCurrentCount(ctx context.Context, userID int64) (int, error) {
key := fmt.Sprintf("%s%d", userConcurrencyKey, userID)
return s.GetCurrentCount(ctx, key)
}
// ============================================
// Wait Queue Count Methods
// ============================================
// IncrementWaitCount attempts to increment the wait queue counter for a user.
// Returns true if successful, false if the wait queue is full.
// maxWait should be user.Concurrency + defaultExtraWaitSlots
func (s *ConcurrencyService) IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error) {
if s.rdb == nil {
// Redis not available, allow request
return true, nil
}
key := fmt.Sprintf("%s%d", userWaitCountKey, userID)
result, err := incrementWaitScript.Run(ctx, s.rdb, []string{key}, maxWait, int(concurrencyKeyTTL.Seconds())).Int()
if err != nil {
// On error, allow the request to proceed (fail open)
log.Printf("Warning: increment wait count failed for user %d: %v", userID, err)
return true, nil
}
return result == 1, nil
}
// DecrementWaitCount decrements the wait queue counter for a user.
// Should be called when a request completes or exits the wait queue.
func (s *ConcurrencyService) DecrementWaitCount(ctx context.Context, userID int64) {
if s.rdb == nil {
return
}
key := fmt.Sprintf("%s%d", userWaitCountKey, userID)
// Use background context to ensure decrement even if original context is cancelled
bgCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := decrementWaitScript.Run(bgCtx, s.rdb, []string{key}).Err(); err != nil {
log.Printf("Warning: decrement wait count failed for user %d: %v", userID, err)
}
}
// GetUserWaitCount returns current wait queue count for a user
func (s *ConcurrencyService) GetUserWaitCount(ctx context.Context, userID int64) (int, error) {
key := fmt.Sprintf("%s%d", userWaitCountKey, userID)
return s.GetCurrentCount(ctx, key)
}
// CalculateMaxWait calculates the maximum wait queue size for a user
// maxWait = userConcurrency + defaultExtraWaitSlots
func CalculateMaxWait(userConcurrency int) int {
if userConcurrency <= 0 {
userConcurrency = 1
}
return userConcurrency + defaultExtraWaitSlots
}

View File

@@ -0,0 +1,109 @@
package service
import (
"context"
"fmt"
"log"
"sync"
"time"
)
// EmailTask 邮件发送任务
type EmailTask struct {
Email string
SiteName string
TaskType string // "verify_code"
}
// EmailQueueService 异步邮件队列服务
type EmailQueueService struct {
emailService *EmailService
taskChan chan EmailTask
wg sync.WaitGroup
stopChan chan struct{}
workers int
}
// NewEmailQueueService 创建邮件队列服务
func NewEmailQueueService(emailService *EmailService, workers int) *EmailQueueService {
if workers <= 0 {
workers = 3 // 默认3个工作协程
}
service := &EmailQueueService{
emailService: emailService,
taskChan: make(chan EmailTask, 100), // 缓冲100个任务
stopChan: make(chan struct{}),
workers: workers,
}
// 启动工作协程
service.start()
return service
}
// start 启动工作协程
func (s *EmailQueueService) start() {
for i := 0; i < s.workers; i++ {
s.wg.Add(1)
go s.worker(i)
}
log.Printf("[EmailQueue] Started %d workers", s.workers)
}
// worker 工作协程
func (s *EmailQueueService) worker(id int) {
defer s.wg.Done()
for {
select {
case task := <-s.taskChan:
s.processTask(id, task)
case <-s.stopChan:
log.Printf("[EmailQueue] Worker %d stopping", id)
return
}
}
}
// processTask 处理任务
func (s *EmailQueueService) processTask(workerID int, task EmailTask) {
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
switch task.TaskType {
case "verify_code":
if err := s.emailService.SendVerifyCode(ctx, task.Email, task.SiteName); err != nil {
log.Printf("[EmailQueue] Worker %d failed to send verify code to %s: %v", workerID, task.Email, err)
} else {
log.Printf("[EmailQueue] Worker %d sent verify code to %s", workerID, task.Email)
}
default:
log.Printf("[EmailQueue] Worker %d unknown task type: %s", workerID, task.TaskType)
}
}
// EnqueueVerifyCode 将验证码发送任务加入队列
func (s *EmailQueueService) EnqueueVerifyCode(email, siteName string) error {
task := EmailTask{
Email: email,
SiteName: siteName,
TaskType: "verify_code",
}
select {
case s.taskChan <- task:
log.Printf("[EmailQueue] Enqueued verify code task for %s", email)
return nil
default:
return fmt.Errorf("email queue is full")
}
}
// Stop 停止队列服务
func (s *EmailQueueService) Stop() {
close(s.stopChan)
s.wg.Wait()
log.Println("[EmailQueue] All workers stopped")
}

View File

@@ -0,0 +1,372 @@
package service
import (
"context"
"crypto/rand"
"crypto/tls"
"encoding/json"
"errors"
"fmt"
"math/big"
"net/smtp"
"strconv"
"sub2api/internal/model"
"sub2api/internal/repository"
"time"
"github.com/redis/go-redis/v9"
)
var (
ErrEmailNotConfigured = errors.New("email service not configured")
ErrInvalidVerifyCode = errors.New("invalid or expired verification code")
ErrVerifyCodeTooFrequent = errors.New("please wait before requesting a new code")
ErrVerifyCodeMaxAttempts = errors.New("too many failed attempts, please request a new code")
)
const (
verifyCodeKeyPrefix = "email_verify:"
verifyCodeTTL = 15 * time.Minute
verifyCodeCooldown = 1 * time.Minute
maxVerifyCodeAttempts = 5
)
// verifyCodeData Redis 中存储的验证码数据
type verifyCodeData struct {
Code string `json:"code"`
Attempts int `json:"attempts"`
CreatedAt time.Time `json:"created_at"`
}
// SmtpConfig SMTP配置
type SmtpConfig struct {
Host string
Port int
Username string
Password string
From string
FromName string
UseTLS bool
}
// EmailService 邮件服务
type EmailService struct {
settingRepo *repository.SettingRepository
rdb *redis.Client
}
// NewEmailService 创建邮件服务实例
func NewEmailService(settingRepo *repository.SettingRepository, rdb *redis.Client) *EmailService {
return &EmailService{
settingRepo: settingRepo,
rdb: rdb,
}
}
// GetSmtpConfig 从数据库获取SMTP配置
func (s *EmailService) GetSmtpConfig(ctx context.Context) (*SmtpConfig, error) {
keys := []string{
model.SettingKeySmtpHost,
model.SettingKeySmtpPort,
model.SettingKeySmtpUsername,
model.SettingKeySmtpPassword,
model.SettingKeySmtpFrom,
model.SettingKeySmtpFromName,
model.SettingKeySmtpUseTLS,
}
settings, err := s.settingRepo.GetMultiple(ctx, keys)
if err != nil {
return nil, fmt.Errorf("get smtp settings: %w", err)
}
host := settings[model.SettingKeySmtpHost]
if host == "" {
return nil, ErrEmailNotConfigured
}
port := 587 // 默认端口
if portStr := settings[model.SettingKeySmtpPort]; portStr != "" {
if p, err := strconv.Atoi(portStr); err == nil {
port = p
}
}
useTLS := settings[model.SettingKeySmtpUseTLS] == "true"
return &SmtpConfig{
Host: host,
Port: port,
Username: settings[model.SettingKeySmtpUsername],
Password: settings[model.SettingKeySmtpPassword],
From: settings[model.SettingKeySmtpFrom],
FromName: settings[model.SettingKeySmtpFromName],
UseTLS: useTLS,
}, nil
}
// SendEmail 发送邮件(使用数据库中保存的配置)
func (s *EmailService) SendEmail(ctx context.Context, to, subject, body string) error {
config, err := s.GetSmtpConfig(ctx)
if err != nil {
return err
}
return s.SendEmailWithConfig(config, to, subject, body)
}
// SendEmailWithConfig 使用指定配置发送邮件
func (s *EmailService) SendEmailWithConfig(config *SmtpConfig, to, subject, body string) error {
from := config.From
if config.FromName != "" {
from = fmt.Sprintf("%s <%s>", config.FromName, config.From)
}
msg := fmt.Sprintf("From: %s\r\nTo: %s\r\nSubject: %s\r\nMIME-Version: 1.0\r\nContent-Type: text/html; charset=UTF-8\r\n\r\n%s",
from, to, subject, body)
addr := fmt.Sprintf("%s:%d", config.Host, config.Port)
auth := smtp.PlainAuth("", config.Username, config.Password, config.Host)
if config.UseTLS {
return s.sendMailTLS(addr, auth, config.From, to, []byte(msg), config.Host)
}
return smtp.SendMail(addr, auth, config.From, []string{to}, []byte(msg))
}
// sendMailTLS 使用TLS发送邮件
func (s *EmailService) sendMailTLS(addr string, auth smtp.Auth, from, to string, msg []byte, host string) error {
tlsConfig := &tls.Config{
ServerName: host,
}
conn, err := tls.Dial("tcp", addr, tlsConfig)
if err != nil {
return fmt.Errorf("tls dial: %w", err)
}
defer conn.Close()
client, err := smtp.NewClient(conn, host)
if err != nil {
return fmt.Errorf("new smtp client: %w", err)
}
defer client.Close()
if err = client.Auth(auth); err != nil {
return fmt.Errorf("smtp auth: %w", err)
}
if err = client.Mail(from); err != nil {
return fmt.Errorf("smtp mail: %w", err)
}
if err = client.Rcpt(to); err != nil {
return fmt.Errorf("smtp rcpt: %w", err)
}
w, err := client.Data()
if err != nil {
return fmt.Errorf("smtp data: %w", err)
}
_, err = w.Write(msg)
if err != nil {
return fmt.Errorf("write msg: %w", err)
}
err = w.Close()
if err != nil {
return fmt.Errorf("close writer: %w", err)
}
// Email is sent successfully after w.Close(), ignore Quit errors
// Some SMTP servers return non-standard responses on QUIT
_ = client.Quit()
return nil
}
// GenerateVerifyCode 生成6位数字验证码
func (s *EmailService) GenerateVerifyCode() (string, error) {
const digits = "0123456789"
code := make([]byte, 6)
for i := range code {
num, err := rand.Int(rand.Reader, big.NewInt(int64(len(digits))))
if err != nil {
return "", err
}
code[i] = digits[num.Int64()]
}
return string(code), nil
}
// SendVerifyCode 发送验证码邮件
func (s *EmailService) SendVerifyCode(ctx context.Context, email, siteName string) error {
key := verifyCodeKeyPrefix + email
// 检查是否在冷却期内
existing, err := s.getVerifyCodeData(ctx, key)
if err == nil && existing != nil {
if time.Since(existing.CreatedAt) < verifyCodeCooldown {
return ErrVerifyCodeTooFrequent
}
}
// 生成验证码
code, err := s.GenerateVerifyCode()
if err != nil {
return fmt.Errorf("generate code: %w", err)
}
// 保存验证码到 Redis
data := &verifyCodeData{
Code: code,
Attempts: 0,
CreatedAt: time.Now(),
}
if err := s.setVerifyCodeData(ctx, key, data); err != nil {
return fmt.Errorf("save verify code: %w", err)
}
// 构建邮件内容
subject := fmt.Sprintf("[%s] Email Verification Code", siteName)
body := s.buildVerifyCodeEmailBody(code, siteName)
// 发送邮件
if err := s.SendEmail(ctx, email, subject, body); err != nil {
return fmt.Errorf("send email: %w", err)
}
return nil
}
// VerifyCode 验证验证码
func (s *EmailService) VerifyCode(ctx context.Context, email, code string) error {
key := verifyCodeKeyPrefix + email
data, err := s.getVerifyCodeData(ctx, key)
if err != nil || data == nil {
return ErrInvalidVerifyCode
}
// 检查是否已达到最大尝试次数
if data.Attempts >= maxVerifyCodeAttempts {
return ErrVerifyCodeMaxAttempts
}
// 验证码不匹配
if data.Code != code {
data.Attempts++
_ = s.setVerifyCodeData(ctx, key, data)
if data.Attempts >= maxVerifyCodeAttempts {
return ErrVerifyCodeMaxAttempts
}
return ErrInvalidVerifyCode
}
// 验证成功,删除验证码
s.rdb.Del(ctx, key)
return nil
}
// getVerifyCodeData 从 Redis 获取验证码数据
func (s *EmailService) getVerifyCodeData(ctx context.Context, key string) (*verifyCodeData, error) {
val, err := s.rdb.Get(ctx, key).Result()
if err != nil {
return nil, err
}
var data verifyCodeData
if err := json.Unmarshal([]byte(val), &data); err != nil {
return nil, err
}
return &data, nil
}
// setVerifyCodeData 保存验证码数据到 Redis
func (s *EmailService) setVerifyCodeData(ctx context.Context, key string, data *verifyCodeData) error {
val, err := json.Marshal(data)
if err != nil {
return err
}
return s.rdb.Set(ctx, key, val, verifyCodeTTL).Err()
}
// buildVerifyCodeEmailBody 构建验证码邮件HTML内容
func (s *EmailService) buildVerifyCodeEmailBody(code, siteName string) string {
return fmt.Sprintf(`
<!DOCTYPE html>
<html>
<head>
<meta charset="UTF-8">
<style>
body { font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Oxygen, Ubuntu, sans-serif; background-color: #f5f5f5; margin: 0; padding: 20px; }
.container { max-width: 600px; margin: 0 auto; background-color: #ffffff; border-radius: 8px; overflow: hidden; box-shadow: 0 2px 8px rgba(0,0,0,0.1); }
.header { background: linear-gradient(135deg, #667eea 0%%, #764ba2 100%%); color: white; padding: 30px; text-align: center; }
.header h1 { margin: 0; font-size: 24px; }
.content { padding: 40px 30px; text-align: center; }
.code { font-size: 36px; font-weight: bold; letter-spacing: 8px; color: #333; background-color: #f8f9fa; padding: 20px 30px; border-radius: 8px; display: inline-block; margin: 20px 0; font-family: monospace; }
.info { color: #666; font-size: 14px; line-height: 1.6; margin-top: 20px; }
.footer { background-color: #f8f9fa; padding: 20px; text-align: center; color: #999; font-size: 12px; }
</style>
</head>
<body>
<div class="container">
<div class="header">
<h1>%s</h1>
</div>
<div class="content">
<p style="font-size: 18px; color: #333;">Your verification code is:</p>
<div class="code">%s</div>
<div class="info">
<p>This code will expire in <strong>15 minutes</strong>.</p>
<p>If you did not request this code, please ignore this email.</p>
</div>
</div>
<div class="footer">
<p>This is an automated message, please do not reply.</p>
</div>
</div>
</body>
</html>
`, siteName, code)
}
// TestSmtpConnectionWithConfig 使用指定配置测试SMTP连接
func (s *EmailService) TestSmtpConnectionWithConfig(config *SmtpConfig) error {
addr := fmt.Sprintf("%s:%d", config.Host, config.Port)
if config.UseTLS {
tlsConfig := &tls.Config{ServerName: config.Host}
conn, err := tls.Dial("tcp", addr, tlsConfig)
if err != nil {
return fmt.Errorf("tls connection failed: %w", err)
}
defer conn.Close()
client, err := smtp.NewClient(conn, config.Host)
if err != nil {
return fmt.Errorf("smtp client creation failed: %w", err)
}
defer client.Close()
auth := smtp.PlainAuth("", config.Username, config.Password, config.Host)
if err = client.Auth(auth); err != nil {
return fmt.Errorf("smtp authentication failed: %w", err)
}
return client.Quit()
}
// 非TLS连接测试
client, err := smtp.Dial(addr)
if err != nil {
return fmt.Errorf("smtp connection failed: %w", err)
}
defer client.Close()
auth := smtp.PlainAuth("", config.Username, config.Password, config.Host)
if err = client.Auth(auth); err != nil {
return fmt.Errorf("smtp authentication failed: %w", err)
}
return client.Quit()
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,194 @@
package service
import (
"context"
"errors"
"fmt"
"sub2api/internal/model"
"sub2api/internal/repository"
"gorm.io/gorm"
)
var (
ErrGroupNotFound = errors.New("group not found")
ErrGroupExists = errors.New("group name already exists")
)
// CreateGroupRequest 创建分组请求
type CreateGroupRequest struct {
Name string `json:"name"`
Description string `json:"description"`
RateMultiplier float64 `json:"rate_multiplier"`
IsExclusive bool `json:"is_exclusive"`
}
// UpdateGroupRequest 更新分组请求
type UpdateGroupRequest struct {
Name *string `json:"name"`
Description *string `json:"description"`
RateMultiplier *float64 `json:"rate_multiplier"`
IsExclusive *bool `json:"is_exclusive"`
Status *string `json:"status"`
}
// GroupService 分组管理服务
type GroupService struct {
groupRepo *repository.GroupRepository
}
// NewGroupService 创建分组服务实例
func NewGroupService(groupRepo *repository.GroupRepository) *GroupService {
return &GroupService{
groupRepo: groupRepo,
}
}
// Create 创建分组
func (s *GroupService) Create(ctx context.Context, req CreateGroupRequest) (*model.Group, error) {
// 检查名称是否已存在
exists, err := s.groupRepo.ExistsByName(ctx, req.Name)
if err != nil {
return nil, fmt.Errorf("check group exists: %w", err)
}
if exists {
return nil, ErrGroupExists
}
// 创建分组
group := &model.Group{
Name: req.Name,
Description: req.Description,
RateMultiplier: req.RateMultiplier,
IsExclusive: req.IsExclusive,
Status: model.StatusActive,
}
if err := s.groupRepo.Create(ctx, group); err != nil {
return nil, fmt.Errorf("create group: %w", err)
}
return group, nil
}
// GetByID 根据ID获取分组
func (s *GroupService) GetByID(ctx context.Context, id int64) (*model.Group, error) {
group, err := s.groupRepo.GetByID(ctx, id)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrGroupNotFound
}
return nil, fmt.Errorf("get group: %w", err)
}
return group, nil
}
// List 获取分组列表
func (s *GroupService) List(ctx context.Context, params repository.PaginationParams) ([]model.Group, *repository.PaginationResult, error) {
groups, pagination, err := s.groupRepo.List(ctx, params)
if err != nil {
return nil, nil, fmt.Errorf("list groups: %w", err)
}
return groups, pagination, nil
}
// ListActive 获取活跃分组列表
func (s *GroupService) ListActive(ctx context.Context) ([]model.Group, error) {
groups, err := s.groupRepo.ListActive(ctx)
if err != nil {
return nil, fmt.Errorf("list active groups: %w", err)
}
return groups, nil
}
// Update 更新分组
func (s *GroupService) Update(ctx context.Context, id int64, req UpdateGroupRequest) (*model.Group, error) {
group, err := s.groupRepo.GetByID(ctx, id)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrGroupNotFound
}
return nil, fmt.Errorf("get group: %w", err)
}
// 更新字段
if req.Name != nil && *req.Name != group.Name {
// 检查新名称是否已存在
exists, err := s.groupRepo.ExistsByName(ctx, *req.Name)
if err != nil {
return nil, fmt.Errorf("check group exists: %w", err)
}
if exists {
return nil, ErrGroupExists
}
group.Name = *req.Name
}
if req.Description != nil {
group.Description = *req.Description
}
if req.RateMultiplier != nil {
group.RateMultiplier = *req.RateMultiplier
}
if req.IsExclusive != nil {
group.IsExclusive = *req.IsExclusive
}
if req.Status != nil {
group.Status = *req.Status
}
if err := s.groupRepo.Update(ctx, group); err != nil {
return nil, fmt.Errorf("update group: %w", err)
}
return group, nil
}
// Delete 删除分组
func (s *GroupService) Delete(ctx context.Context, id int64) error {
// 检查分组是否存在
_, err := s.groupRepo.GetByID(ctx, id)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return ErrGroupNotFound
}
return fmt.Errorf("get group: %w", err)
}
if err := s.groupRepo.Delete(ctx, id); err != nil {
return fmt.Errorf("delete group: %w", err)
}
return nil
}
// GetStats 获取分组统计信息
func (s *GroupService) GetStats(ctx context.Context, id int64) (map[string]interface{}, error) {
group, err := s.groupRepo.GetByID(ctx, id)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrGroupNotFound
}
return nil, fmt.Errorf("get group: %w", err)
}
// 获取账号数量
accountCount, err := s.groupRepo.GetAccountCount(ctx, id)
if err != nil {
return nil, fmt.Errorf("get account count: %w", err)
}
stats := map[string]interface{}{
"id": group.ID,
"name": group.Name,
"rate_multiplier": group.RateMultiplier,
"is_exclusive": group.IsExclusive,
"status": group.Status,
"account_count": accountCount,
}
return stats, nil
}

View File

@@ -0,0 +1,282 @@
package service
import (
"context"
"crypto/rand"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"fmt"
"log"
"net/http"
"regexp"
"strconv"
"time"
"github.com/redis/go-redis/v9"
)
const (
// Redis key prefix
identityFingerprintKey = "identity:fingerprint:"
)
// 预编译正则表达式(避免每次调用重新编译)
var (
// 匹配 user_id 格式: user_{64位hex}_account__session_{uuid}
userIDRegex = regexp.MustCompile(`^user_[a-f0-9]{64}_account__session_([a-f0-9-]{36})$`)
// 匹配 User-Agent 版本号: xxx/x.y.z
userAgentVersionRegex = regexp.MustCompile(`/(\d+)\.(\d+)\.(\d+)`)
)
// Fingerprint 存储的指纹数据结构
type Fingerprint struct {
ClientID string `json:"client_id"` // 64位hex客户端ID首次随机生成
UserAgent string `json:"user_agent"` // User-Agent
StainlessLang string `json:"x_stainless_lang"` // x-stainless-lang
StainlessPackageVersion string `json:"x_stainless_package_version"` // x-stainless-package-version
StainlessOS string `json:"x_stainless_os"` // x-stainless-os
StainlessArch string `json:"x_stainless_arch"` // x-stainless-arch
StainlessRuntime string `json:"x_stainless_runtime"` // x-stainless-runtime
StainlessRuntimeVersion string `json:"x_stainless_runtime_version"` // x-stainless-runtime-version
}
// 默认指纹值(当客户端未提供时使用)
var defaultFingerprint = Fingerprint{
UserAgent: "claude-cli/2.0.62 (external, cli)",
StainlessLang: "js",
StainlessPackageVersion: "0.52.0",
StainlessOS: "Linux",
StainlessArch: "x64",
StainlessRuntime: "node",
StainlessRuntimeVersion: "v22.14.0",
}
// IdentityService 管理OAuth账号的请求身份指纹
type IdentityService struct {
rdb *redis.Client
}
// NewIdentityService 创建新的IdentityService
func NewIdentityService(rdb *redis.Client) *IdentityService {
return &IdentityService{rdb: rdb}
}
// GetOrCreateFingerprint 获取或创建账号的指纹
// 如果缓存存在检测user-agent版本新版本则更新
// 如果缓存不存在生成随机ClientID并从请求头创建指纹然后缓存
func (s *IdentityService) GetOrCreateFingerprint(ctx context.Context, accountID int64, headers http.Header) (*Fingerprint, error) {
key := identityFingerprintKey + strconv.FormatInt(accountID, 10)
// 尝试从Redis获取缓存的指纹
data, err := s.rdb.Get(ctx, key).Bytes()
if err == nil && len(data) > 0 {
// 缓存存在,解析指纹
var cached Fingerprint
if err := json.Unmarshal(data, &cached); err == nil {
// 检查客户端的user-agent是否是更新版本
clientUA := headers.Get("User-Agent")
if clientUA != "" && isNewerVersion(clientUA, cached.UserAgent) {
// 更新user-agent
cached.UserAgent = clientUA
// 保存更新后的指纹
if newData, err := json.Marshal(cached); err == nil {
s.rdb.Set(ctx, key, newData, 0) // 永不过期
}
log.Printf("Updated fingerprint user-agent for account %d: %s", accountID, clientUA)
}
return &cached, nil
}
}
// 缓存不存在或解析失败,创建新指纹
fp := s.createFingerprintFromHeaders(headers)
// 生成随机ClientID
fp.ClientID = generateClientID()
// 保存到Redis永不过期
if data, err := json.Marshal(fp); err == nil {
if err := s.rdb.Set(ctx, key, data, 0).Err(); err != nil {
log.Printf("Warning: failed to cache fingerprint for account %d: %v", accountID, err)
}
}
log.Printf("Created new fingerprint for account %d with client_id: %s", accountID, fp.ClientID)
return fp, nil
}
// createFingerprintFromHeaders 从请求头创建指纹
func (s *IdentityService) createFingerprintFromHeaders(headers http.Header) *Fingerprint {
fp := &Fingerprint{}
// 获取User-Agent
if ua := headers.Get("User-Agent"); ua != "" {
fp.UserAgent = ua
} else {
fp.UserAgent = defaultFingerprint.UserAgent
}
// 获取x-stainless-*头,如果没有则使用默认值
fp.StainlessLang = getHeaderOrDefault(headers, "X-Stainless-Lang", defaultFingerprint.StainlessLang)
fp.StainlessPackageVersion = getHeaderOrDefault(headers, "X-Stainless-Package-Version", defaultFingerprint.StainlessPackageVersion)
fp.StainlessOS = getHeaderOrDefault(headers, "X-Stainless-OS", defaultFingerprint.StainlessOS)
fp.StainlessArch = getHeaderOrDefault(headers, "X-Stainless-Arch", defaultFingerprint.StainlessArch)
fp.StainlessRuntime = getHeaderOrDefault(headers, "X-Stainless-Runtime", defaultFingerprint.StainlessRuntime)
fp.StainlessRuntimeVersion = getHeaderOrDefault(headers, "X-Stainless-Runtime-Version", defaultFingerprint.StainlessRuntimeVersion)
return fp
}
// getHeaderOrDefault 获取header值如果不存在则返回默认值
func getHeaderOrDefault(headers http.Header, key, defaultValue string) string {
if v := headers.Get(key); v != "" {
return v
}
return defaultValue
}
// ApplyFingerprint 将指纹应用到请求头覆盖原有的x-stainless-*头)
func (s *IdentityService) ApplyFingerprint(req *http.Request, fp *Fingerprint) {
if fp == nil {
return
}
// 设置User-Agent
if fp.UserAgent != "" {
req.Header.Set("User-Agent", fp.UserAgent)
}
// 设置x-stainless-*头(使用正确的大小写)
if fp.StainlessLang != "" {
req.Header.Set("X-Stainless-Lang", fp.StainlessLang)
}
if fp.StainlessPackageVersion != "" {
req.Header.Set("X-Stainless-Package-Version", fp.StainlessPackageVersion)
}
if fp.StainlessOS != "" {
req.Header.Set("X-Stainless-OS", fp.StainlessOS)
}
if fp.StainlessArch != "" {
req.Header.Set("X-Stainless-Arch", fp.StainlessArch)
}
if fp.StainlessRuntime != "" {
req.Header.Set("X-Stainless-Runtime", fp.StainlessRuntime)
}
if fp.StainlessRuntimeVersion != "" {
req.Header.Set("X-Stainless-Runtime-Version", fp.StainlessRuntimeVersion)
}
}
// RewriteUserID 重写body中的metadata.user_id
// 输入格式user_{clientId}_account__session_{sessionUUID}
// 输出格式user_{cachedClientID}_account_{accountUUID}_session_{newHash}
func (s *IdentityService) RewriteUserID(body []byte, accountID int64, accountUUID, cachedClientID string) ([]byte, error) {
if len(body) == 0 || accountUUID == "" || cachedClientID == "" {
return body, nil
}
// 解析JSON
var reqMap map[string]interface{}
if err := json.Unmarshal(body, &reqMap); err != nil {
return body, nil
}
metadata, ok := reqMap["metadata"].(map[string]interface{})
if !ok {
return body, nil
}
userID, ok := metadata["user_id"].(string)
if !ok || userID == "" {
return body, nil
}
// 匹配格式: user_{64位hex}_account__session_{uuid}
matches := userIDRegex.FindStringSubmatch(userID)
if matches == nil {
return body, nil
}
sessionTail := matches[1] // 原始session UUID
// 生成新的session hash: SHA256(accountID::sessionTail) -> UUID格式
seed := fmt.Sprintf("%d::%s", accountID, sessionTail)
newSessionHash := generateUUIDFromSeed(seed)
// 构建新的user_id
// 格式: user_{cachedClientID}_account_{account_uuid}_session_{newSessionHash}
newUserID := fmt.Sprintf("user_%s_account_%s_session_%s", cachedClientID, accountUUID, newSessionHash)
metadata["user_id"] = newUserID
reqMap["metadata"] = metadata
return json.Marshal(reqMap)
}
// generateClientID 生成64位十六进制客户端ID32字节随机数
func generateClientID() string {
b := make([]byte, 32)
if _, err := rand.Read(b); err != nil {
// 极罕见的情况,使用时间戳+固定值作为fallback
log.Printf("Warning: crypto/rand.Read failed: %v, using fallback", err)
// 使用SHA256(当前纳秒时间)作为fallback
h := sha256.Sum256([]byte(fmt.Sprintf("%d", time.Now().UnixNano())))
return hex.EncodeToString(h[:])
}
return hex.EncodeToString(b)
}
// generateUUIDFromSeed 从种子生成确定性UUID v4格式字符串
func generateUUIDFromSeed(seed string) string {
hash := sha256.Sum256([]byte(seed))
bytes := hash[:16]
// 设置UUID v4版本和变体位
bytes[6] = (bytes[6] & 0x0f) | 0x40
bytes[8] = (bytes[8] & 0x3f) | 0x80
return fmt.Sprintf("%x-%x-%x-%x-%x",
bytes[0:4], bytes[4:6], bytes[6:8], bytes[8:10], bytes[10:16])
}
// parseUserAgentVersion 解析user-agent版本号
// 例如claude-cli/2.0.62 -> (2, 0, 62)
func parseUserAgentVersion(ua string) (major, minor, patch int, ok bool) {
// 匹配 xxx/x.y.z 格式
matches := userAgentVersionRegex.FindStringSubmatch(ua)
if len(matches) != 4 {
return 0, 0, 0, false
}
major, _ = strconv.Atoi(matches[1])
minor, _ = strconv.Atoi(matches[2])
patch, _ = strconv.Atoi(matches[3])
return major, minor, patch, true
}
// isNewerVersion 比较版本号判断newUA是否比cachedUA更新
func isNewerVersion(newUA, cachedUA string) bool {
newMajor, newMinor, newPatch, newOk := parseUserAgentVersion(newUA)
cachedMajor, cachedMinor, cachedPatch, cachedOk := parseUserAgentVersion(cachedUA)
if !newOk || !cachedOk {
return false
}
// 比较版本号
if newMajor > cachedMajor {
return true
}
if newMajor < cachedMajor {
return false
}
if newMinor > cachedMinor {
return true
}
if newMinor < cachedMinor {
return false
}
return newPatch > cachedPatch
}

View File

@@ -0,0 +1,471 @@
package service
import (
"context"
"encoding/json"
"fmt"
"log"
"net/http"
"net/url"
"strings"
"time"
"sub2api/internal/model"
"sub2api/internal/pkg/oauth"
"sub2api/internal/repository"
"github.com/imroc/req/v3"
)
// OAuthService handles OAuth authentication flows
type OAuthService struct {
sessionStore *oauth.SessionStore
proxyRepo *repository.ProxyRepository
}
// NewOAuthService creates a new OAuth service
func NewOAuthService(proxyRepo *repository.ProxyRepository) *OAuthService {
return &OAuthService{
sessionStore: oauth.NewSessionStore(),
proxyRepo: proxyRepo,
}
}
// GenerateAuthURLResult contains the authorization URL and session info
type GenerateAuthURLResult struct {
AuthURL string `json:"auth_url"`
SessionID string `json:"session_id"`
}
// GenerateAuthURL generates an OAuth authorization URL with full scope
func (s *OAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64) (*GenerateAuthURLResult, error) {
scope := fmt.Sprintf("%s %s", oauth.ScopeProfile, oauth.ScopeInference)
return s.generateAuthURLWithScope(ctx, scope, proxyID)
}
// GenerateSetupTokenURL generates an OAuth authorization URL for setup token (inference only)
func (s *OAuthService) GenerateSetupTokenURL(ctx context.Context, proxyID *int64) (*GenerateAuthURLResult, error) {
scope := oauth.ScopeInference
return s.generateAuthURLWithScope(ctx, scope, proxyID)
}
func (s *OAuthService) generateAuthURLWithScope(ctx context.Context, scope string, proxyID *int64) (*GenerateAuthURLResult, error) {
// Generate PKCE values
state, err := oauth.GenerateState()
if err != nil {
return nil, fmt.Errorf("failed to generate state: %w", err)
}
codeVerifier, err := oauth.GenerateCodeVerifier()
if err != nil {
return nil, fmt.Errorf("failed to generate code verifier: %w", err)
}
codeChallenge := oauth.GenerateCodeChallenge(codeVerifier)
// Generate session ID
sessionID, err := oauth.GenerateSessionID()
if err != nil {
return nil, fmt.Errorf("failed to generate session ID: %w", err)
}
// Get proxy URL if specified
var proxyURL string
if proxyID != nil {
proxy, err := s.proxyRepo.GetByID(ctx, *proxyID)
if err == nil && proxy != nil {
proxyURL = proxy.URL()
}
}
// Store session
session := &oauth.OAuthSession{
State: state,
CodeVerifier: codeVerifier,
Scope: scope,
ProxyURL: proxyURL,
CreatedAt: time.Now(),
}
s.sessionStore.Set(sessionID, session)
// Build authorization URL
authURL := oauth.BuildAuthorizationURL(state, codeChallenge, scope)
return &GenerateAuthURLResult{
AuthURL: authURL,
SessionID: sessionID,
}, nil
}
// ExchangeCodeInput represents the input for code exchange
type ExchangeCodeInput struct {
SessionID string
Code string
ProxyID *int64
}
// TokenInfo represents the token information stored in credentials
type TokenInfo struct {
AccessToken string `json:"access_token"`
TokenType string `json:"token_type"`
ExpiresIn int64 `json:"expires_in"`
ExpiresAt int64 `json:"expires_at"`
RefreshToken string `json:"refresh_token,omitempty"`
Scope string `json:"scope,omitempty"`
OrgUUID string `json:"org_uuid,omitempty"`
AccountUUID string `json:"account_uuid,omitempty"`
}
// ExchangeCode exchanges authorization code for tokens
func (s *OAuthService) ExchangeCode(ctx context.Context, input *ExchangeCodeInput) (*TokenInfo, error) {
// Get session
session, ok := s.sessionStore.Get(input.SessionID)
if !ok {
return nil, fmt.Errorf("session not found or expired")
}
// Get proxy URL
proxyURL := session.ProxyURL
if input.ProxyID != nil {
proxy, err := s.proxyRepo.GetByID(ctx, *input.ProxyID)
if err == nil && proxy != nil {
proxyURL = proxy.URL()
}
}
// Exchange code for token
tokenInfo, err := s.exchangeCodeForToken(ctx, input.Code, session.CodeVerifier, session.State, proxyURL)
if err != nil {
return nil, err
}
// Delete session after successful exchange
s.sessionStore.Delete(input.SessionID)
return tokenInfo, nil
}
// CookieAuthInput represents the input for cookie-based authentication
type CookieAuthInput struct {
SessionKey string
ProxyID *int64
Scope string // "full" or "inference"
}
// CookieAuth performs OAuth using sessionKey (cookie-based auto-auth)
func (s *OAuthService) CookieAuth(ctx context.Context, input *CookieAuthInput) (*TokenInfo, error) {
// Get proxy URL if specified
var proxyURL string
if input.ProxyID != nil {
proxy, err := s.proxyRepo.GetByID(ctx, *input.ProxyID)
if err == nil && proxy != nil {
proxyURL = proxy.URL()
}
}
// Determine scope
scope := fmt.Sprintf("%s %s", oauth.ScopeProfile, oauth.ScopeInference)
if input.Scope == "inference" {
scope = oauth.ScopeInference
}
// Step 1: Get organization info using sessionKey
orgUUID, err := s.getOrganizationUUID(ctx, input.SessionKey, proxyURL)
if err != nil {
return nil, fmt.Errorf("failed to get organization info: %w", err)
}
// Step 2: Generate PKCE values
codeVerifier, err := oauth.GenerateCodeVerifier()
if err != nil {
return nil, fmt.Errorf("failed to generate code verifier: %w", err)
}
codeChallenge := oauth.GenerateCodeChallenge(codeVerifier)
state, err := oauth.GenerateState()
if err != nil {
return nil, fmt.Errorf("failed to generate state: %w", err)
}
// Step 3: Get authorization code using cookie
authCode, err := s.getAuthorizationCode(ctx, input.SessionKey, orgUUID, scope, codeChallenge, state, proxyURL)
if err != nil {
return nil, fmt.Errorf("failed to get authorization code: %w", err)
}
// Step 4: Exchange code for token
tokenInfo, err := s.exchangeCodeForToken(ctx, authCode, codeVerifier, state, proxyURL)
if err != nil {
return nil, fmt.Errorf("failed to exchange code: %w", err)
}
// Ensure org_uuid is set (from step 1 if not from token response)
if tokenInfo.OrgUUID == "" && orgUUID != "" {
tokenInfo.OrgUUID = orgUUID
log.Printf("[OAuth] Set org_uuid from cookie auth: %s", orgUUID)
}
return tokenInfo, nil
}
// getOrganizationUUID gets the organization UUID from claude.ai using sessionKey
func (s *OAuthService) getOrganizationUUID(ctx context.Context, sessionKey, proxyURL string) (string, error) {
client := s.createReqClient(proxyURL)
var orgs []struct {
UUID string `json:"uuid"`
}
targetURL := "https://claude.ai/api/organizations"
log.Printf("[OAuth] Step 1: Getting organization UUID from %s", targetURL)
resp, err := client.R().
SetContext(ctx).
SetCookies(&http.Cookie{
Name: "sessionKey",
Value: sessionKey,
}).
SetSuccessResult(&orgs).
Get(targetURL)
if err != nil {
log.Printf("[OAuth] Step 1 FAILED - Request error: %v", err)
return "", fmt.Errorf("request failed: %w", err)
}
log.Printf("[OAuth] Step 1 Response - Status: %d, Body: %s", resp.StatusCode, resp.String())
if !resp.IsSuccessState() {
return "", fmt.Errorf("failed to get organizations: status %d, body: %s", resp.StatusCode, resp.String())
}
if len(orgs) == 0 {
return "", fmt.Errorf("no organizations found")
}
log.Printf("[OAuth] Step 1 SUCCESS - Got org UUID: %s", orgs[0].UUID)
return orgs[0].UUID, nil
}
// getAuthorizationCode gets the authorization code using sessionKey
func (s *OAuthService) getAuthorizationCode(ctx context.Context, sessionKey, orgUUID, scope, codeChallenge, state, proxyURL string) (string, error) {
client := s.createReqClient(proxyURL)
authURL := fmt.Sprintf("https://claude.ai/v1/oauth/%s/authorize", orgUUID)
// Build request body - must include organization_uuid as per CRS
reqBody := map[string]interface{}{
"response_type": "code",
"client_id": oauth.ClientID,
"organization_uuid": orgUUID, // Required field!
"redirect_uri": oauth.RedirectURI,
"scope": scope,
"state": state,
"code_challenge": codeChallenge,
"code_challenge_method": "S256",
}
reqBodyJSON, _ := json.Marshal(reqBody)
log.Printf("[OAuth] Step 2: Getting authorization code from %s", authURL)
log.Printf("[OAuth] Step 2 Request Body: %s", string(reqBodyJSON))
// Response contains redirect_uri with code, not direct code field
var result struct {
RedirectURI string `json:"redirect_uri"`
}
resp, err := client.R().
SetContext(ctx).
SetCookies(&http.Cookie{
Name: "sessionKey",
Value: sessionKey,
}).
SetHeader("Accept", "application/json").
SetHeader("Accept-Language", "en-US,en;q=0.9").
SetHeader("Cache-Control", "no-cache").
SetHeader("Origin", "https://claude.ai").
SetHeader("Referer", "https://claude.ai/new").
SetHeader("Content-Type", "application/json").
SetBody(reqBody).
SetSuccessResult(&result).
Post(authURL)
if err != nil {
log.Printf("[OAuth] Step 2 FAILED - Request error: %v", err)
return "", fmt.Errorf("request failed: %w", err)
}
log.Printf("[OAuth] Step 2 Response - Status: %d, Body: %s", resp.StatusCode, resp.String())
if !resp.IsSuccessState() {
return "", fmt.Errorf("failed to get authorization code: status %d, body: %s", resp.StatusCode, resp.String())
}
if result.RedirectURI == "" {
return "", fmt.Errorf("no redirect_uri in response")
}
// Parse redirect_uri to extract code and state
parsedURL, err := url.Parse(result.RedirectURI)
if err != nil {
return "", fmt.Errorf("failed to parse redirect_uri: %w", err)
}
queryParams := parsedURL.Query()
authCode := queryParams.Get("code")
responseState := queryParams.Get("state")
if authCode == "" {
return "", fmt.Errorf("no authorization code in redirect_uri")
}
// Combine code with state if present (as CRS does)
fullCode := authCode
if responseState != "" {
fullCode = authCode + "#" + responseState
}
log.Printf("[OAuth] Step 2 SUCCESS - Got authorization code: %s...", authCode[:20])
return fullCode, nil
}
// exchangeCodeForToken exchanges authorization code for tokens
func (s *OAuthService) exchangeCodeForToken(ctx context.Context, code, codeVerifier, state, proxyURL string) (*TokenInfo, error) {
client := s.createReqClient(proxyURL)
// Parse code#state format if present
authCode := code
codeState := ""
if parts := strings.Split(code, "#"); len(parts) > 1 {
authCode = parts[0]
codeState = parts[1]
}
// Build JSON body as CRS does (not form data!)
reqBody := map[string]interface{}{
"code": authCode,
"grant_type": "authorization_code",
"client_id": oauth.ClientID,
"redirect_uri": oauth.RedirectURI,
"code_verifier": codeVerifier,
}
// Add state if present
if codeState != "" {
reqBody["state"] = codeState
}
reqBodyJSON, _ := json.Marshal(reqBody)
log.Printf("[OAuth] Step 3: Exchanging code for token at %s", oauth.TokenURL)
log.Printf("[OAuth] Step 3 Request Body: %s", string(reqBodyJSON))
var tokenResp oauth.TokenResponse
resp, err := client.R().
SetContext(ctx).
SetHeader("Content-Type", "application/json").
SetBody(reqBody).
SetSuccessResult(&tokenResp).
Post(oauth.TokenURL)
if err != nil {
log.Printf("[OAuth] Step 3 FAILED - Request error: %v", err)
return nil, fmt.Errorf("request failed: %w", err)
}
log.Printf("[OAuth] Step 3 Response - Status: %d, Body: %s", resp.StatusCode, resp.String())
if !resp.IsSuccessState() {
return nil, fmt.Errorf("token exchange failed: status %d, body: %s", resp.StatusCode, resp.String())
}
log.Printf("[OAuth] Step 3 SUCCESS - Got access token")
tokenInfo := &TokenInfo{
AccessToken: tokenResp.AccessToken,
TokenType: tokenResp.TokenType,
ExpiresIn: tokenResp.ExpiresIn,
ExpiresAt: time.Now().Unix() + tokenResp.ExpiresIn,
RefreshToken: tokenResp.RefreshToken,
Scope: tokenResp.Scope,
}
// Extract org_uuid and account_uuid from response
if tokenResp.Organization != nil && tokenResp.Organization.UUID != "" {
tokenInfo.OrgUUID = tokenResp.Organization.UUID
log.Printf("[OAuth] Got org_uuid: %s", tokenInfo.OrgUUID)
}
if tokenResp.Account != nil && tokenResp.Account.UUID != "" {
tokenInfo.AccountUUID = tokenResp.Account.UUID
log.Printf("[OAuth] Got account_uuid: %s", tokenInfo.AccountUUID)
}
return tokenInfo, nil
}
// RefreshToken refreshes an OAuth token
func (s *OAuthService) RefreshToken(ctx context.Context, refreshToken string, proxyURL string) (*TokenInfo, error) {
client := s.createReqClient(proxyURL)
formData := url.Values{}
formData.Set("grant_type", "refresh_token")
formData.Set("refresh_token", refreshToken)
formData.Set("client_id", oauth.ClientID)
var tokenResp oauth.TokenResponse
resp, err := client.R().
SetContext(ctx).
SetFormDataFromValues(formData).
SetSuccessResult(&tokenResp).
Post(oauth.TokenURL)
if err != nil {
return nil, fmt.Errorf("request failed: %w", err)
}
if !resp.IsSuccessState() {
return nil, fmt.Errorf("token refresh failed: status %d, body: %s", resp.StatusCode, resp.String())
}
return &TokenInfo{
AccessToken: tokenResp.AccessToken,
TokenType: tokenResp.TokenType,
ExpiresIn: tokenResp.ExpiresIn,
ExpiresAt: time.Now().Unix() + tokenResp.ExpiresIn,
RefreshToken: tokenResp.RefreshToken,
Scope: tokenResp.Scope,
}, nil
}
// RefreshAccountToken refreshes token for an account
func (s *OAuthService) RefreshAccountToken(ctx context.Context, account *model.Account) (*TokenInfo, error) {
refreshToken := account.GetCredential("refresh_token")
if refreshToken == "" {
return nil, fmt.Errorf("no refresh token available")
}
var proxyURL string
if account.ProxyID != nil {
proxy, err := s.proxyRepo.GetByID(ctx, *account.ProxyID)
if err == nil && proxy != nil {
proxyURL = proxy.URL()
}
}
return s.RefreshToken(ctx, refreshToken, proxyURL)
}
// createReqClient creates a req client with Chrome impersonation and optional proxy
func (s *OAuthService) createReqClient(proxyURL string) *req.Client {
client := req.C().
ImpersonateChrome(). // Impersonate Chrome browser to bypass Cloudflare
SetTimeout(60 * time.Second)
// Set proxy if specified
if proxyURL != "" {
client.SetProxyURL(proxyURL)
}
return client
}

View File

@@ -0,0 +1,572 @@
package service
import (
"crypto/sha256"
"encoding/hex"
"encoding/json"
"fmt"
"io"
"log"
"net/http"
"os"
"path/filepath"
"strings"
"sync"
"time"
"sub2api/internal/config"
)
// LiteLLMModelPricing LiteLLM价格数据结构
// 只保留我们需要的字段,使用指针来处理可能缺失的值
type LiteLLMModelPricing struct {
InputCostPerToken float64 `json:"input_cost_per_token"`
OutputCostPerToken float64 `json:"output_cost_per_token"`
CacheCreationInputTokenCost float64 `json:"cache_creation_input_token_cost"`
CacheReadInputTokenCost float64 `json:"cache_read_input_token_cost"`
LiteLLMProvider string `json:"litellm_provider"`
Mode string `json:"mode"`
SupportsPromptCaching bool `json:"supports_prompt_caching"`
}
// LiteLLMRawEntry 用于解析原始JSON数据
type LiteLLMRawEntry struct {
InputCostPerToken *float64 `json:"input_cost_per_token"`
OutputCostPerToken *float64 `json:"output_cost_per_token"`
CacheCreationInputTokenCost *float64 `json:"cache_creation_input_token_cost"`
CacheReadInputTokenCost *float64 `json:"cache_read_input_token_cost"`
LiteLLMProvider string `json:"litellm_provider"`
Mode string `json:"mode"`
SupportsPromptCaching bool `json:"supports_prompt_caching"`
}
// PricingService 动态价格服务
type PricingService struct {
cfg *config.Config
mu sync.RWMutex
pricingData map[string]*LiteLLMModelPricing
lastUpdated time.Time
localHash string
// 停止信号
stopCh chan struct{}
wg sync.WaitGroup
}
// NewPricingService 创建价格服务
func NewPricingService(cfg *config.Config) *PricingService {
s := &PricingService{
cfg: cfg,
pricingData: make(map[string]*LiteLLMModelPricing),
stopCh: make(chan struct{}),
}
return s
}
// Initialize 初始化价格服务
func (s *PricingService) Initialize() error {
// 确保数据目录存在
if err := os.MkdirAll(s.cfg.Pricing.DataDir, 0755); err != nil {
log.Printf("[Pricing] Failed to create data directory: %v", err)
}
// 首次加载价格数据
if err := s.checkAndUpdatePricing(); err != nil {
log.Printf("[Pricing] Initial load failed, using fallback: %v", err)
if err := s.useFallbackPricing(); err != nil {
return fmt.Errorf("failed to load pricing data: %w", err)
}
}
// 启动定时更新
s.startUpdateScheduler()
log.Printf("[Pricing] Service initialized with %d models", len(s.pricingData))
return nil
}
// Stop 停止价格服务
func (s *PricingService) Stop() {
close(s.stopCh)
s.wg.Wait()
log.Println("[Pricing] Service stopped")
}
// startUpdateScheduler 启动定时更新调度器
func (s *PricingService) startUpdateScheduler() {
// 定期检查哈希更新
hashInterval := time.Duration(s.cfg.Pricing.HashCheckIntervalMinutes) * time.Minute
if hashInterval < time.Minute {
hashInterval = 10 * time.Minute
}
s.wg.Add(1)
go func() {
defer s.wg.Done()
ticker := time.NewTicker(hashInterval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
if err := s.syncWithRemote(); err != nil {
log.Printf("[Pricing] Sync failed: %v", err)
}
case <-s.stopCh:
return
}
}
}()
log.Printf("[Pricing] Update scheduler started (check every %v)", hashInterval)
}
// checkAndUpdatePricing 检查并更新价格数据
func (s *PricingService) checkAndUpdatePricing() error {
pricingFile := s.getPricingFilePath()
// 检查本地文件是否存在
if _, err := os.Stat(pricingFile); os.IsNotExist(err) {
log.Println("[Pricing] Local pricing file not found, downloading...")
return s.downloadPricingData()
}
// 检查文件是否过期
info, err := os.Stat(pricingFile)
if err != nil {
return s.downloadPricingData()
}
fileAge := time.Since(info.ModTime())
maxAge := time.Duration(s.cfg.Pricing.UpdateIntervalHours) * time.Hour
if fileAge > maxAge {
log.Printf("[Pricing] Local file is %v old, updating...", fileAge.Round(time.Hour))
if err := s.downloadPricingData(); err != nil {
log.Printf("[Pricing] Download failed, using existing file: %v", err)
}
}
// 加载本地文件
return s.loadPricingData(pricingFile)
}
// syncWithRemote 与远程同步(基于哈希校验)
func (s *PricingService) syncWithRemote() error {
pricingFile := s.getPricingFilePath()
// 计算本地文件哈希
localHash, err := s.computeFileHash(pricingFile)
if err != nil {
log.Printf("[Pricing] Failed to compute local hash: %v", err)
return s.downloadPricingData()
}
// 如果配置了哈希URL从远程获取哈希进行比对
if s.cfg.Pricing.HashURL != "" {
remoteHash, err := s.fetchRemoteHash()
if err != nil {
log.Printf("[Pricing] Failed to fetch remote hash: %v", err)
return nil // 哈希获取失败不影响正常使用
}
if remoteHash != localHash {
log.Println("[Pricing] Remote hash differs, downloading new version...")
return s.downloadPricingData()
}
log.Println("[Pricing] Hash check passed, no update needed")
return nil
}
// 没有哈希URL时基于时间检查
info, err := os.Stat(pricingFile)
if err != nil {
return s.downloadPricingData()
}
fileAge := time.Since(info.ModTime())
maxAge := time.Duration(s.cfg.Pricing.UpdateIntervalHours) * time.Hour
if fileAge > maxAge {
log.Printf("[Pricing] File is %v old, downloading...", fileAge.Round(time.Hour))
return s.downloadPricingData()
}
return nil
}
// downloadPricingData 从远程下载价格数据
func (s *PricingService) downloadPricingData() error {
log.Printf("[Pricing] Downloading from %s", s.cfg.Pricing.RemoteURL)
client := &http.Client{Timeout: 30 * time.Second}
resp, err := client.Get(s.cfg.Pricing.RemoteURL)
if err != nil {
return fmt.Errorf("download failed: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("download failed: HTTP %d", resp.StatusCode)
}
body, err := io.ReadAll(resp.Body)
if err != nil {
return fmt.Errorf("read response failed: %w", err)
}
// 解析JSON数据使用灵活的解析方式
data, err := s.parsePricingData(body)
if err != nil {
return fmt.Errorf("parse pricing data: %w", err)
}
// 保存到本地文件
pricingFile := s.getPricingFilePath()
if err := os.WriteFile(pricingFile, body, 0644); err != nil {
log.Printf("[Pricing] Failed to save file: %v", err)
}
// 保存哈希
hash := sha256.Sum256(body)
hashStr := hex.EncodeToString(hash[:])
hashFile := s.getHashFilePath()
if err := os.WriteFile(hashFile, []byte(hashStr+"\n"), 0644); err != nil {
log.Printf("[Pricing] Failed to save hash: %v", err)
}
// 更新内存数据
s.mu.Lock()
s.pricingData = data
s.lastUpdated = time.Now()
s.localHash = hashStr
s.mu.Unlock()
log.Printf("[Pricing] Downloaded %d models successfully", len(data))
return nil
}
// parsePricingData 解析价格数据(处理各种格式)
func (s *PricingService) parsePricingData(body []byte) (map[string]*LiteLLMModelPricing, error) {
// 首先解析为 map[string]json.RawMessage
var rawData map[string]json.RawMessage
if err := json.Unmarshal(body, &rawData); err != nil {
return nil, fmt.Errorf("parse raw JSON: %w", err)
}
result := make(map[string]*LiteLLMModelPricing)
skipped := 0
for modelName, rawEntry := range rawData {
// 跳过 sample_spec 等文档条目
if modelName == "sample_spec" {
continue
}
// 尝试解析每个条目
var entry LiteLLMRawEntry
if err := json.Unmarshal(rawEntry, &entry); err != nil {
skipped++
continue
}
// 只保留有有效价格的条目
if entry.InputCostPerToken == nil && entry.OutputCostPerToken == nil {
continue
}
pricing := &LiteLLMModelPricing{
LiteLLMProvider: entry.LiteLLMProvider,
Mode: entry.Mode,
SupportsPromptCaching: entry.SupportsPromptCaching,
}
if entry.InputCostPerToken != nil {
pricing.InputCostPerToken = *entry.InputCostPerToken
}
if entry.OutputCostPerToken != nil {
pricing.OutputCostPerToken = *entry.OutputCostPerToken
}
if entry.CacheCreationInputTokenCost != nil {
pricing.CacheCreationInputTokenCost = *entry.CacheCreationInputTokenCost
}
if entry.CacheReadInputTokenCost != nil {
pricing.CacheReadInputTokenCost = *entry.CacheReadInputTokenCost
}
result[modelName] = pricing
}
if skipped > 0 {
log.Printf("[Pricing] Skipped %d invalid entries", skipped)
}
if len(result) == 0 {
return nil, fmt.Errorf("no valid pricing entries found")
}
return result, nil
}
// loadPricingData 从本地文件加载价格数据
func (s *PricingService) loadPricingData(filePath string) error {
data, err := os.ReadFile(filePath)
if err != nil {
return fmt.Errorf("read file failed: %w", err)
}
// 使用灵活的解析方式
pricingData, err := s.parsePricingData(data)
if err != nil {
return fmt.Errorf("parse pricing data: %w", err)
}
// 计算哈希
hash := sha256.Sum256(data)
hashStr := hex.EncodeToString(hash[:])
s.mu.Lock()
s.pricingData = pricingData
s.localHash = hashStr
info, _ := os.Stat(filePath)
if info != nil {
s.lastUpdated = info.ModTime()
} else {
s.lastUpdated = time.Now()
}
s.mu.Unlock()
log.Printf("[Pricing] Loaded %d models from %s", len(pricingData), filePath)
return nil
}
// useFallbackPricing 使用回退价格文件
func (s *PricingService) useFallbackPricing() error {
fallbackFile := s.cfg.Pricing.FallbackFile
if _, err := os.Stat(fallbackFile); os.IsNotExist(err) {
return fmt.Errorf("fallback file not found: %s", fallbackFile)
}
log.Printf("[Pricing] Using fallback file: %s", fallbackFile)
// 复制到数据目录
data, err := os.ReadFile(fallbackFile)
if err != nil {
return fmt.Errorf("read fallback failed: %w", err)
}
pricingFile := s.getPricingFilePath()
if err := os.WriteFile(pricingFile, data, 0644); err != nil {
log.Printf("[Pricing] Failed to copy fallback: %v", err)
}
return s.loadPricingData(fallbackFile)
}
// fetchRemoteHash 从远程获取哈希值
func (s *PricingService) fetchRemoteHash() (string, error) {
client := &http.Client{Timeout: 10 * time.Second}
resp, err := client.Get(s.cfg.Pricing.HashURL)
if err != nil {
return "", err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("HTTP %d", resp.StatusCode)
}
body, err := io.ReadAll(resp.Body)
if err != nil {
return "", err
}
// 哈希文件格式hash filename 或者纯 hash
hash := strings.TrimSpace(string(body))
parts := strings.Fields(hash)
if len(parts) > 0 {
return parts[0], nil
}
return hash, nil
}
// computeFileHash 计算文件哈希
func (s *PricingService) computeFileHash(filePath string) (string, error) {
data, err := os.ReadFile(filePath)
if err != nil {
return "", err
}
hash := sha256.Sum256(data)
return hex.EncodeToString(hash[:]), nil
}
// GetModelPricing 获取模型价格(带模糊匹配)
func (s *PricingService) GetModelPricing(modelName string) *LiteLLMModelPricing {
s.mu.RLock()
defer s.mu.RUnlock()
if modelName == "" {
return nil
}
// 标准化模型名称
modelLower := strings.ToLower(modelName)
// 1. 精确匹配
if pricing, ok := s.pricingData[modelLower]; ok {
return pricing
}
if pricing, ok := s.pricingData[modelName]; ok {
return pricing
}
// 2. 处理常见的模型名称变体
// claude-opus-4-5-20251101 -> claude-opus-4.5-20251101
normalized := strings.ReplaceAll(modelLower, "-4-5-", "-4.5-")
if pricing, ok := s.pricingData[normalized]; ok {
return pricing
}
// 3. 尝试模糊匹配(去掉版本号后缀)
// claude-opus-4-5-20251101 -> claude-opus-4.5
baseName := s.extractBaseName(modelLower)
for key, pricing := range s.pricingData {
keyBase := s.extractBaseName(strings.ToLower(key))
if keyBase == baseName {
return pricing
}
}
// 4. 基于模型系列匹配
return s.matchByModelFamily(modelLower)
}
// extractBaseName 提取基础模型名称(去掉日期版本号)
func (s *PricingService) extractBaseName(model string) string {
// 移除日期后缀 (如 -20251101, -20241022)
parts := strings.Split(model, "-")
result := make([]string, 0, len(parts))
for _, part := range parts {
// 跳过看起来像日期的部分8位数字
if len(part) == 8 && isNumeric(part) {
continue
}
// 跳过版本号(如 v1:0
if strings.Contains(part, ":") {
continue
}
result = append(result, part)
}
return strings.Join(result, "-")
}
// matchByModelFamily 基于模型系列匹配
func (s *PricingService) matchByModelFamily(model string) *LiteLLMModelPricing {
// Claude模型系列匹配规则
familyPatterns := map[string][]string{
"opus-4.5": {"claude-opus-4.5", "claude-opus-4-5"},
"opus-4": {"claude-opus-4", "claude-3-opus"},
"sonnet-4.5": {"claude-sonnet-4.5", "claude-sonnet-4-5"},
"sonnet-4": {"claude-sonnet-4", "claude-3-5-sonnet"},
"sonnet-3.5": {"claude-3-5-sonnet", "claude-3.5-sonnet"},
"sonnet-3": {"claude-3-sonnet"},
"haiku-3.5": {"claude-3-5-haiku", "claude-3.5-haiku"},
"haiku-3": {"claude-3-haiku"},
}
// 确定模型属于哪个系列
var matchedFamily string
for family, patterns := range familyPatterns {
for _, pattern := range patterns {
if strings.Contains(model, pattern) || strings.Contains(model, strings.ReplaceAll(pattern, "-", "")) {
matchedFamily = family
break
}
}
if matchedFamily != "" {
break
}
}
if matchedFamily == "" {
// 简单的系列匹配
if strings.Contains(model, "opus") {
if strings.Contains(model, "4.5") || strings.Contains(model, "4-5") {
matchedFamily = "opus-4.5"
} else {
matchedFamily = "opus-4"
}
} else if strings.Contains(model, "sonnet") {
if strings.Contains(model, "4.5") || strings.Contains(model, "4-5") {
matchedFamily = "sonnet-4.5"
} else if strings.Contains(model, "3-5") || strings.Contains(model, "3.5") {
matchedFamily = "sonnet-3.5"
} else {
matchedFamily = "sonnet-4"
}
} else if strings.Contains(model, "haiku") {
if strings.Contains(model, "3-5") || strings.Contains(model, "3.5") {
matchedFamily = "haiku-3.5"
} else {
matchedFamily = "haiku-3"
}
}
}
if matchedFamily == "" {
return nil
}
// 在价格数据中查找该系列的模型
patterns := familyPatterns[matchedFamily]
for _, pattern := range patterns {
for key, pricing := range s.pricingData {
keyLower := strings.ToLower(key)
if strings.Contains(keyLower, pattern) {
log.Printf("[Pricing] Fuzzy matched %s -> %s", model, key)
return pricing
}
}
}
return nil
}
// GetStatus 获取服务状态
func (s *PricingService) GetStatus() map[string]interface{} {
s.mu.RLock()
defer s.mu.RUnlock()
return map[string]interface{}{
"model_count": len(s.pricingData),
"last_updated": s.lastUpdated,
"local_hash": s.localHash[:min(8, len(s.localHash))],
}
}
// ForceUpdate 强制更新
func (s *PricingService) ForceUpdate() error {
return s.downloadPricingData()
}
// getPricingFilePath 获取价格文件路径
func (s *PricingService) getPricingFilePath() string {
return filepath.Join(s.cfg.Pricing.DataDir, "model_pricing.json")
}
// getHashFilePath 获取哈希文件路径
func (s *PricingService) getHashFilePath() string {
return filepath.Join(s.cfg.Pricing.DataDir, "model_pricing.sha256")
}
// isNumeric 检查字符串是否为纯数字
func isNumeric(s string) bool {
for _, c := range s {
if c < '0' || c > '9' {
return false
}
}
return true
}

View File

@@ -0,0 +1,192 @@
package service
import (
"context"
"errors"
"fmt"
"sub2api/internal/model"
"sub2api/internal/repository"
"gorm.io/gorm"
)
var (
ErrProxyNotFound = errors.New("proxy not found")
)
// CreateProxyRequest 创建代理请求
type CreateProxyRequest struct {
Name string `json:"name"`
Protocol string `json:"protocol"`
Host string `json:"host"`
Port int `json:"port"`
Username string `json:"username"`
Password string `json:"password"`
}
// UpdateProxyRequest 更新代理请求
type UpdateProxyRequest struct {
Name *string `json:"name"`
Protocol *string `json:"protocol"`
Host *string `json:"host"`
Port *int `json:"port"`
Username *string `json:"username"`
Password *string `json:"password"`
Status *string `json:"status"`
}
// ProxyService 代理管理服务
type ProxyService struct {
proxyRepo *repository.ProxyRepository
}
// NewProxyService 创建代理服务实例
func NewProxyService(proxyRepo *repository.ProxyRepository) *ProxyService {
return &ProxyService{
proxyRepo: proxyRepo,
}
}
// Create 创建代理
func (s *ProxyService) Create(ctx context.Context, req CreateProxyRequest) (*model.Proxy, error) {
// 创建代理
proxy := &model.Proxy{
Name: req.Name,
Protocol: req.Protocol,
Host: req.Host,
Port: req.Port,
Username: req.Username,
Password: req.Password,
Status: model.StatusActive,
}
if err := s.proxyRepo.Create(ctx, proxy); err != nil {
return nil, fmt.Errorf("create proxy: %w", err)
}
return proxy, nil
}
// GetByID 根据ID获取代理
func (s *ProxyService) GetByID(ctx context.Context, id int64) (*model.Proxy, error) {
proxy, err := s.proxyRepo.GetByID(ctx, id)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrProxyNotFound
}
return nil, fmt.Errorf("get proxy: %w", err)
}
return proxy, nil
}
// List 获取代理列表
func (s *ProxyService) List(ctx context.Context, params repository.PaginationParams) ([]model.Proxy, *repository.PaginationResult, error) {
proxies, pagination, err := s.proxyRepo.List(ctx, params)
if err != nil {
return nil, nil, fmt.Errorf("list proxies: %w", err)
}
return proxies, pagination, nil
}
// ListActive 获取活跃代理列表
func (s *ProxyService) ListActive(ctx context.Context) ([]model.Proxy, error) {
proxies, err := s.proxyRepo.ListActive(ctx)
if err != nil {
return nil, fmt.Errorf("list active proxies: %w", err)
}
return proxies, nil
}
// Update 更新代理
func (s *ProxyService) Update(ctx context.Context, id int64, req UpdateProxyRequest) (*model.Proxy, error) {
proxy, err := s.proxyRepo.GetByID(ctx, id)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrProxyNotFound
}
return nil, fmt.Errorf("get proxy: %w", err)
}
// 更新字段
if req.Name != nil {
proxy.Name = *req.Name
}
if req.Protocol != nil {
proxy.Protocol = *req.Protocol
}
if req.Host != nil {
proxy.Host = *req.Host
}
if req.Port != nil {
proxy.Port = *req.Port
}
if req.Username != nil {
proxy.Username = *req.Username
}
if req.Password != nil {
proxy.Password = *req.Password
}
if req.Status != nil {
proxy.Status = *req.Status
}
if err := s.proxyRepo.Update(ctx, proxy); err != nil {
return nil, fmt.Errorf("update proxy: %w", err)
}
return proxy, nil
}
// Delete 删除代理
func (s *ProxyService) Delete(ctx context.Context, id int64) error {
// 检查代理是否存在
_, err := s.proxyRepo.GetByID(ctx, id)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return ErrProxyNotFound
}
return fmt.Errorf("get proxy: %w", err)
}
if err := s.proxyRepo.Delete(ctx, id); err != nil {
return fmt.Errorf("delete proxy: %w", err)
}
return nil
}
// TestConnection 测试代理连接(需要实现具体测试逻辑)
func (s *ProxyService) TestConnection(ctx context.Context, id int64) error {
proxy, err := s.proxyRepo.GetByID(ctx, id)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return ErrProxyNotFound
}
return fmt.Errorf("get proxy: %w", err)
}
// TODO: 实现代理连接测试逻辑
// 可以尝试通过代理发送测试请求
_ = proxy
return nil
}
// GetURL 获取代理URL
func (s *ProxyService) GetURL(ctx context.Context, id int64) (string, error) {
proxy, err := s.proxyRepo.GetByID(ctx, id)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return "", ErrProxyNotFound
}
return "", fmt.Errorf("get proxy: %w", err)
}
return proxy.URL(), nil
}

View File

@@ -0,0 +1,170 @@
package service
import (
"context"
"log"
"net/http"
"strconv"
"time"
"sub2api/internal/config"
"sub2api/internal/model"
"sub2api/internal/repository"
)
// RateLimitService 处理限流和过载状态管理
type RateLimitService struct {
repos *repository.Repositories
cfg *config.Config
}
// NewRateLimitService 创建RateLimitService实例
func NewRateLimitService(repos *repository.Repositories, cfg *config.Config) *RateLimitService {
return &RateLimitService{
repos: repos,
cfg: cfg,
}
}
// HandleUpstreamError 处理上游错误响应,标记账号状态
// 返回是否应该停止该账号的调度
func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *model.Account, statusCode int, headers http.Header, responseBody []byte) (shouldDisable bool) {
// apikey 类型账号:检查自定义错误码配置
// 如果启用且错误码不在列表中,则不处理(不停止调度、不标记限流/过载)
if !account.ShouldHandleErrorCode(statusCode) {
log.Printf("Account %d: error %d skipped (not in custom error codes)", account.ID, statusCode)
return false
}
switch statusCode {
case 401:
// 认证失败:停止调度,记录错误
s.handleAuthError(ctx, account, "Authentication failed (401): invalid or expired credentials")
return true
case 403:
// 禁止访问:停止调度,记录错误
s.handleAuthError(ctx, account, "Access forbidden (403): account may be suspended or lack permissions")
return true
case 429:
s.handle429(ctx, account, headers)
return false
case 529:
s.handle529(ctx, account)
return false
default:
// 其他5xx错误记录但不停止调度
if statusCode >= 500 {
log.Printf("Account %d received upstream error %d", account.ID, statusCode)
}
return false
}
}
// handleAuthError 处理认证类错误(401/403),停止账号调度
func (s *RateLimitService) handleAuthError(ctx context.Context, account *model.Account, errorMsg string) {
if err := s.repos.Account.SetError(ctx, account.ID, errorMsg); err != nil {
log.Printf("SetError failed for account %d: %v", account.ID, err)
return
}
log.Printf("Account %d disabled due to auth error: %s", account.ID, errorMsg)
}
// handle429 处理429限流错误
// 解析响应头获取重置时间,标记账号为限流状态
func (s *RateLimitService) handle429(ctx context.Context, account *model.Account, headers http.Header) {
// 解析重置时间戳
resetTimestamp := headers.Get("anthropic-ratelimit-unified-reset")
if resetTimestamp == "" {
// 没有重置时间使用默认5分钟
resetAt := time.Now().Add(5 * time.Minute)
if err := s.repos.Account.SetRateLimited(ctx, account.ID, resetAt); err != nil {
log.Printf("SetRateLimited failed for account %d: %v", account.ID, err)
}
return
}
// 解析Unix时间戳
ts, err := strconv.ParseInt(resetTimestamp, 10, 64)
if err != nil {
log.Printf("Parse reset timestamp failed: %v", err)
resetAt := time.Now().Add(5 * time.Minute)
if err := s.repos.Account.SetRateLimited(ctx, account.ID, resetAt); err != nil {
log.Printf("SetRateLimited failed for account %d: %v", account.ID, err)
}
return
}
resetAt := time.Unix(ts, 0)
// 标记限流状态
if err := s.repos.Account.SetRateLimited(ctx, account.ID, resetAt); err != nil {
log.Printf("SetRateLimited failed for account %d: %v", account.ID, err)
return
}
// 根据重置时间反推5h窗口
windowEnd := resetAt
windowStart := resetAt.Add(-5 * time.Hour)
if err := s.repos.Account.UpdateSessionWindow(ctx, account.ID, &windowStart, &windowEnd, "rejected"); err != nil {
log.Printf("UpdateSessionWindow failed for account %d: %v", account.ID, err)
}
log.Printf("Account %d rate limited until %v", account.ID, resetAt)
}
// handle529 处理529过载错误
// 根据配置设置过载冷却时间
func (s *RateLimitService) handle529(ctx context.Context, account *model.Account) {
cooldownMinutes := s.cfg.RateLimit.OverloadCooldownMinutes
if cooldownMinutes <= 0 {
cooldownMinutes = 10 // 默认10分钟
}
until := time.Now().Add(time.Duration(cooldownMinutes) * time.Minute)
if err := s.repos.Account.SetOverloaded(ctx, account.ID, until); err != nil {
log.Printf("SetOverloaded failed for account %d: %v", account.ID, err)
return
}
log.Printf("Account %d overloaded until %v", account.ID, until)
}
// UpdateSessionWindow 从成功响应更新5h窗口状态
func (s *RateLimitService) UpdateSessionWindow(ctx context.Context, account *model.Account, headers http.Header) {
status := headers.Get("anthropic-ratelimit-unified-5h-status")
if status == "" {
return
}
// 检查是否需要初始化时间窗口
// 对于 Setup Token 账号,首次成功请求时需要预测时间窗口
var windowStart, windowEnd *time.Time
needInitWindow := account.SessionWindowEnd == nil || time.Now().After(*account.SessionWindowEnd)
if needInitWindow && (status == "allowed" || status == "allowed_warning") {
// 预测时间窗口:从当前时间的整点开始,+5小时为结束
// 例如:现在是 14:30窗口为 14:00 ~ 19:00
now := time.Now()
start := time.Date(now.Year(), now.Month(), now.Day(), now.Hour(), 0, 0, 0, now.Location())
end := start.Add(5 * time.Hour)
windowStart = &start
windowEnd = &end
log.Printf("Account %d: initializing 5h window from %v to %v (status: %s)", account.ID, start, end, status)
}
if err := s.repos.Account.UpdateSessionWindow(ctx, account.ID, windowStart, windowEnd, status); err != nil {
log.Printf("UpdateSessionWindow failed for account %d: %v", account.ID, err)
}
// 如果状态为allowed且之前有限流说明窗口已重置清除限流状态
if status == "allowed" && account.IsRateLimited() {
if err := s.repos.Account.ClearRateLimit(ctx, account.ID); err != nil {
log.Printf("ClearRateLimit failed for account %d: %v", account.ID, err)
}
}
}
// ClearRateLimit 清除账号的限流状态
func (s *RateLimitService) ClearRateLimit(ctx context.Context, accountID int64) error {
return s.repos.Account.ClearRateLimit(ctx, accountID)
}

View File

@@ -0,0 +1,392 @@
package service
import (
"context"
"crypto/rand"
"encoding/hex"
"errors"
"fmt"
"strings"
"sub2api/internal/model"
"sub2api/internal/repository"
"time"
"github.com/redis/go-redis/v9"
"gorm.io/gorm"
)
var (
ErrRedeemCodeNotFound = errors.New("redeem code not found")
ErrRedeemCodeUsed = errors.New("redeem code already used")
ErrRedeemCodeInvalid = errors.New("invalid redeem code")
ErrInsufficientBalance = errors.New("insufficient balance")
ErrRedeemRateLimited = errors.New("too many failed attempts, please try again later")
ErrRedeemCodeLocked = errors.New("redeem code is being processed, please try again")
)
const (
redeemRateLimitKeyPrefix = "redeem:rate_limit:"
redeemLockKeyPrefix = "redeem:lock:"
redeemMaxErrorsPerHour = 20
redeemRateLimitDuration = time.Hour
redeemLockDuration = 10 * time.Second // 锁超时时间,防止死锁
)
// GenerateCodesRequest 生成兑换码请求
type GenerateCodesRequest struct {
Count int `json:"count"`
Value float64 `json:"value"`
Type string `json:"type"`
}
// RedeemCodeResponse 兑换码响应
type RedeemCodeResponse struct {
Code string `json:"code"`
Value float64 `json:"value"`
Status string `json:"status"`
CreatedAt time.Time `json:"created_at"`
}
// RedeemService 兑换码服务
type RedeemService struct {
redeemRepo *repository.RedeemCodeRepository
userRepo *repository.UserRepository
subscriptionService *SubscriptionService
rdb *redis.Client
billingCacheService *BillingCacheService
}
// NewRedeemService 创建兑换码服务实例
func NewRedeemService(redeemRepo *repository.RedeemCodeRepository, userRepo *repository.UserRepository, subscriptionService *SubscriptionService, rdb *redis.Client) *RedeemService {
return &RedeemService{
redeemRepo: redeemRepo,
userRepo: userRepo,
subscriptionService: subscriptionService,
rdb: rdb,
}
}
// SetBillingCacheService 设置计费缓存服务(用于缓存失效)
func (s *RedeemService) SetBillingCacheService(billingCacheService *BillingCacheService) {
s.billingCacheService = billingCacheService
}
// GenerateRandomCode 生成随机兑换码
func (s *RedeemService) GenerateRandomCode() (string, error) {
// 生成16字节随机数据
bytes := make([]byte, 16)
if _, err := rand.Read(bytes); err != nil {
return "", fmt.Errorf("generate random bytes: %w", err)
}
// 转换为十六进制字符串
code := hex.EncodeToString(bytes)
// 格式化为 XXXX-XXXX-XXXX-XXXX 格式
parts := []string{
strings.ToUpper(code[0:8]),
strings.ToUpper(code[8:16]),
strings.ToUpper(code[16:24]),
strings.ToUpper(code[24:32]),
}
return strings.Join(parts, "-"), nil
}
// GenerateCodes 批量生成兑换码
func (s *RedeemService) GenerateCodes(ctx context.Context, req GenerateCodesRequest) ([]model.RedeemCode, error) {
if req.Count <= 0 {
return nil, errors.New("count must be greater than 0")
}
if req.Value <= 0 {
return nil, errors.New("value must be greater than 0")
}
if req.Count > 1000 {
return nil, errors.New("cannot generate more than 1000 codes at once")
}
codeType := req.Type
if codeType == "" {
codeType = model.RedeemTypeBalance
}
codes := make([]model.RedeemCode, 0, req.Count)
for i := 0; i < req.Count; i++ {
code, err := s.GenerateRandomCode()
if err != nil {
return nil, fmt.Errorf("generate code: %w", err)
}
codes = append(codes, model.RedeemCode{
Code: code,
Type: codeType,
Value: req.Value,
Status: model.StatusUnused,
})
}
// 批量插入
if err := s.redeemRepo.CreateBatch(ctx, codes); err != nil {
return nil, fmt.Errorf("create batch codes: %w", err)
}
return codes, nil
}
// checkRedeemRateLimit 检查用户兑换错误次数是否超限
func (s *RedeemService) checkRedeemRateLimit(ctx context.Context, userID int64) error {
if s.rdb == nil {
return nil
}
key := fmt.Sprintf("%s%d", redeemRateLimitKeyPrefix, userID)
count, err := s.rdb.Get(ctx, key).Int()
if err != nil && !errors.Is(err, redis.Nil) {
// Redis 出错时不阻止用户操作
return nil
}
if count >= redeemMaxErrorsPerHour {
return ErrRedeemRateLimited
}
return nil
}
// incrementRedeemErrorCount 增加用户兑换错误计数
func (s *RedeemService) incrementRedeemErrorCount(ctx context.Context, userID int64) {
if s.rdb == nil {
return
}
key := fmt.Sprintf("%s%d", redeemRateLimitKeyPrefix, userID)
pipe := s.rdb.Pipeline()
pipe.Incr(ctx, key)
pipe.Expire(ctx, key, redeemRateLimitDuration)
_, _ = pipe.Exec(ctx)
}
// acquireRedeemLock 尝试获取兑换码的分布式锁
// 返回 true 表示获取成功false 表示锁已被占用
func (s *RedeemService) acquireRedeemLock(ctx context.Context, code string) bool {
if s.rdb == nil {
return true // 无 Redis 时降级为不加锁
}
key := redeemLockKeyPrefix + code
ok, err := s.rdb.SetNX(ctx, key, "1", redeemLockDuration).Result()
if err != nil {
// Redis 出错时不阻止操作,依赖数据库层面的状态检查
return true
}
return ok
}
// releaseRedeemLock 释放兑换码的分布式锁
func (s *RedeemService) releaseRedeemLock(ctx context.Context, code string) {
if s.rdb == nil {
return
}
key := redeemLockKeyPrefix + code
s.rdb.Del(ctx, key)
}
// Redeem 使用兑换码
func (s *RedeemService) Redeem(ctx context.Context, userID int64, code string) (*model.RedeemCode, error) {
// 检查限流
if err := s.checkRedeemRateLimit(ctx, userID); err != nil {
return nil, err
}
// 获取分布式锁,防止同一兑换码并发使用
if !s.acquireRedeemLock(ctx, code) {
return nil, ErrRedeemCodeLocked
}
defer s.releaseRedeemLock(ctx, code)
// 查找兑换码
redeemCode, err := s.redeemRepo.GetByCode(ctx, code)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
s.incrementRedeemErrorCount(ctx, userID)
return nil, ErrRedeemCodeNotFound
}
return nil, fmt.Errorf("get redeem code: %w", err)
}
// 检查兑换码状态
if !redeemCode.CanUse() {
s.incrementRedeemErrorCount(ctx, userID)
return nil, ErrRedeemCodeUsed
}
// 验证兑换码类型的前置条件
if redeemCode.Type == model.RedeemTypeSubscription && redeemCode.GroupID == nil {
return nil, errors.New("invalid subscription redeem code: missing group_id")
}
// 获取用户信息
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrUserNotFound
}
return nil, fmt.Errorf("get user: %w", err)
}
_ = user // 使用变量避免未使用错误
// 【关键】先标记兑换码为已使用,确保并发安全
// 利用数据库乐观锁WHERE status = 'unused')保证原子性
if err := s.redeemRepo.Use(ctx, redeemCode.ID, userID); err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
// 兑换码已被其他请求使用
return nil, ErrRedeemCodeUsed
}
return nil, fmt.Errorf("mark code as used: %w", err)
}
// 执行兑换逻辑(兑换码已被锁定,此时可安全操作)
switch redeemCode.Type {
case model.RedeemTypeBalance:
// 增加用户余额
if err := s.userRepo.UpdateBalance(ctx, userID, redeemCode.Value); err != nil {
return nil, fmt.Errorf("update user balance: %w", err)
}
// 失效余额缓存
if s.billingCacheService != nil {
go func() {
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
s.billingCacheService.InvalidateUserBalance(cacheCtx, userID)
}()
}
case model.RedeemTypeConcurrency:
// 增加用户并发数
if err := s.userRepo.UpdateConcurrency(ctx, userID, int(redeemCode.Value)); err != nil {
return nil, fmt.Errorf("update user concurrency: %w", err)
}
case model.RedeemTypeSubscription:
validityDays := redeemCode.ValidityDays
if validityDays <= 0 {
validityDays = 30
}
_, _, err := s.subscriptionService.AssignOrExtendSubscription(ctx, &AssignSubscriptionInput{
UserID: userID,
GroupID: *redeemCode.GroupID,
ValidityDays: validityDays,
AssignedBy: 0, // 系统分配
Notes: fmt.Sprintf("通过兑换码 %s 兑换", redeemCode.Code),
})
if err != nil {
return nil, fmt.Errorf("assign or extend subscription: %w", err)
}
// 失效订阅缓存
if s.billingCacheService != nil {
groupID := *redeemCode.GroupID
go func() {
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
s.billingCacheService.InvalidateSubscription(cacheCtx, userID, groupID)
}()
}
default:
return nil, fmt.Errorf("unsupported redeem type: %s", redeemCode.Type)
}
// 重新获取更新后的兑换码
redeemCode, err = s.redeemRepo.GetByID(ctx, redeemCode.ID)
if err != nil {
return nil, fmt.Errorf("get updated redeem code: %w", err)
}
return redeemCode, nil
}
// GetByID 根据ID获取兑换码
func (s *RedeemService) GetByID(ctx context.Context, id int64) (*model.RedeemCode, error) {
code, err := s.redeemRepo.GetByID(ctx, id)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrRedeemCodeNotFound
}
return nil, fmt.Errorf("get redeem code: %w", err)
}
return code, nil
}
// GetByCode 根据Code获取兑换码
func (s *RedeemService) GetByCode(ctx context.Context, code string) (*model.RedeemCode, error) {
redeemCode, err := s.redeemRepo.GetByCode(ctx, code)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrRedeemCodeNotFound
}
return nil, fmt.Errorf("get redeem code: %w", err)
}
return redeemCode, nil
}
// List 获取兑换码列表(管理员功能)
func (s *RedeemService) List(ctx context.Context, params repository.PaginationParams) ([]model.RedeemCode, *repository.PaginationResult, error) {
codes, pagination, err := s.redeemRepo.List(ctx, params)
if err != nil {
return nil, nil, fmt.Errorf("list redeem codes: %w", err)
}
return codes, pagination, nil
}
// Delete 删除兑换码(管理员功能)
func (s *RedeemService) Delete(ctx context.Context, id int64) error {
// 检查兑换码是否存在
code, err := s.redeemRepo.GetByID(ctx, id)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return ErrRedeemCodeNotFound
}
return fmt.Errorf("get redeem code: %w", err)
}
// 不允许删除已使用的兑换码
if code.IsUsed() {
return errors.New("cannot delete used redeem code")
}
if err := s.redeemRepo.Delete(ctx, id); err != nil {
return fmt.Errorf("delete redeem code: %w", err)
}
return nil
}
// GetStats 获取兑换码统计信息
func (s *RedeemService) GetStats(ctx context.Context) (map[string]interface{}, error) {
// TODO: 实现统计逻辑
// 统计未使用、已使用的兑换码数量
// 统计总面值等
stats := map[string]interface{}{
"total_codes": 0,
"unused_codes": 0,
"used_codes": 0,
"total_value": 0.0,
}
return stats, nil
}
// GetUserHistory 获取用户的兑换历史
func (s *RedeemService) GetUserHistory(ctx context.Context, userID int64, limit int) ([]model.RedeemCode, error) {
codes, err := s.redeemRepo.ListByUser(ctx, userID, limit)
if err != nil {
return nil, fmt.Errorf("get user redeem history: %w", err)
}
return codes, nil
}

View File

@@ -0,0 +1,139 @@
package service
import (
"sub2api/internal/config"
"sub2api/internal/repository"
"github.com/redis/go-redis/v9"
)
// Services 服务集合容器
type Services struct {
Auth *AuthService
User *UserService
ApiKey *ApiKeyService
Group *GroupService
Account *AccountService
Proxy *ProxyService
Redeem *RedeemService
Usage *UsageService
Pricing *PricingService
Billing *BillingService
BillingCache *BillingCacheService
Admin AdminService
Gateway *GatewayService
OAuth *OAuthService
RateLimit *RateLimitService
AccountUsage *AccountUsageService
AccountTest *AccountTestService
Setting *SettingService
Email *EmailService
EmailQueue *EmailQueueService
Turnstile *TurnstileService
Subscription *SubscriptionService
Concurrency *ConcurrencyService
Identity *IdentityService
}
// NewServices 创建所有服务实例
func NewServices(repos *repository.Repositories, rdb *redis.Client, cfg *config.Config) *Services {
// 初始化价格服务
pricingService := NewPricingService(cfg)
if err := pricingService.Initialize(); err != nil {
// 价格服务初始化失败不应阻止启动,使用回退价格
println("[Service] Warning: Pricing service initialization failed:", err.Error())
}
// 初始化计费服务(依赖价格服务)
billingService := NewBillingService(cfg, pricingService)
// 初始化其他服务
authService := NewAuthService(repos.User, cfg)
userService := NewUserService(repos.User, cfg)
apiKeyService := NewApiKeyService(repos.ApiKey, repos.User, repos.Group, repos.UserSubscription, rdb, cfg)
groupService := NewGroupService(repos.Group)
accountService := NewAccountService(repos.Account, repos.Group)
proxyService := NewProxyService(repos.Proxy)
usageService := NewUsageService(repos.UsageLog, repos.User)
// 初始化订阅服务 (RedeemService 依赖)
subscriptionService := NewSubscriptionService(repos)
// 初始化兑换服务 (依赖订阅服务)
redeemService := NewRedeemService(repos.RedeemCode, repos.User, subscriptionService, rdb)
// 初始化Admin服务
adminService := NewAdminService(repos)
// 初始化OAuth服务GatewayService依赖
oauthService := NewOAuthService(repos.Proxy)
// 初始化限流服务
rateLimitService := NewRateLimitService(repos, cfg)
// 初始化计费缓存服务
billingCacheService := NewBillingCacheService(rdb, repos.User, repos.UserSubscription)
// 初始化账号使用量服务
accountUsageService := NewAccountUsageService(repos, oauthService)
// 初始化账号测试服务
accountTestService := NewAccountTestService(repos, oauthService)
// 初始化身份指纹服务
identityService := NewIdentityService(rdb)
// 初始化Gateway服务
gatewayService := NewGatewayService(repos, rdb, cfg, oauthService, billingService, rateLimitService, billingCacheService, identityService)
// 初始化设置服务
settingService := NewSettingService(repos.Setting, cfg)
emailService := NewEmailService(repos.Setting, rdb)
// 初始化邮件队列服务
emailQueueService := NewEmailQueueService(emailService, 3)
// 初始化Turnstile服务
turnstileService := NewTurnstileService(settingService)
// 设置Auth服务的依赖用于注册开关和邮件验证
authService.SetSettingService(settingService)
authService.SetEmailService(emailService)
authService.SetTurnstileService(turnstileService)
authService.SetEmailQueueService(emailQueueService)
// 初始化并发控制服务
concurrencyService := NewConcurrencyService(rdb)
// 注入计费缓存服务到需要失效缓存的服务
redeemService.SetBillingCacheService(billingCacheService)
subscriptionService.SetBillingCacheService(billingCacheService)
SetAdminServiceBillingCache(adminService, billingCacheService)
return &Services{
Auth: authService,
User: userService,
ApiKey: apiKeyService,
Group: groupService,
Account: accountService,
Proxy: proxyService,
Redeem: redeemService,
Usage: usageService,
Pricing: pricingService,
Billing: billingService,
BillingCache: billingCacheService,
Admin: adminService,
Gateway: gatewayService,
OAuth: oauthService,
RateLimit: rateLimitService,
AccountUsage: accountUsageService,
AccountTest: accountTestService,
Setting: settingService,
Email: emailService,
EmailQueue: emailQueueService,
Turnstile: turnstileService,
Subscription: subscriptionService,
Concurrency: concurrencyService,
Identity: identityService,
}
}

View File

@@ -0,0 +1,264 @@
package service
import (
"context"
"errors"
"fmt"
"strconv"
"sub2api/internal/config"
"sub2api/internal/model"
"sub2api/internal/repository"
"gorm.io/gorm"
)
var (
ErrRegistrationDisabled = errors.New("registration is currently disabled")
)
// SettingService 系统设置服务
type SettingService struct {
settingRepo *repository.SettingRepository
cfg *config.Config
}
// NewSettingService 创建系统设置服务实例
func NewSettingService(settingRepo *repository.SettingRepository, cfg *config.Config) *SettingService {
return &SettingService{
settingRepo: settingRepo,
cfg: cfg,
}
}
// GetAllSettings 获取所有系统设置
func (s *SettingService) GetAllSettings(ctx context.Context) (*model.SystemSettings, error) {
settings, err := s.settingRepo.GetAll(ctx)
if err != nil {
return nil, fmt.Errorf("get all settings: %w", err)
}
return s.parseSettings(settings), nil
}
// GetPublicSettings 获取公开设置(无需登录)
func (s *SettingService) GetPublicSettings(ctx context.Context) (*model.PublicSettings, error) {
keys := []string{
model.SettingKeyRegistrationEnabled,
model.SettingKeyEmailVerifyEnabled,
model.SettingKeyTurnstileEnabled,
model.SettingKeyTurnstileSiteKey,
model.SettingKeySiteName,
model.SettingKeySiteLogo,
model.SettingKeySiteSubtitle,
model.SettingKeyApiBaseUrl,
model.SettingKeyContactInfo,
}
settings, err := s.settingRepo.GetMultiple(ctx, keys)
if err != nil {
return nil, fmt.Errorf("get public settings: %w", err)
}
return &model.PublicSettings{
RegistrationEnabled: settings[model.SettingKeyRegistrationEnabled] == "true",
EmailVerifyEnabled: settings[model.SettingKeyEmailVerifyEnabled] == "true",
TurnstileEnabled: settings[model.SettingKeyTurnstileEnabled] == "true",
TurnstileSiteKey: settings[model.SettingKeyTurnstileSiteKey],
SiteName: s.getStringOrDefault(settings, model.SettingKeySiteName, "Sub2API"),
SiteLogo: settings[model.SettingKeySiteLogo],
SiteSubtitle: s.getStringOrDefault(settings, model.SettingKeySiteSubtitle, "Subscription to API Conversion Platform"),
ApiBaseUrl: settings[model.SettingKeyApiBaseUrl],
ContactInfo: settings[model.SettingKeyContactInfo],
}, nil
}
// UpdateSettings 更新系统设置
func (s *SettingService) UpdateSettings(ctx context.Context, settings *model.SystemSettings) error {
updates := make(map[string]string)
// 注册设置
updates[model.SettingKeyRegistrationEnabled] = strconv.FormatBool(settings.RegistrationEnabled)
updates[model.SettingKeyEmailVerifyEnabled] = strconv.FormatBool(settings.EmailVerifyEnabled)
// 邮件服务设置(只有非空才更新密码)
updates[model.SettingKeySmtpHost] = settings.SmtpHost
updates[model.SettingKeySmtpPort] = strconv.Itoa(settings.SmtpPort)
updates[model.SettingKeySmtpUsername] = settings.SmtpUsername
if settings.SmtpPassword != "" {
updates[model.SettingKeySmtpPassword] = settings.SmtpPassword
}
updates[model.SettingKeySmtpFrom] = settings.SmtpFrom
updates[model.SettingKeySmtpFromName] = settings.SmtpFromName
updates[model.SettingKeySmtpUseTLS] = strconv.FormatBool(settings.SmtpUseTLS)
// Cloudflare Turnstile 设置(只有非空才更新密钥)
updates[model.SettingKeyTurnstileEnabled] = strconv.FormatBool(settings.TurnstileEnabled)
updates[model.SettingKeyTurnstileSiteKey] = settings.TurnstileSiteKey
if settings.TurnstileSecretKey != "" {
updates[model.SettingKeyTurnstileSecretKey] = settings.TurnstileSecretKey
}
// OEM设置
updates[model.SettingKeySiteName] = settings.SiteName
updates[model.SettingKeySiteLogo] = settings.SiteLogo
updates[model.SettingKeySiteSubtitle] = settings.SiteSubtitle
updates[model.SettingKeyApiBaseUrl] = settings.ApiBaseUrl
updates[model.SettingKeyContactInfo] = settings.ContactInfo
// 默认配置
updates[model.SettingKeyDefaultConcurrency] = strconv.Itoa(settings.DefaultConcurrency)
updates[model.SettingKeyDefaultBalance] = strconv.FormatFloat(settings.DefaultBalance, 'f', 8, 64)
return s.settingRepo.SetMultiple(ctx, updates)
}
// IsRegistrationEnabled 检查是否开放注册
func (s *SettingService) IsRegistrationEnabled(ctx context.Context) bool {
value, err := s.settingRepo.GetValue(ctx, model.SettingKeyRegistrationEnabled)
if err != nil {
// 默认开放注册
return true
}
return value == "true"
}
// IsEmailVerifyEnabled 检查是否开启邮件验证
func (s *SettingService) IsEmailVerifyEnabled(ctx context.Context) bool {
value, err := s.settingRepo.GetValue(ctx, model.SettingKeyEmailVerifyEnabled)
if err != nil {
return false
}
return value == "true"
}
// GetSiteName 获取网站名称
func (s *SettingService) GetSiteName(ctx context.Context) string {
value, err := s.settingRepo.GetValue(ctx, model.SettingKeySiteName)
if err != nil || value == "" {
return "Sub2API"
}
return value
}
// GetDefaultConcurrency 获取默认并发量
func (s *SettingService) GetDefaultConcurrency(ctx context.Context) int {
value, err := s.settingRepo.GetValue(ctx, model.SettingKeyDefaultConcurrency)
if err != nil {
return s.cfg.Default.UserConcurrency
}
if v, err := strconv.Atoi(value); err == nil && v > 0 {
return v
}
return s.cfg.Default.UserConcurrency
}
// GetDefaultBalance 获取默认余额
func (s *SettingService) GetDefaultBalance(ctx context.Context) float64 {
value, err := s.settingRepo.GetValue(ctx, model.SettingKeyDefaultBalance)
if err != nil {
return s.cfg.Default.UserBalance
}
if v, err := strconv.ParseFloat(value, 64); err == nil && v >= 0 {
return v
}
return s.cfg.Default.UserBalance
}
// InitializeDefaultSettings 初始化默认设置
func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
// 检查是否已有设置
_, err := s.settingRepo.GetValue(ctx, model.SettingKeyRegistrationEnabled)
if err == nil {
// 已有设置,不需要初始化
return nil
}
if !errors.Is(err, gorm.ErrRecordNotFound) {
return fmt.Errorf("check existing settings: %w", err)
}
// 初始化默认设置
defaults := map[string]string{
model.SettingKeyRegistrationEnabled: "true",
model.SettingKeyEmailVerifyEnabled: "false",
model.SettingKeySiteName: "Sub2API",
model.SettingKeySiteLogo: "",
model.SettingKeyDefaultConcurrency: strconv.Itoa(s.cfg.Default.UserConcurrency),
model.SettingKeyDefaultBalance: strconv.FormatFloat(s.cfg.Default.UserBalance, 'f', 8, 64),
model.SettingKeySmtpPort: "587",
model.SettingKeySmtpUseTLS: "false",
}
return s.settingRepo.SetMultiple(ctx, defaults)
}
// parseSettings 解析设置到结构体
func (s *SettingService) parseSettings(settings map[string]string) *model.SystemSettings {
result := &model.SystemSettings{
RegistrationEnabled: settings[model.SettingKeyRegistrationEnabled] == "true",
EmailVerifyEnabled: settings[model.SettingKeyEmailVerifyEnabled] == "true",
SmtpHost: settings[model.SettingKeySmtpHost],
SmtpUsername: settings[model.SettingKeySmtpUsername],
SmtpFrom: settings[model.SettingKeySmtpFrom],
SmtpFromName: settings[model.SettingKeySmtpFromName],
SmtpUseTLS: settings[model.SettingKeySmtpUseTLS] == "true",
TurnstileEnabled: settings[model.SettingKeyTurnstileEnabled] == "true",
TurnstileSiteKey: settings[model.SettingKeyTurnstileSiteKey],
SiteName: s.getStringOrDefault(settings, model.SettingKeySiteName, "Sub2API"),
SiteLogo: settings[model.SettingKeySiteLogo],
SiteSubtitle: s.getStringOrDefault(settings, model.SettingKeySiteSubtitle, "Subscription to API Conversion Platform"),
ApiBaseUrl: settings[model.SettingKeyApiBaseUrl],
ContactInfo: settings[model.SettingKeyContactInfo],
}
// 解析整数类型
if port, err := strconv.Atoi(settings[model.SettingKeySmtpPort]); err == nil {
result.SmtpPort = port
} else {
result.SmtpPort = 587
}
if concurrency, err := strconv.Atoi(settings[model.SettingKeyDefaultConcurrency]); err == nil {
result.DefaultConcurrency = concurrency
} else {
result.DefaultConcurrency = s.cfg.Default.UserConcurrency
}
// 解析浮点数类型
if balance, err := strconv.ParseFloat(settings[model.SettingKeyDefaultBalance], 64); err == nil {
result.DefaultBalance = balance
} else {
result.DefaultBalance = s.cfg.Default.UserBalance
}
// 敏感信息直接返回,方便测试连接时使用
result.SmtpPassword = settings[model.SettingKeySmtpPassword]
result.TurnstileSecretKey = settings[model.SettingKeyTurnstileSecretKey]
return result
}
// getStringOrDefault 获取字符串值或默认值
func (s *SettingService) getStringOrDefault(settings map[string]string, key, defaultValue string) string {
if value, ok := settings[key]; ok && value != "" {
return value
}
return defaultValue
}
// IsTurnstileEnabled 检查是否启用 Turnstile 验证
func (s *SettingService) IsTurnstileEnabled(ctx context.Context) bool {
value, err := s.settingRepo.GetValue(ctx, model.SettingKeyTurnstileEnabled)
if err != nil {
return false
}
return value == "true"
}
// GetTurnstileSecretKey 获取 Turnstile Secret Key
func (s *SettingService) GetTurnstileSecretKey(ctx context.Context) string {
value, err := s.settingRepo.GetValue(ctx, model.SettingKeyTurnstileSecretKey)
if err != nil {
return ""
}
return value
}

View File

@@ -0,0 +1,575 @@
package service
import (
"context"
"errors"
"fmt"
"time"
"sub2api/internal/model"
"sub2api/internal/repository"
)
var (
ErrSubscriptionNotFound = errors.New("subscription not found")
ErrSubscriptionExpired = errors.New("subscription has expired")
ErrSubscriptionSuspended = errors.New("subscription is suspended")
ErrSubscriptionAlreadyExists = errors.New("subscription already exists for this user and group")
ErrGroupNotSubscriptionType = errors.New("group is not a subscription type")
ErrDailyLimitExceeded = errors.New("daily usage limit exceeded")
ErrWeeklyLimitExceeded = errors.New("weekly usage limit exceeded")
ErrMonthlyLimitExceeded = errors.New("monthly usage limit exceeded")
)
// SubscriptionService 订阅服务
type SubscriptionService struct {
repos *repository.Repositories
billingCacheService *BillingCacheService
}
// NewSubscriptionService 创建订阅服务
func NewSubscriptionService(repos *repository.Repositories) *SubscriptionService {
return &SubscriptionService{repos: repos}
}
// SetBillingCacheService 设置计费缓存服务(用于缓存失效)
func (s *SubscriptionService) SetBillingCacheService(billingCacheService *BillingCacheService) {
s.billingCacheService = billingCacheService
}
// AssignSubscriptionInput 分配订阅输入
type AssignSubscriptionInput struct {
UserID int64
GroupID int64
ValidityDays int
AssignedBy int64
Notes string
}
// AssignSubscription 分配订阅给用户(不允许重复分配)
func (s *SubscriptionService) AssignSubscription(ctx context.Context, input *AssignSubscriptionInput) (*model.UserSubscription, error) {
// 检查分组是否存在且为订阅类型
group, err := s.repos.Group.GetByID(ctx, input.GroupID)
if err != nil {
return nil, fmt.Errorf("group not found: %w", err)
}
if !group.IsSubscriptionType() {
return nil, ErrGroupNotSubscriptionType
}
// 检查是否已存在订阅
exists, err := s.repos.UserSubscription.ExistsByUserIDAndGroupID(ctx, input.UserID, input.GroupID)
if err != nil {
return nil, err
}
if exists {
return nil, ErrSubscriptionAlreadyExists
}
sub, err := s.createSubscription(ctx, input)
if err != nil {
return nil, err
}
// 失效订阅缓存
if s.billingCacheService != nil {
userID, groupID := input.UserID, input.GroupID
go func() {
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
s.billingCacheService.InvalidateSubscription(cacheCtx, userID, groupID)
}()
}
return sub, nil
}
// AssignOrExtendSubscription 分配或续期订阅(用于兑换码等场景)
// 如果用户已有同分组的订阅:
// - 未过期:从当前过期时间累加天数
// - 已过期:从当前时间开始计算新的过期时间,并激活订阅
// 如果没有订阅:创建新订阅
func (s *SubscriptionService) AssignOrExtendSubscription(ctx context.Context, input *AssignSubscriptionInput) (*model.UserSubscription, bool, error) {
// 检查分组是否存在且为订阅类型
group, err := s.repos.Group.GetByID(ctx, input.GroupID)
if err != nil {
return nil, false, fmt.Errorf("group not found: %w", err)
}
if !group.IsSubscriptionType() {
return nil, false, ErrGroupNotSubscriptionType
}
// 查询是否已有订阅
existingSub, err := s.repos.UserSubscription.GetByUserIDAndGroupID(ctx, input.UserID, input.GroupID)
if err != nil {
// 不存在记录是正常情况,其他错误需要返回
existingSub = nil
}
validityDays := input.ValidityDays
if validityDays <= 0 {
validityDays = 30
}
// 已有订阅,执行续期
if existingSub != nil {
now := time.Now()
var newExpiresAt time.Time
if existingSub.ExpiresAt.After(now) {
// 未过期:从当前过期时间累加
newExpiresAt = existingSub.ExpiresAt.AddDate(0, 0, validityDays)
} else {
// 已过期:从当前时间开始计算
newExpiresAt = now.AddDate(0, 0, validityDays)
}
// 更新过期时间
if err := s.repos.UserSubscription.ExtendExpiry(ctx, existingSub.ID, newExpiresAt); err != nil {
return nil, false, fmt.Errorf("extend subscription: %w", err)
}
// 如果订阅已过期或被暂停恢复为active状态
if existingSub.Status != model.SubscriptionStatusActive {
if err := s.repos.UserSubscription.UpdateStatus(ctx, existingSub.ID, model.SubscriptionStatusActive); err != nil {
return nil, false, fmt.Errorf("update subscription status: %w", err)
}
}
// 追加备注
if input.Notes != "" {
newNotes := existingSub.Notes
if newNotes != "" {
newNotes += "\n"
}
newNotes += input.Notes
if err := s.repos.UserSubscription.UpdateNotes(ctx, existingSub.ID, newNotes); err != nil {
// 备注更新失败不影响主流程
}
}
// 失效订阅缓存
if s.billingCacheService != nil {
userID, groupID := input.UserID, input.GroupID
go func() {
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
s.billingCacheService.InvalidateSubscription(cacheCtx, userID, groupID)
}()
}
// 返回更新后的订阅
sub, err := s.repos.UserSubscription.GetByID(ctx, existingSub.ID)
return sub, true, err // true 表示是续期
}
// 没有订阅,创建新订阅
sub, err := s.createSubscription(ctx, input)
if err != nil {
return nil, false, err
}
// 失效订阅缓存
if s.billingCacheService != nil {
userID, groupID := input.UserID, input.GroupID
go func() {
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
s.billingCacheService.InvalidateSubscription(cacheCtx, userID, groupID)
}()
}
return sub, false, nil // false 表示是新建
}
// createSubscription 创建新订阅(内部方法)
func (s *SubscriptionService) createSubscription(ctx context.Context, input *AssignSubscriptionInput) (*model.UserSubscription, error) {
validityDays := input.ValidityDays
if validityDays <= 0 {
validityDays = 30
}
now := time.Now()
sub := &model.UserSubscription{
UserID: input.UserID,
GroupID: input.GroupID,
StartsAt: now,
ExpiresAt: now.AddDate(0, 0, validityDays),
Status: model.SubscriptionStatusActive,
AssignedAt: now,
Notes: input.Notes,
CreatedAt: now,
UpdatedAt: now,
}
// 只有当 AssignedBy > 0 时才设置0 表示系统分配,如兑换码)
if input.AssignedBy > 0 {
sub.AssignedBy = &input.AssignedBy
}
if err := s.repos.UserSubscription.Create(ctx, sub); err != nil {
return nil, err
}
// 重新获取完整订阅信息(包含关联)
return s.repos.UserSubscription.GetByID(ctx, sub.ID)
}
// BulkAssignSubscriptionInput 批量分配订阅输入
type BulkAssignSubscriptionInput struct {
UserIDs []int64
GroupID int64
ValidityDays int
AssignedBy int64
Notes string
}
// BulkAssignResult 批量分配结果
type BulkAssignResult struct {
SuccessCount int
FailedCount int
Subscriptions []model.UserSubscription
Errors []string
}
// BulkAssignSubscription 批量分配订阅
func (s *SubscriptionService) BulkAssignSubscription(ctx context.Context, input *BulkAssignSubscriptionInput) (*BulkAssignResult, error) {
result := &BulkAssignResult{
Subscriptions: make([]model.UserSubscription, 0),
Errors: make([]string, 0),
}
for _, userID := range input.UserIDs {
sub, err := s.AssignSubscription(ctx, &AssignSubscriptionInput{
UserID: userID,
GroupID: input.GroupID,
ValidityDays: input.ValidityDays,
AssignedBy: input.AssignedBy,
Notes: input.Notes,
})
if err != nil {
result.FailedCount++
result.Errors = append(result.Errors, fmt.Sprintf("user %d: %v", userID, err))
} else {
result.SuccessCount++
result.Subscriptions = append(result.Subscriptions, *sub)
}
}
return result, nil
}
// RevokeSubscription 撤销订阅
func (s *SubscriptionService) RevokeSubscription(ctx context.Context, subscriptionID int64) error {
// 先获取订阅信息用于失效缓存
sub, err := s.repos.UserSubscription.GetByID(ctx, subscriptionID)
if err != nil {
return err
}
if err := s.repos.UserSubscription.Delete(ctx, subscriptionID); err != nil {
return err
}
// 失效订阅缓存
if s.billingCacheService != nil {
userID, groupID := sub.UserID, sub.GroupID
go func() {
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
s.billingCacheService.InvalidateSubscription(cacheCtx, userID, groupID)
}()
}
return nil
}
// ExtendSubscription 延长订阅
func (s *SubscriptionService) ExtendSubscription(ctx context.Context, subscriptionID int64, days int) (*model.UserSubscription, error) {
sub, err := s.repos.UserSubscription.GetByID(ctx, subscriptionID)
if err != nil {
return nil, ErrSubscriptionNotFound
}
// 计算新的过期时间
newExpiresAt := sub.ExpiresAt.AddDate(0, 0, days)
if err := s.repos.UserSubscription.ExtendExpiry(ctx, subscriptionID, newExpiresAt); err != nil {
return nil, err
}
// 如果订阅已过期恢复为active状态
if sub.Status == model.SubscriptionStatusExpired {
if err := s.repos.UserSubscription.UpdateStatus(ctx, subscriptionID, model.SubscriptionStatusActive); err != nil {
return nil, err
}
}
// 失效订阅缓存
if s.billingCacheService != nil {
userID, groupID := sub.UserID, sub.GroupID
go func() {
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
s.billingCacheService.InvalidateSubscription(cacheCtx, userID, groupID)
}()
}
return s.repos.UserSubscription.GetByID(ctx, subscriptionID)
}
// GetByID 根据ID获取订阅
func (s *SubscriptionService) GetByID(ctx context.Context, id int64) (*model.UserSubscription, error) {
return s.repos.UserSubscription.GetByID(ctx, id)
}
// GetActiveSubscription 获取用户对特定分组的有效订阅
func (s *SubscriptionService) GetActiveSubscription(ctx context.Context, userID, groupID int64) (*model.UserSubscription, error) {
sub, err := s.repos.UserSubscription.GetActiveByUserIDAndGroupID(ctx, userID, groupID)
if err != nil {
return nil, ErrSubscriptionNotFound
}
return sub, nil
}
// ListUserSubscriptions 获取用户的所有订阅
func (s *SubscriptionService) ListUserSubscriptions(ctx context.Context, userID int64) ([]model.UserSubscription, error) {
return s.repos.UserSubscription.ListByUserID(ctx, userID)
}
// ListActiveUserSubscriptions 获取用户的所有有效订阅
func (s *SubscriptionService) ListActiveUserSubscriptions(ctx context.Context, userID int64) ([]model.UserSubscription, error) {
return s.repos.UserSubscription.ListActiveByUserID(ctx, userID)
}
// ListGroupSubscriptions 获取分组的所有订阅
func (s *SubscriptionService) ListGroupSubscriptions(ctx context.Context, groupID int64, page, pageSize int) ([]model.UserSubscription, *repository.PaginationResult, error) {
params := repository.PaginationParams{Page: page, PageSize: pageSize}
return s.repos.UserSubscription.ListByGroupID(ctx, groupID, params)
}
// List 获取所有订阅(分页,支持筛选)
func (s *SubscriptionService) List(ctx context.Context, page, pageSize int, userID, groupID *int64, status string) ([]model.UserSubscription, *repository.PaginationResult, error) {
params := repository.PaginationParams{Page: page, PageSize: pageSize}
return s.repos.UserSubscription.List(ctx, params, userID, groupID, status)
}
// CheckAndActivateWindow 检查并激活窗口(首次使用时)
func (s *SubscriptionService) CheckAndActivateWindow(ctx context.Context, sub *model.UserSubscription) error {
if sub.IsWindowActivated() {
return nil
}
now := time.Now()
return s.repos.UserSubscription.ActivateWindows(ctx, sub.ID, now)
}
// CheckAndResetWindows 检查并重置过期的窗口
func (s *SubscriptionService) CheckAndResetWindows(ctx context.Context, sub *model.UserSubscription) error {
now := time.Now()
// 日窗口重置24小时
if sub.NeedsDailyReset() {
if err := s.repos.UserSubscription.ResetDailyUsage(ctx, sub.ID, now); err != nil {
return err
}
sub.DailyWindowStart = &now
sub.DailyUsageUSD = 0
}
// 周窗口重置7天
if sub.NeedsWeeklyReset() {
if err := s.repos.UserSubscription.ResetWeeklyUsage(ctx, sub.ID, now); err != nil {
return err
}
sub.WeeklyWindowStart = &now
sub.WeeklyUsageUSD = 0
}
// 月窗口重置30天
if sub.NeedsMonthlyReset() {
if err := s.repos.UserSubscription.ResetMonthlyUsage(ctx, sub.ID, now); err != nil {
return err
}
sub.MonthlyWindowStart = &now
sub.MonthlyUsageUSD = 0
}
return nil
}
// CheckUsageLimits 检查使用限额(返回错误如果超限)
func (s *SubscriptionService) CheckUsageLimits(ctx context.Context, sub *model.UserSubscription, group *model.Group, additionalCost float64) error {
if !sub.CheckDailyLimit(group, additionalCost) {
return ErrDailyLimitExceeded
}
if !sub.CheckWeeklyLimit(group, additionalCost) {
return ErrWeeklyLimitExceeded
}
if !sub.CheckMonthlyLimit(group, additionalCost) {
return ErrMonthlyLimitExceeded
}
return nil
}
// RecordUsage 记录使用量到订阅
func (s *SubscriptionService) RecordUsage(ctx context.Context, subscriptionID int64, costUSD float64) error {
return s.repos.UserSubscription.IncrementUsage(ctx, subscriptionID, costUSD)
}
// SubscriptionProgress 订阅进度
type SubscriptionProgress struct {
ID int64 `json:"id"`
GroupName string `json:"group_name"`
ExpiresAt time.Time `json:"expires_at"`
ExpiresInDays int `json:"expires_in_days"`
Daily *UsageWindowProgress `json:"daily,omitempty"`
Weekly *UsageWindowProgress `json:"weekly,omitempty"`
Monthly *UsageWindowProgress `json:"monthly,omitempty"`
}
// UsageWindowProgress 使用窗口进度
type UsageWindowProgress struct {
LimitUSD float64 `json:"limit_usd"`
UsedUSD float64 `json:"used_usd"`
RemainingUSD float64 `json:"remaining_usd"`
Percentage float64 `json:"percentage"`
WindowStart time.Time `json:"window_start"`
ResetsAt time.Time `json:"resets_at"`
ResetsInSeconds int64 `json:"resets_in_seconds"`
}
// GetSubscriptionProgress 获取订阅使用进度
func (s *SubscriptionService) GetSubscriptionProgress(ctx context.Context, subscriptionID int64) (*SubscriptionProgress, error) {
sub, err := s.repos.UserSubscription.GetByID(ctx, subscriptionID)
if err != nil {
return nil, ErrSubscriptionNotFound
}
group := sub.Group
if group == nil {
group, err = s.repos.Group.GetByID(ctx, sub.GroupID)
if err != nil {
return nil, err
}
}
progress := &SubscriptionProgress{
ID: sub.ID,
GroupName: group.Name,
ExpiresAt: sub.ExpiresAt,
ExpiresInDays: sub.DaysRemaining(),
}
// 日进度
if group.HasDailyLimit() && sub.DailyWindowStart != nil {
limit := *group.DailyLimitUSD
resetsAt := sub.DailyWindowStart.Add(24 * time.Hour)
progress.Daily = &UsageWindowProgress{
LimitUSD: limit,
UsedUSD: sub.DailyUsageUSD,
RemainingUSD: limit - sub.DailyUsageUSD,
Percentage: (sub.DailyUsageUSD / limit) * 100,
WindowStart: *sub.DailyWindowStart,
ResetsAt: resetsAt,
ResetsInSeconds: int64(time.Until(resetsAt).Seconds()),
}
if progress.Daily.RemainingUSD < 0 {
progress.Daily.RemainingUSD = 0
}
if progress.Daily.Percentage > 100 {
progress.Daily.Percentage = 100
}
if progress.Daily.ResetsInSeconds < 0 {
progress.Daily.ResetsInSeconds = 0
}
}
// 周进度
if group.HasWeeklyLimit() && sub.WeeklyWindowStart != nil {
limit := *group.WeeklyLimitUSD
resetsAt := sub.WeeklyWindowStart.Add(7 * 24 * time.Hour)
progress.Weekly = &UsageWindowProgress{
LimitUSD: limit,
UsedUSD: sub.WeeklyUsageUSD,
RemainingUSD: limit - sub.WeeklyUsageUSD,
Percentage: (sub.WeeklyUsageUSD / limit) * 100,
WindowStart: *sub.WeeklyWindowStart,
ResetsAt: resetsAt,
ResetsInSeconds: int64(time.Until(resetsAt).Seconds()),
}
if progress.Weekly.RemainingUSD < 0 {
progress.Weekly.RemainingUSD = 0
}
if progress.Weekly.Percentage > 100 {
progress.Weekly.Percentage = 100
}
if progress.Weekly.ResetsInSeconds < 0 {
progress.Weekly.ResetsInSeconds = 0
}
}
// 月进度
if group.HasMonthlyLimit() && sub.MonthlyWindowStart != nil {
limit := *group.MonthlyLimitUSD
resetsAt := sub.MonthlyWindowStart.Add(30 * 24 * time.Hour)
progress.Monthly = &UsageWindowProgress{
LimitUSD: limit,
UsedUSD: sub.MonthlyUsageUSD,
RemainingUSD: limit - sub.MonthlyUsageUSD,
Percentage: (sub.MonthlyUsageUSD / limit) * 100,
WindowStart: *sub.MonthlyWindowStart,
ResetsAt: resetsAt,
ResetsInSeconds: int64(time.Until(resetsAt).Seconds()),
}
if progress.Monthly.RemainingUSD < 0 {
progress.Monthly.RemainingUSD = 0
}
if progress.Monthly.Percentage > 100 {
progress.Monthly.Percentage = 100
}
if progress.Monthly.ResetsInSeconds < 0 {
progress.Monthly.ResetsInSeconds = 0
}
}
return progress, nil
}
// GetUserSubscriptionsWithProgress 获取用户所有订阅及进度
func (s *SubscriptionService) GetUserSubscriptionsWithProgress(ctx context.Context, userID int64) ([]SubscriptionProgress, error) {
subs, err := s.repos.UserSubscription.ListActiveByUserID(ctx, userID)
if err != nil {
return nil, err
}
progresses := make([]SubscriptionProgress, 0, len(subs))
for _, sub := range subs {
progress, err := s.GetSubscriptionProgress(ctx, sub.ID)
if err != nil {
continue
}
progresses = append(progresses, *progress)
}
return progresses, nil
}
// UpdateExpiredSubscriptions 更新过期订阅状态(定时任务调用)
func (s *SubscriptionService) UpdateExpiredSubscriptions(ctx context.Context) (int64, error) {
return s.repos.UserSubscription.BatchUpdateExpiredStatus(ctx)
}
// ValidateSubscription 验证订阅是否有效
func (s *SubscriptionService) ValidateSubscription(ctx context.Context, sub *model.UserSubscription) error {
if sub.Status == model.SubscriptionStatusExpired {
return ErrSubscriptionExpired
}
if sub.Status == model.SubscriptionStatusSuspended {
return ErrSubscriptionSuspended
}
if sub.IsExpired() {
// 更新状态
_ = s.repos.UserSubscription.UpdateStatus(ctx, sub.ID, model.SubscriptionStatusExpired)
return ErrSubscriptionExpired
}
return nil
}

View File

@@ -0,0 +1,111 @@
package service
import (
"context"
"encoding/json"
"errors"
"fmt"
"log"
"net/http"
"net/url"
"strings"
"time"
)
var (
ErrTurnstileVerificationFailed = errors.New("turnstile verification failed")
ErrTurnstileNotConfigured = errors.New("turnstile not configured")
)
const turnstileVerifyURL = "https://challenges.cloudflare.com/turnstile/v0/siteverify"
// TurnstileService Turnstile 验证服务
type TurnstileService struct {
settingService *SettingService
httpClient *http.Client
}
// TurnstileVerifyResponse Cloudflare Turnstile 验证响应
type TurnstileVerifyResponse struct {
Success bool `json:"success"`
ChallengeTS string `json:"challenge_ts"`
Hostname string `json:"hostname"`
ErrorCodes []string `json:"error-codes"`
Action string `json:"action"`
CData string `json:"cdata"`
}
// NewTurnstileService 创建 Turnstile 服务实例
func NewTurnstileService(settingService *SettingService) *TurnstileService {
return &TurnstileService{
settingService: settingService,
httpClient: &http.Client{
Timeout: 10 * time.Second,
},
}
}
// VerifyToken 验证 Turnstile token
func (s *TurnstileService) VerifyToken(ctx context.Context, token string, remoteIP string) error {
// 检查是否启用 Turnstile
if !s.settingService.IsTurnstileEnabled(ctx) {
log.Println("[Turnstile] Disabled, skipping verification")
return nil
}
// 获取 Secret Key
secretKey := s.settingService.GetTurnstileSecretKey(ctx)
if secretKey == "" {
log.Println("[Turnstile] Secret key not configured")
return ErrTurnstileNotConfigured
}
// 如果 token 为空,返回错误
if token == "" {
log.Println("[Turnstile] Token is empty")
return ErrTurnstileVerificationFailed
}
// 构建请求
formData := url.Values{}
formData.Set("secret", secretKey)
formData.Set("response", token)
if remoteIP != "" {
formData.Set("remoteip", remoteIP)
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, turnstileVerifyURL, strings.NewReader(formData.Encode()))
if err != nil {
return fmt.Errorf("create request: %w", err)
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
// 发送请求
log.Printf("[Turnstile] Verifying token for IP: %s", remoteIP)
resp, err := s.httpClient.Do(req)
if err != nil {
log.Printf("[Turnstile] Request failed: %v", err)
return fmt.Errorf("send request: %w", err)
}
defer resp.Body.Close()
// 解析响应
var result TurnstileVerifyResponse
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
log.Printf("[Turnstile] Failed to decode response: %v", err)
return fmt.Errorf("decode response: %w", err)
}
if !result.Success {
log.Printf("[Turnstile] Verification failed, error codes: %v", result.ErrorCodes)
return ErrTurnstileVerificationFailed
}
log.Println("[Turnstile] Verification successful")
return nil
}
// IsEnabled 检查 Turnstile 是否启用
func (s *TurnstileService) IsEnabled(ctx context.Context) bool {
return s.settingService.IsTurnstileEnabled(ctx)
}

View File

@@ -0,0 +1,621 @@
package service
import (
"archive/tar"
"bufio"
"compress/gzip"
"context"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"os"
"os/exec"
"path/filepath"
"runtime"
"strings"
"time"
"github.com/redis/go-redis/v9"
)
const (
updateCacheKey = "update_check_cache"
updateCacheTTL = 1200 // 20 minutes
githubRepo = "Wei-Shaw/sub2api"
// Security: allowed download domains for updates
allowedDownloadHost = "github.com"
allowedAssetHost = "objects.githubusercontent.com"
// Security: max download size (500MB)
maxDownloadSize = 500 * 1024 * 1024
)
// UpdateService handles software updates
type UpdateService struct {
rdb *redis.Client
currentVersion string
buildType string // "source" for manual builds, "release" for CI builds
}
// NewUpdateService creates a new UpdateService
func NewUpdateService(rdb *redis.Client, version, buildType string) *UpdateService {
return &UpdateService{
rdb: rdb,
currentVersion: version,
buildType: buildType,
}
}
// UpdateInfo contains update information
type UpdateInfo struct {
CurrentVersion string `json:"current_version"`
LatestVersion string `json:"latest_version"`
HasUpdate bool `json:"has_update"`
ReleaseInfo *ReleaseInfo `json:"release_info,omitempty"`
Cached bool `json:"cached"`
Warning string `json:"warning,omitempty"`
BuildType string `json:"build_type"` // "source" or "release"
}
// ReleaseInfo contains GitHub release details
type ReleaseInfo struct {
Name string `json:"name"`
Body string `json:"body"`
PublishedAt string `json:"published_at"`
HtmlURL string `json:"html_url"`
Assets []Asset `json:"assets,omitempty"`
}
// Asset represents a release asset
type Asset struct {
Name string `json:"name"`
DownloadURL string `json:"download_url"`
Size int64 `json:"size"`
}
// GitHubRelease represents GitHub API response
type GitHubRelease struct {
TagName string `json:"tag_name"`
Name string `json:"name"`
Body string `json:"body"`
PublishedAt string `json:"published_at"`
HtmlUrl string `json:"html_url"`
Assets []GitHubAsset `json:"assets"`
}
type GitHubAsset struct {
Name string `json:"name"`
BrowserDownloadUrl string `json:"browser_download_url"`
Size int64 `json:"size"`
}
// CheckUpdate checks for available updates
func (s *UpdateService) CheckUpdate(ctx context.Context, force bool) (*UpdateInfo, error) {
// Try cache first
if !force {
if cached, err := s.getFromCache(ctx); err == nil && cached != nil {
return cached, nil
}
}
// Fetch from GitHub
info, err := s.fetchLatestRelease(ctx)
if err != nil {
// Return cached on error
if cached, cacheErr := s.getFromCache(ctx); cacheErr == nil && cached != nil {
cached.Warning = "Using cached data: " + err.Error()
return cached, nil
}
return &UpdateInfo{
CurrentVersion: s.currentVersion,
LatestVersion: s.currentVersion,
HasUpdate: false,
Warning: err.Error(),
BuildType: s.buildType,
}, nil
}
// Cache result
s.saveToCache(ctx, info)
return info, nil
}
// PerformUpdate downloads and applies the update
func (s *UpdateService) PerformUpdate(ctx context.Context) error {
info, err := s.CheckUpdate(ctx, true)
if err != nil {
return err
}
if !info.HasUpdate {
return fmt.Errorf("no update available")
}
// Find matching archive and checksum for current platform
archiveName := s.getArchiveName()
var downloadURL string
var checksumURL string
for _, asset := range info.ReleaseInfo.Assets {
if strings.Contains(asset.Name, archiveName) && !strings.HasSuffix(asset.Name, ".txt") {
downloadURL = asset.DownloadURL
}
if asset.Name == "checksums.txt" {
checksumURL = asset.DownloadURL
}
}
if downloadURL == "" {
return fmt.Errorf("no compatible release found for %s/%s", runtime.GOOS, runtime.GOARCH)
}
// SECURITY: Validate download URL is from trusted domain
if err := validateDownloadURL(downloadURL); err != nil {
return fmt.Errorf("invalid download URL: %w", err)
}
if checksumURL != "" {
if err := validateDownloadURL(checksumURL); err != nil {
return fmt.Errorf("invalid checksum URL: %w", err)
}
}
// Get current executable path
exePath, err := os.Executable()
if err != nil {
return fmt.Errorf("failed to get executable path: %w", err)
}
exePath, err = filepath.EvalSymlinks(exePath)
if err != nil {
return fmt.Errorf("failed to resolve symlinks: %w", err)
}
// Create temp directory for extraction
tempDir, err := os.MkdirTemp("", "sub2api-update-*")
if err != nil {
return fmt.Errorf("failed to create temp dir: %w", err)
}
defer os.RemoveAll(tempDir)
// Download archive
archivePath := filepath.Join(tempDir, filepath.Base(downloadURL))
if err := s.downloadFile(ctx, downloadURL, archivePath); err != nil {
return fmt.Errorf("download failed: %w", err)
}
// Verify checksum if available
if checksumURL != "" {
if err := s.verifyChecksum(ctx, archivePath, checksumURL); err != nil {
return fmt.Errorf("checksum verification failed: %w", err)
}
}
// Extract binary from archive
newBinaryPath := filepath.Join(tempDir, "sub2api")
if err := s.extractBinary(archivePath, newBinaryPath); err != nil {
return fmt.Errorf("extraction failed: %w", err)
}
// Backup current binary
backupFile := exePath + ".backup"
if err := os.Rename(exePath, backupFile); err != nil {
return fmt.Errorf("backup failed: %w", err)
}
// Replace with new binary
if err := copyFile(newBinaryPath, exePath); err != nil {
os.Rename(backupFile, exePath)
return fmt.Errorf("replace failed: %w", err)
}
// Make executable
if err := os.Chmod(exePath, 0755); err != nil {
return fmt.Errorf("chmod failed: %w", err)
}
return nil
}
// Rollback restores the previous version
func (s *UpdateService) Rollback() error {
exePath, err := os.Executable()
if err != nil {
return fmt.Errorf("failed to get executable path: %w", err)
}
exePath, err = filepath.EvalSymlinks(exePath)
if err != nil {
return fmt.Errorf("failed to resolve symlinks: %w", err)
}
backupFile := exePath + ".backup"
if _, err := os.Stat(backupFile); os.IsNotExist(err) {
return fmt.Errorf("no backup found")
}
// Replace current with backup
if err := os.Rename(backupFile, exePath); err != nil {
return fmt.Errorf("rollback failed: %w", err)
}
return nil
}
// RestartService triggers a service restart via systemd
func (s *UpdateService) RestartService() error {
if runtime.GOOS != "linux" {
return fmt.Errorf("systemd restart only available on Linux")
}
// Try direct systemctl first (works if running as root or with proper permissions)
cmd := exec.Command("systemctl", "restart", "sub2api")
if err := cmd.Run(); err != nil {
// Try with sudo (requires NOPASSWD sudoers entry)
sudoCmd := exec.Command("sudo", "systemctl", "restart", "sub2api")
if sudoErr := sudoCmd.Run(); sudoErr != nil {
return fmt.Errorf("systemctl restart failed: %w (sudo also failed: %v)", err, sudoErr)
}
}
return nil
}
func (s *UpdateService) fetchLatestRelease(ctx context.Context) (*UpdateInfo, error) {
url := fmt.Sprintf("https://api.github.com/repos/%s/releases/latest", githubRepo)
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
if err != nil {
return nil, err
}
req.Header.Set("Accept", "application/vnd.github.v3+json")
req.Header.Set("User-Agent", "Sub2API-Updater")
client := &http.Client{Timeout: 30 * time.Second}
resp, err := client.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode == http.StatusNotFound {
return &UpdateInfo{
CurrentVersion: s.currentVersion,
LatestVersion: s.currentVersion,
HasUpdate: false,
Warning: "No releases found",
BuildType: s.buildType,
}, nil
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("GitHub API returned %d", resp.StatusCode)
}
var release GitHubRelease
if err := json.NewDecoder(resp.Body).Decode(&release); err != nil {
return nil, err
}
latestVersion := strings.TrimPrefix(release.TagName, "v")
assets := make([]Asset, len(release.Assets))
for i, a := range release.Assets {
assets[i] = Asset{
Name: a.Name,
DownloadURL: a.BrowserDownloadUrl,
Size: a.Size,
}
}
return &UpdateInfo{
CurrentVersion: s.currentVersion,
LatestVersion: latestVersion,
HasUpdate: compareVersions(s.currentVersion, latestVersion) < 0,
ReleaseInfo: &ReleaseInfo{
Name: release.Name,
Body: release.Body,
PublishedAt: release.PublishedAt,
HtmlURL: release.HtmlUrl,
Assets: assets,
},
Cached: false,
BuildType: s.buildType,
}, nil
}
func (s *UpdateService) downloadFile(ctx context.Context, downloadURL, dest string) error {
req, err := http.NewRequestWithContext(ctx, "GET", downloadURL, nil)
if err != nil {
return err
}
client := &http.Client{Timeout: 10 * time.Minute}
resp, err := client.Do(req)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("download returned %d", resp.StatusCode)
}
// SECURITY: Check Content-Length if available
if resp.ContentLength > maxDownloadSize {
return fmt.Errorf("file too large: %d bytes (max %d)", resp.ContentLength, maxDownloadSize)
}
out, err := os.Create(dest)
if err != nil {
return err
}
defer out.Close()
// SECURITY: Use LimitReader to enforce max download size even if Content-Length is missing/wrong
limited := io.LimitReader(resp.Body, maxDownloadSize+1)
written, err := io.Copy(out, limited)
if err != nil {
return err
}
// Check if we hit the limit (downloaded more than maxDownloadSize)
if written > maxDownloadSize {
os.Remove(dest) // Clean up partial file
return fmt.Errorf("download exceeded maximum size of %d bytes", maxDownloadSize)
}
return nil
}
func (s *UpdateService) getArchiveName() string {
osName := runtime.GOOS
arch := runtime.GOARCH
return fmt.Sprintf("%s_%s", osName, arch)
}
// validateDownloadURL checks if the URL is from an allowed domain
// SECURITY: This prevents SSRF and ensures downloads only come from trusted GitHub domains
func validateDownloadURL(rawURL string) error {
parsedURL, err := url.Parse(rawURL)
if err != nil {
return fmt.Errorf("invalid URL: %w", err)
}
// Must be HTTPS
if parsedURL.Scheme != "https" {
return fmt.Errorf("only HTTPS URLs are allowed")
}
// Check against allowed hosts
host := parsedURL.Host
// GitHub release URLs can be from github.com or objects.githubusercontent.com
if host != allowedDownloadHost &&
!strings.HasSuffix(host, "."+allowedDownloadHost) &&
host != allowedAssetHost &&
!strings.HasSuffix(host, "."+allowedAssetHost) {
return fmt.Errorf("download from untrusted host: %s", host)
}
return nil
}
func (s *UpdateService) verifyChecksum(ctx context.Context, filePath, checksumURL string) error {
// Download checksums file
req, err := http.NewRequestWithContext(ctx, "GET", checksumURL, nil)
if err != nil {
return err
}
client := &http.Client{Timeout: 30 * time.Second}
resp, err := client.Do(req)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("failed to download checksums: %d", resp.StatusCode)
}
// Calculate file hash
f, err := os.Open(filePath)
if err != nil {
return err
}
defer f.Close()
h := sha256.New()
if _, err := io.Copy(h, f); err != nil {
return err
}
actualHash := hex.EncodeToString(h.Sum(nil))
// Find expected hash in checksums file
fileName := filepath.Base(filePath)
scanner := bufio.NewScanner(resp.Body)
for scanner.Scan() {
line := scanner.Text()
parts := strings.Fields(line)
if len(parts) == 2 && parts[1] == fileName {
if parts[0] == actualHash {
return nil
}
return fmt.Errorf("checksum mismatch: expected %s, got %s", parts[0], actualHash)
}
}
return fmt.Errorf("checksum not found for %s", fileName)
}
func (s *UpdateService) extractBinary(archivePath, destPath string) error {
f, err := os.Open(archivePath)
if err != nil {
return err
}
defer f.Close()
var reader io.Reader = f
// Handle gzip compression
if strings.HasSuffix(archivePath, ".gz") || strings.HasSuffix(archivePath, ".tar.gz") || strings.HasSuffix(archivePath, ".tgz") {
gzr, err := gzip.NewReader(f)
if err != nil {
return err
}
defer gzr.Close()
reader = gzr
}
// Handle tar archive
if strings.Contains(archivePath, ".tar") {
tr := tar.NewReader(reader)
for {
hdr, err := tr.Next()
if err == io.EOF {
break
}
if err != nil {
return err
}
// SECURITY: Prevent Zip Slip / Path Traversal attack
// Only allow files with safe base names, no directory traversal
baseName := filepath.Base(hdr.Name)
// Check for path traversal attempts
if strings.Contains(hdr.Name, "..") {
return fmt.Errorf("path traversal attempt detected: %s", hdr.Name)
}
// Validate the entry is a regular file
if hdr.Typeflag != tar.TypeReg {
continue // Skip directories and special files
}
// Only extract the specific binary we need
if baseName == "sub2api" || baseName == "sub2api.exe" {
// Additional security: limit file size (max 500MB)
const maxBinarySize = 500 * 1024 * 1024
if hdr.Size > maxBinarySize {
return fmt.Errorf("binary too large: %d bytes (max %d)", hdr.Size, maxBinarySize)
}
out, err := os.Create(destPath)
if err != nil {
return err
}
// Use LimitReader to prevent decompression bombs
limited := io.LimitReader(tr, maxBinarySize)
if _, err := io.Copy(out, limited); err != nil {
out.Close()
return err
}
out.Close()
return nil
}
}
return fmt.Errorf("binary not found in archive")
}
// Direct copy for non-tar files (with size limit)
const maxBinarySize = 500 * 1024 * 1024
out, err := os.Create(destPath)
if err != nil {
return err
}
defer out.Close()
limited := io.LimitReader(reader, maxBinarySize)
_, err = io.Copy(out, limited)
return err
}
func copyFile(src, dst string) error {
in, err := os.Open(src)
if err != nil {
return err
}
defer in.Close()
out, err := os.Create(dst)
if err != nil {
return err
}
defer out.Close()
_, err = io.Copy(out, in)
return err
}
func (s *UpdateService) getFromCache(ctx context.Context) (*UpdateInfo, error) {
data, err := s.rdb.Get(ctx, updateCacheKey).Result()
if err != nil {
return nil, err
}
var cached struct {
Latest string `json:"latest"`
ReleaseInfo *ReleaseInfo `json:"release_info"`
Timestamp int64 `json:"timestamp"`
}
if err := json.Unmarshal([]byte(data), &cached); err != nil {
return nil, err
}
if time.Now().Unix()-cached.Timestamp > updateCacheTTL {
return nil, fmt.Errorf("cache expired")
}
return &UpdateInfo{
CurrentVersion: s.currentVersion,
LatestVersion: cached.Latest,
HasUpdate: compareVersions(s.currentVersion, cached.Latest) < 0,
ReleaseInfo: cached.ReleaseInfo,
Cached: true,
BuildType: s.buildType,
}, nil
}
func (s *UpdateService) saveToCache(ctx context.Context, info *UpdateInfo) {
cacheData := struct {
Latest string `json:"latest"`
ReleaseInfo *ReleaseInfo `json:"release_info"`
Timestamp int64 `json:"timestamp"`
}{
Latest: info.LatestVersion,
ReleaseInfo: info.ReleaseInfo,
Timestamp: time.Now().Unix(),
}
data, _ := json.Marshal(cacheData)
s.rdb.Set(ctx, updateCacheKey, data, time.Duration(updateCacheTTL)*time.Second)
}
// compareVersions compares two semantic versions
func compareVersions(current, latest string) int {
currentParts := parseVersion(current)
latestParts := parseVersion(latest)
for i := 0; i < 3; i++ {
if currentParts[i] < latestParts[i] {
return -1
}
if currentParts[i] > latestParts[i] {
return 1
}
}
return 0
}
func parseVersion(v string) [3]int {
v = strings.TrimPrefix(v, "v")
parts := strings.Split(v, ".")
result := [3]int{0, 0, 0}
for i := 0; i < len(parts) && i < 3; i++ {
fmt.Sscanf(parts[i], "%d", &result[i])
}
return result
}

View File

@@ -0,0 +1,283 @@
package service
import (
"context"
"errors"
"fmt"
"sub2api/internal/model"
"sub2api/internal/repository"
"time"
"gorm.io/gorm"
)
var (
ErrUsageLogNotFound = errors.New("usage log not found")
)
// CreateUsageLogRequest 创建使用日志请求
type CreateUsageLogRequest struct {
UserID int64 `json:"user_id"`
ApiKeyID int64 `json:"api_key_id"`
AccountID int64 `json:"account_id"`
RequestID string `json:"request_id"`
Model string `json:"model"`
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
CacheCreationTokens int `json:"cache_creation_tokens"`
CacheReadTokens int `json:"cache_read_tokens"`
CacheCreation5mTokens int `json:"cache_creation_5m_tokens"`
CacheCreation1hTokens int `json:"cache_creation_1h_tokens"`
InputCost float64 `json:"input_cost"`
OutputCost float64 `json:"output_cost"`
CacheCreationCost float64 `json:"cache_creation_cost"`
CacheReadCost float64 `json:"cache_read_cost"`
TotalCost float64 `json:"total_cost"`
ActualCost float64 `json:"actual_cost"`
RateMultiplier float64 `json:"rate_multiplier"`
Stream bool `json:"stream"`
DurationMs *int `json:"duration_ms"`
}
// UsageStats 使用统计
type UsageStats struct {
TotalRequests int64 `json:"total_requests"`
TotalInputTokens int64 `json:"total_input_tokens"`
TotalOutputTokens int64 `json:"total_output_tokens"`
TotalCacheTokens int64 `json:"total_cache_tokens"`
TotalTokens int64 `json:"total_tokens"`
TotalCost float64 `json:"total_cost"`
TotalActualCost float64 `json:"total_actual_cost"`
AverageDurationMs float64 `json:"average_duration_ms"`
}
// UsageService 使用统计服务
type UsageService struct {
usageRepo *repository.UsageLogRepository
userRepo *repository.UserRepository
}
// NewUsageService 创建使用统计服务实例
func NewUsageService(usageRepo *repository.UsageLogRepository, userRepo *repository.UserRepository) *UsageService {
return &UsageService{
usageRepo: usageRepo,
userRepo: userRepo,
}
}
// Create 创建使用日志
func (s *UsageService) Create(ctx context.Context, req CreateUsageLogRequest) (*model.UsageLog, error) {
// 验证用户存在
_, err := s.userRepo.GetByID(ctx, req.UserID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrUserNotFound
}
return nil, fmt.Errorf("get user: %w", err)
}
// 创建使用日志
usageLog := &model.UsageLog{
UserID: req.UserID,
ApiKeyID: req.ApiKeyID,
AccountID: req.AccountID,
RequestID: req.RequestID,
Model: req.Model,
InputTokens: req.InputTokens,
OutputTokens: req.OutputTokens,
CacheCreationTokens: req.CacheCreationTokens,
CacheReadTokens: req.CacheReadTokens,
CacheCreation5mTokens: req.CacheCreation5mTokens,
CacheCreation1hTokens: req.CacheCreation1hTokens,
InputCost: req.InputCost,
OutputCost: req.OutputCost,
CacheCreationCost: req.CacheCreationCost,
CacheReadCost: req.CacheReadCost,
TotalCost: req.TotalCost,
ActualCost: req.ActualCost,
RateMultiplier: req.RateMultiplier,
Stream: req.Stream,
DurationMs: req.DurationMs,
}
if err := s.usageRepo.Create(ctx, usageLog); err != nil {
return nil, fmt.Errorf("create usage log: %w", err)
}
// 扣除用户余额
if req.ActualCost > 0 {
if err := s.userRepo.UpdateBalance(ctx, req.UserID, -req.ActualCost); err != nil {
return nil, fmt.Errorf("update user balance: %w", err)
}
}
return usageLog, nil
}
// GetByID 根据ID获取使用日志
func (s *UsageService) GetByID(ctx context.Context, id int64) (*model.UsageLog, error) {
log, err := s.usageRepo.GetByID(ctx, id)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrUsageLogNotFound
}
return nil, fmt.Errorf("get usage log: %w", err)
}
return log, nil
}
// ListByUser 获取用户的使用日志列表
func (s *UsageService) ListByUser(ctx context.Context, userID int64, params repository.PaginationParams) ([]model.UsageLog, *repository.PaginationResult, error) {
logs, pagination, err := s.usageRepo.ListByUser(ctx, userID, params)
if err != nil {
return nil, nil, fmt.Errorf("list usage logs: %w", err)
}
return logs, pagination, nil
}
// ListByApiKey 获取API Key的使用日志列表
func (s *UsageService) ListByApiKey(ctx context.Context, apiKeyID int64, params repository.PaginationParams) ([]model.UsageLog, *repository.PaginationResult, error) {
logs, pagination, err := s.usageRepo.ListByApiKey(ctx, apiKeyID, params)
if err != nil {
return nil, nil, fmt.Errorf("list usage logs: %w", err)
}
return logs, pagination, nil
}
// ListByAccount 获取账号的使用日志列表
func (s *UsageService) ListByAccount(ctx context.Context, accountID int64, params repository.PaginationParams) ([]model.UsageLog, *repository.PaginationResult, error) {
logs, pagination, err := s.usageRepo.ListByAccount(ctx, accountID, params)
if err != nil {
return nil, nil, fmt.Errorf("list usage logs: %w", err)
}
return logs, pagination, nil
}
// GetStatsByUser 获取用户的使用统计
func (s *UsageService) GetStatsByUser(ctx context.Context, userID int64, startTime, endTime time.Time) (*UsageStats, error) {
logs, _, err := s.usageRepo.ListByUserAndTimeRange(ctx, userID, startTime, endTime)
if err != nil {
return nil, fmt.Errorf("list usage logs: %w", err)
}
return s.calculateStats(logs), nil
}
// GetStatsByApiKey 获取API Key的使用统计
func (s *UsageService) GetStatsByApiKey(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) (*UsageStats, error) {
logs, _, err := s.usageRepo.ListByApiKeyAndTimeRange(ctx, apiKeyID, startTime, endTime)
if err != nil {
return nil, fmt.Errorf("list usage logs: %w", err)
}
return s.calculateStats(logs), nil
}
// GetStatsByAccount 获取账号的使用统计
func (s *UsageService) GetStatsByAccount(ctx context.Context, accountID int64, startTime, endTime time.Time) (*UsageStats, error) {
logs, _, err := s.usageRepo.ListByAccountAndTimeRange(ctx, accountID, startTime, endTime)
if err != nil {
return nil, fmt.Errorf("list usage logs: %w", err)
}
return s.calculateStats(logs), nil
}
// GetStatsByModel 获取模型的使用统计
func (s *UsageService) GetStatsByModel(ctx context.Context, modelName string, startTime, endTime time.Time) (*UsageStats, error) {
logs, _, err := s.usageRepo.ListByModelAndTimeRange(ctx, modelName, startTime, endTime)
if err != nil {
return nil, fmt.Errorf("list usage logs: %w", err)
}
return s.calculateStats(logs), nil
}
// GetDailyStats 获取每日使用统计最近N天
func (s *UsageService) GetDailyStats(ctx context.Context, userID int64, days int) ([]map[string]interface{}, error) {
endTime := time.Now()
startTime := endTime.AddDate(0, 0, -days)
logs, _, err := s.usageRepo.ListByUserAndTimeRange(ctx, userID, startTime, endTime)
if err != nil {
return nil, fmt.Errorf("list usage logs: %w", err)
}
// 按日期分组统计
dailyStats := make(map[string]*UsageStats)
for _, log := range logs {
dateKey := log.CreatedAt.Format("2006-01-02")
if _, exists := dailyStats[dateKey]; !exists {
dailyStats[dateKey] = &UsageStats{}
}
stats := dailyStats[dateKey]
stats.TotalRequests++
stats.TotalInputTokens += int64(log.InputTokens)
stats.TotalOutputTokens += int64(log.OutputTokens)
stats.TotalCacheTokens += int64(log.CacheCreationTokens + log.CacheReadTokens)
stats.TotalTokens += int64(log.TotalTokens())
stats.TotalCost += log.TotalCost
stats.TotalActualCost += log.ActualCost
if log.DurationMs != nil {
stats.AverageDurationMs += float64(*log.DurationMs)
}
}
// 计算平均值并转换为数组
result := make([]map[string]interface{}, 0, len(dailyStats))
for date, stats := range dailyStats {
if stats.TotalRequests > 0 {
stats.AverageDurationMs /= float64(stats.TotalRequests)
}
result = append(result, map[string]interface{}{
"date": date,
"total_requests": stats.TotalRequests,
"total_input_tokens": stats.TotalInputTokens,
"total_output_tokens": stats.TotalOutputTokens,
"total_cache_tokens": stats.TotalCacheTokens,
"total_tokens": stats.TotalTokens,
"total_cost": stats.TotalCost,
"total_actual_cost": stats.TotalActualCost,
"average_duration_ms": stats.AverageDurationMs,
})
}
return result, nil
}
// calculateStats 计算统计数据
func (s *UsageService) calculateStats(logs []model.UsageLog) *UsageStats {
stats := &UsageStats{}
for _, log := range logs {
stats.TotalRequests++
stats.TotalInputTokens += int64(log.InputTokens)
stats.TotalOutputTokens += int64(log.OutputTokens)
stats.TotalCacheTokens += int64(log.CacheCreationTokens + log.CacheReadTokens)
stats.TotalTokens += int64(log.TotalTokens())
stats.TotalCost += log.TotalCost
stats.TotalActualCost += log.ActualCost
if log.DurationMs != nil {
stats.AverageDurationMs += float64(*log.DurationMs)
}
}
// 计算平均持续时间
if stats.TotalRequests > 0 {
stats.AverageDurationMs /= float64(stats.TotalRequests)
}
return stats
}
// Delete 删除使用日志(管理员功能,谨慎使用)
func (s *UsageService) Delete(ctx context.Context, id int64) error {
if err := s.usageRepo.Delete(ctx, id); err != nil {
return fmt.Errorf("delete usage log: %w", err)
}
return nil
}

View File

@@ -0,0 +1,177 @@
package service
import (
"context"
"errors"
"fmt"
"sub2api/internal/config"
"sub2api/internal/model"
"sub2api/internal/repository"
"golang.org/x/crypto/bcrypt"
"gorm.io/gorm"
)
var (
ErrUserNotFound = errors.New("user not found")
ErrPasswordIncorrect = errors.New("current password is incorrect")
ErrInsufficientPerms = errors.New("insufficient permissions")
)
// UpdateProfileRequest 更新用户资料请求
type UpdateProfileRequest struct {
Email *string `json:"email"`
Concurrency *int `json:"concurrency"`
}
// ChangePasswordRequest 修改密码请求
type ChangePasswordRequest struct {
CurrentPassword string `json:"current_password"`
NewPassword string `json:"new_password"`
}
// UserService 用户服务
type UserService struct {
userRepo *repository.UserRepository
cfg *config.Config
}
// NewUserService 创建用户服务实例
func NewUserService(userRepo *repository.UserRepository, cfg *config.Config) *UserService {
return &UserService{
userRepo: userRepo,
cfg: cfg,
}
}
// GetProfile 获取用户资料
func (s *UserService) GetProfile(ctx context.Context, userID int64) (*model.User, error) {
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrUserNotFound
}
return nil, fmt.Errorf("get user: %w", err)
}
return user, nil
}
// UpdateProfile 更新用户资料
func (s *UserService) UpdateProfile(ctx context.Context, userID int64, req UpdateProfileRequest) (*model.User, error) {
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrUserNotFound
}
return nil, fmt.Errorf("get user: %w", err)
}
// 更新字段
if req.Email != nil {
// 检查新邮箱是否已被使用
exists, err := s.userRepo.ExistsByEmail(ctx, *req.Email)
if err != nil {
return nil, fmt.Errorf("check email exists: %w", err)
}
if exists && *req.Email != user.Email {
return nil, ErrEmailExists
}
user.Email = *req.Email
}
if req.Concurrency != nil {
user.Concurrency = *req.Concurrency
}
if err := s.userRepo.Update(ctx, user); err != nil {
return nil, fmt.Errorf("update user: %w", err)
}
return user, nil
}
// ChangePassword 修改密码
func (s *UserService) ChangePassword(ctx context.Context, userID int64, req ChangePasswordRequest) error {
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return ErrUserNotFound
}
return fmt.Errorf("get user: %w", err)
}
// 验证当前密码
if err := bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(req.CurrentPassword)); err != nil {
return ErrPasswordIncorrect
}
// 生成新密码哈希
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(req.NewPassword), bcrypt.DefaultCost)
if err != nil {
return fmt.Errorf("hash password: %w", err)
}
user.PasswordHash = string(hashedPassword)
if err := s.userRepo.Update(ctx, user); err != nil {
return fmt.Errorf("update user: %w", err)
}
return nil
}
// GetByID 根据ID获取用户管理员功能
func (s *UserService) GetByID(ctx context.Context, id int64) (*model.User, error) {
user, err := s.userRepo.GetByID(ctx, id)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrUserNotFound
}
return nil, fmt.Errorf("get user: %w", err)
}
return user, nil
}
// List 获取用户列表(管理员功能)
func (s *UserService) List(ctx context.Context, params repository.PaginationParams) ([]model.User, *repository.PaginationResult, error) {
users, pagination, err := s.userRepo.List(ctx, params)
if err != nil {
return nil, nil, fmt.Errorf("list users: %w", err)
}
return users, pagination, nil
}
// UpdateBalance 更新用户余额(管理员功能)
func (s *UserService) UpdateBalance(ctx context.Context, userID int64, amount float64) error {
if err := s.userRepo.UpdateBalance(ctx, userID, amount); err != nil {
return fmt.Errorf("update balance: %w", err)
}
return nil
}
// UpdateStatus 更新用户状态(管理员功能)
func (s *UserService) UpdateStatus(ctx context.Context, userID int64, status string) error {
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return ErrUserNotFound
}
return fmt.Errorf("get user: %w", err)
}
user.Status = status
if err := s.userRepo.Update(ctx, user); err != nil {
return fmt.Errorf("update user: %w", err)
}
return nil
}
// Delete 删除用户(管理员功能)
func (s *UserService) Delete(ctx context.Context, userID int64) error {
if err := s.userRepo.Delete(ctx, userID); err != nil {
return fmt.Errorf("delete user: %w", err)
}
return nil
}

View File

@@ -0,0 +1,294 @@
package setup
import (
"bufio"
"fmt"
"net/mail"
"os"
"regexp"
"strconv"
"strings"
"golang.org/x/term"
)
// CLI input validation functions (matching Web API validation)
func cliValidateHostname(host string) bool {
validHost := regexp.MustCompile(`^[a-zA-Z0-9.\-:]+$`)
return validHost.MatchString(host) && len(host) <= 253
}
func cliValidateDBName(name string) bool {
validName := regexp.MustCompile(`^[a-zA-Z][a-zA-Z0-9_]*$`)
return validName.MatchString(name) && len(name) <= 63
}
func cliValidateUsername(name string) bool {
validName := regexp.MustCompile(`^[a-zA-Z0-9_]+$`)
return validName.MatchString(name) && len(name) <= 63
}
func cliValidateEmail(email string) bool {
_, err := mail.ParseAddress(email)
return err == nil && len(email) <= 254
}
func cliValidatePort(port int) bool {
return port > 0 && port <= 65535
}
func cliValidateSSLMode(mode string) bool {
validModes := map[string]bool{
"disable": true, "require": true, "verify-ca": true, "verify-full": true,
}
return validModes[mode]
}
// RunCLI runs the CLI setup wizard
func RunCLI() error {
reader := bufio.NewReader(os.Stdin)
fmt.Println()
fmt.Println("╔═══════════════════════════════════════════╗")
fmt.Println("║ Sub2API Installation Wizard ║")
fmt.Println("╚═══════════════════════════════════════════╝")
fmt.Println()
cfg := &SetupConfig{
Server: ServerConfig{
Host: "0.0.0.0",
Port: 8080,
Mode: "release",
},
JWT: JWTConfig{
ExpireHour: 24,
},
}
// Database configuration with validation
fmt.Println("── Database Configuration ──")
for {
cfg.Database.Host = promptString(reader, "PostgreSQL Host", "localhost")
if cliValidateHostname(cfg.Database.Host) {
break
}
fmt.Println(" Invalid hostname format. Use alphanumeric, dots, hyphens only.")
}
for {
cfg.Database.Port = promptInt(reader, "PostgreSQL Port", 5432)
if cliValidatePort(cfg.Database.Port) {
break
}
fmt.Println(" Invalid port. Must be between 1 and 65535.")
}
for {
cfg.Database.User = promptString(reader, "PostgreSQL User", "postgres")
if cliValidateUsername(cfg.Database.User) {
break
}
fmt.Println(" Invalid username. Use alphanumeric and underscores only.")
}
cfg.Database.Password = promptPassword("PostgreSQL Password")
for {
cfg.Database.DBName = promptString(reader, "Database Name", "sub2api")
if cliValidateDBName(cfg.Database.DBName) {
break
}
fmt.Println(" Invalid database name. Start with letter, use alphanumeric and underscores.")
}
for {
cfg.Database.SSLMode = promptString(reader, "SSL Mode", "disable")
if cliValidateSSLMode(cfg.Database.SSLMode) {
break
}
fmt.Println(" Invalid SSL mode. Use: disable, require, verify-ca, or verify-full.")
}
fmt.Println()
fmt.Print("Testing database connection... ")
if err := TestDatabaseConnection(&cfg.Database); err != nil {
fmt.Println("FAILED")
return fmt.Errorf("database connection failed: %w", err)
}
fmt.Println("OK")
// Redis configuration with validation
fmt.Println()
fmt.Println("── Redis Configuration ──")
for {
cfg.Redis.Host = promptString(reader, "Redis Host", "localhost")
if cliValidateHostname(cfg.Redis.Host) {
break
}
fmt.Println(" Invalid hostname format. Use alphanumeric, dots, hyphens only.")
}
for {
cfg.Redis.Port = promptInt(reader, "Redis Port", 6379)
if cliValidatePort(cfg.Redis.Port) {
break
}
fmt.Println(" Invalid port. Must be between 1 and 65535.")
}
cfg.Redis.Password = promptPassword("Redis Password (optional)")
for {
cfg.Redis.DB = promptInt(reader, "Redis DB", 0)
if cfg.Redis.DB >= 0 && cfg.Redis.DB <= 15 {
break
}
fmt.Println(" Invalid Redis DB. Must be between 0 and 15.")
}
fmt.Println()
fmt.Print("Testing Redis connection... ")
if err := TestRedisConnection(&cfg.Redis); err != nil {
fmt.Println("FAILED")
return fmt.Errorf("redis connection failed: %w", err)
}
fmt.Println("OK")
// Admin configuration with validation
fmt.Println()
fmt.Println("── Admin Account ──")
for {
cfg.Admin.Email = promptString(reader, "Admin Email", "admin@example.com")
if cliValidateEmail(cfg.Admin.Email) {
break
}
fmt.Println(" Invalid email format.")
}
for {
cfg.Admin.Password = promptPassword("Admin Password")
// SECURITY: Match Web API requirement of 8 characters minimum
if len(cfg.Admin.Password) < 8 {
fmt.Println(" Password must be at least 8 characters")
continue
}
if len(cfg.Admin.Password) > 128 {
fmt.Println(" Password must be at most 128 characters")
continue
}
confirm := promptPassword("Confirm Password")
if cfg.Admin.Password != confirm {
fmt.Println(" Passwords do not match")
continue
}
break
}
// Server configuration with validation
fmt.Println()
fmt.Println("── Server Configuration ──")
for {
cfg.Server.Port = promptInt(reader, "Server Port", 8080)
if cliValidatePort(cfg.Server.Port) {
break
}
fmt.Println(" Invalid port. Must be between 1 and 65535.")
}
// Confirm and install
fmt.Println()
fmt.Println("── Configuration Summary ──")
fmt.Printf("Database: %s@%s:%d/%s\n", cfg.Database.User, cfg.Database.Host, cfg.Database.Port, cfg.Database.DBName)
fmt.Printf("Redis: %s:%d\n", cfg.Redis.Host, cfg.Redis.Port)
fmt.Printf("Admin: %s\n", cfg.Admin.Email)
fmt.Printf("Server: :%d\n", cfg.Server.Port)
fmt.Println()
if !promptConfirm(reader, "Proceed with installation?") {
fmt.Println("Installation cancelled")
return nil
}
fmt.Println()
fmt.Print("Installing... ")
if err := Install(cfg); err != nil {
fmt.Println("FAILED")
return err
}
fmt.Println("OK")
fmt.Println()
fmt.Println("╔═══════════════════════════════════════════╗")
fmt.Println("║ Installation Complete! ║")
fmt.Println("╚═══════════════════════════════════════════╝")
fmt.Println()
fmt.Println("Start the server with:")
fmt.Println(" ./sub2api")
fmt.Println()
fmt.Printf("Admin panel: http://localhost:%d\n", cfg.Server.Port)
fmt.Println()
return nil
}
func promptString(reader *bufio.Reader, prompt, defaultVal string) string {
if defaultVal != "" {
fmt.Printf(" %s [%s]: ", prompt, defaultVal)
} else {
fmt.Printf(" %s: ", prompt)
}
input, _ := reader.ReadString('\n')
input = strings.TrimSpace(input)
if input == "" {
return defaultVal
}
return input
}
func promptInt(reader *bufio.Reader, prompt string, defaultVal int) int {
fmt.Printf(" %s [%d]: ", prompt, defaultVal)
input, _ := reader.ReadString('\n')
input = strings.TrimSpace(input)
if input == "" {
return defaultVal
}
val, err := strconv.Atoi(input)
if err != nil {
return defaultVal
}
return val
}
func promptPassword(prompt string) string {
fmt.Printf(" %s: ", prompt)
// Try to read password without echo
if term.IsTerminal(int(os.Stdin.Fd())) {
password, err := term.ReadPassword(int(os.Stdin.Fd()))
fmt.Println()
if err == nil {
return string(password)
}
}
// Fallback to regular input
reader := bufio.NewReader(os.Stdin)
input, _ := reader.ReadString('\n')
return strings.TrimSpace(input)
}
func promptConfirm(reader *bufio.Reader, prompt string) bool {
fmt.Printf("%s [y/N]: ", prompt)
input, _ := reader.ReadString('\n')
input = strings.TrimSpace(strings.ToLower(input))
return input == "y" || input == "yes"
}

View File

@@ -0,0 +1,344 @@
package setup
import (
"fmt"
"net/http"
"net/mail"
"regexp"
"strings"
"sync"
"sub2api/internal/pkg/response"
"github.com/gin-gonic/gin"
)
// installMutex prevents concurrent installation attempts (TOCTOU protection)
var installMutex sync.Mutex
// RegisterRoutes registers setup wizard routes
func RegisterRoutes(r *gin.Engine) {
setup := r.Group("/setup")
{
// Status endpoint is always accessible (read-only)
setup.GET("/status", getStatus)
// All modification endpoints are protected by setupGuard
protected := setup.Group("")
protected.Use(setupGuard())
{
protected.POST("/test-db", testDatabase)
protected.POST("/test-redis", testRedis)
protected.POST("/install", install)
}
}
}
// SetupStatus represents the current setup state
type SetupStatus struct {
NeedsSetup bool `json:"needs_setup"`
Step string `json:"step"`
}
// getStatus returns the current setup status
func getStatus(c *gin.Context) {
response.Success(c, SetupStatus{
NeedsSetup: NeedsSetup(),
Step: "welcome",
})
}
// setupGuard middleware ensures setup endpoints are only accessible during setup mode
func setupGuard() gin.HandlerFunc {
return func(c *gin.Context) {
if !NeedsSetup() {
response.Error(c, http.StatusForbidden, "Setup is not allowed: system is already installed")
c.Abort()
return
}
c.Next()
}
}
// validateHostname checks if a hostname/IP is safe (no injection characters)
func validateHostname(host string) bool {
// Allow only alphanumeric, dots, hyphens, and colons (for IPv6)
validHost := regexp.MustCompile(`^[a-zA-Z0-9.\-:]+$`)
return validHost.MatchString(host) && len(host) <= 253
}
// validateDBName checks if database name is safe
func validateDBName(name string) bool {
// Allow only alphanumeric and underscores, starting with letter
validName := regexp.MustCompile(`^[a-zA-Z][a-zA-Z0-9_]*$`)
return validName.MatchString(name) && len(name) <= 63
}
// validateUsername checks if username is safe
func validateUsername(name string) bool {
// Allow only alphanumeric and underscores
validName := regexp.MustCompile(`^[a-zA-Z0-9_]+$`)
return validName.MatchString(name) && len(name) <= 63
}
// validateEmail checks if email format is valid
func validateEmail(email string) bool {
_, err := mail.ParseAddress(email)
return err == nil && len(email) <= 254
}
// validatePassword checks password strength
func validatePassword(password string) error {
if len(password) < 8 {
return fmt.Errorf("password must be at least 8 characters")
}
if len(password) > 128 {
return fmt.Errorf("password must be at most 128 characters")
}
return nil
}
// validatePort checks if port is in valid range
func validatePort(port int) bool {
return port > 0 && port <= 65535
}
// validateSSLMode checks if SSL mode is valid
func validateSSLMode(mode string) bool {
validModes := map[string]bool{
"disable": true, "require": true, "verify-ca": true, "verify-full": true,
}
return validModes[mode]
}
// TestDatabaseRequest represents database test request
type TestDatabaseRequest struct {
Host string `json:"host" binding:"required"`
Port int `json:"port" binding:"required"`
User string `json:"user" binding:"required"`
Password string `json:"password"`
DBName string `json:"dbname" binding:"required"`
SSLMode string `json:"sslmode"`
}
// testDatabase tests database connection
func testDatabase(c *gin.Context) {
var req TestDatabaseRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.Error(c, http.StatusBadRequest, "Invalid request: "+err.Error())
return
}
// Security: Validate all inputs to prevent injection attacks
if !validateHostname(req.Host) {
response.Error(c, http.StatusBadRequest, "Invalid hostname format")
return
}
if !validatePort(req.Port) {
response.Error(c, http.StatusBadRequest, "Invalid port number")
return
}
if !validateUsername(req.User) {
response.Error(c, http.StatusBadRequest, "Invalid username format")
return
}
if !validateDBName(req.DBName) {
response.Error(c, http.StatusBadRequest, "Invalid database name format")
return
}
if req.SSLMode == "" {
req.SSLMode = "disable"
}
if !validateSSLMode(req.SSLMode) {
response.Error(c, http.StatusBadRequest, "Invalid SSL mode")
return
}
cfg := &DatabaseConfig{
Host: req.Host,
Port: req.Port,
User: req.User,
Password: req.Password,
DBName: req.DBName,
SSLMode: req.SSLMode,
}
if err := TestDatabaseConnection(cfg); err != nil {
response.Error(c, http.StatusBadRequest, "Connection failed: "+err.Error())
return
}
response.Success(c, gin.H{"message": "Connection successful"})
}
// TestRedisRequest represents Redis test request
type TestRedisRequest struct {
Host string `json:"host" binding:"required"`
Port int `json:"port" binding:"required"`
Password string `json:"password"`
DB int `json:"db"`
}
// testRedis tests Redis connection
func testRedis(c *gin.Context) {
var req TestRedisRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.Error(c, http.StatusBadRequest, "Invalid request: "+err.Error())
return
}
// Security: Validate inputs
if !validateHostname(req.Host) {
response.Error(c, http.StatusBadRequest, "Invalid hostname format")
return
}
if !validatePort(req.Port) {
response.Error(c, http.StatusBadRequest, "Invalid port number")
return
}
if req.DB < 0 || req.DB > 15 {
response.Error(c, http.StatusBadRequest, "Invalid Redis database number (0-15)")
return
}
cfg := &RedisConfig{
Host: req.Host,
Port: req.Port,
Password: req.Password,
DB: req.DB,
}
if err := TestRedisConnection(cfg); err != nil {
response.Error(c, http.StatusBadRequest, "Connection failed: "+err.Error())
return
}
response.Success(c, gin.H{"message": "Connection successful"})
}
// InstallRequest represents installation request
type InstallRequest struct {
Database DatabaseConfig `json:"database" binding:"required"`
Redis RedisConfig `json:"redis" binding:"required"`
Admin AdminConfig `json:"admin" binding:"required"`
Server ServerConfig `json:"server"`
}
// install performs the installation
func install(c *gin.Context) {
// TOCTOU Protection: Acquire mutex to prevent concurrent installation
installMutex.Lock()
defer installMutex.Unlock()
// Double-check after acquiring lock
if !NeedsSetup() {
response.Error(c, http.StatusForbidden, "Setup is not allowed: system is already installed")
return
}
var req InstallRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.Error(c, http.StatusBadRequest, "Invalid request: "+err.Error())
return
}
// ========== COMPREHENSIVE INPUT VALIDATION ==========
// Database validation
if !validateHostname(req.Database.Host) {
response.Error(c, http.StatusBadRequest, "Invalid database hostname")
return
}
if !validatePort(req.Database.Port) {
response.Error(c, http.StatusBadRequest, "Invalid database port")
return
}
if !validateUsername(req.Database.User) {
response.Error(c, http.StatusBadRequest, "Invalid database username")
return
}
if !validateDBName(req.Database.DBName) {
response.Error(c, http.StatusBadRequest, "Invalid database name")
return
}
// Redis validation
if !validateHostname(req.Redis.Host) {
response.Error(c, http.StatusBadRequest, "Invalid Redis hostname")
return
}
if !validatePort(req.Redis.Port) {
response.Error(c, http.StatusBadRequest, "Invalid Redis port")
return
}
if req.Redis.DB < 0 || req.Redis.DB > 15 {
response.Error(c, http.StatusBadRequest, "Invalid Redis database number")
return
}
// Admin validation
if !validateEmail(req.Admin.Email) {
response.Error(c, http.StatusBadRequest, "Invalid admin email format")
return
}
if err := validatePassword(req.Admin.Password); err != nil {
response.Error(c, http.StatusBadRequest, err.Error())
return
}
// Server validation
if req.Server.Port != 0 && !validatePort(req.Server.Port) {
response.Error(c, http.StatusBadRequest, "Invalid server port")
return
}
// ========== SET DEFAULTS ==========
if req.Database.SSLMode == "" {
req.Database.SSLMode = "disable"
}
if !validateSSLMode(req.Database.SSLMode) {
response.Error(c, http.StatusBadRequest, "Invalid SSL mode")
return
}
if req.Server.Host == "" {
req.Server.Host = "0.0.0.0"
}
if req.Server.Port == 0 {
req.Server.Port = 8080
}
if req.Server.Mode == "" {
req.Server.Mode = "release"
}
// Validate server mode
if req.Server.Mode != "release" && req.Server.Mode != "debug" {
response.Error(c, http.StatusBadRequest, "Invalid server mode (must be 'release' or 'debug')")
return
}
// Trim whitespace from string inputs
req.Admin.Email = strings.TrimSpace(req.Admin.Email)
req.Database.Host = strings.TrimSpace(req.Database.Host)
req.Database.User = strings.TrimSpace(req.Database.User)
req.Database.DBName = strings.TrimSpace(req.Database.DBName)
req.Redis.Host = strings.TrimSpace(req.Redis.Host)
cfg := &SetupConfig{
Database: req.Database,
Redis: req.Redis,
Admin: req.Admin,
Server: req.Server,
JWT: JWTConfig{
ExpireHour: 24,
},
}
if err := Install(cfg); err != nil {
response.Error(c, http.StatusInternalServerError, "Installation failed: "+err.Error())
return
}
response.Success(c, gin.H{
"message": "Installation completed successfully",
"restart": true,
})
}

View File

@@ -0,0 +1,564 @@
package setup
import (
"context"
"crypto/rand"
"encoding/hex"
"fmt"
"log"
"os"
"strconv"
"time"
"github.com/redis/go-redis/v9"
"golang.org/x/crypto/bcrypt"
"gorm.io/driver/postgres"
"gorm.io/gorm"
"gopkg.in/yaml.v3"
)
// Config paths
const (
ConfigFile = "config.yaml"
EnvFile = ".env"
)
// SetupConfig holds the setup configuration
type SetupConfig struct {
Database DatabaseConfig `json:"database" yaml:"database"`
Redis RedisConfig `json:"redis" yaml:"redis"`
Admin AdminConfig `json:"admin" yaml:"-"` // Not stored in config file
Server ServerConfig `json:"server" yaml:"server"`
JWT JWTConfig `json:"jwt" yaml:"jwt"`
Timezone string `json:"timezone" yaml:"timezone"` // e.g. "Asia/Shanghai", "UTC"
}
type DatabaseConfig struct {
Host string `json:"host" yaml:"host"`
Port int `json:"port" yaml:"port"`
User string `json:"user" yaml:"user"`
Password string `json:"password" yaml:"password"`
DBName string `json:"dbname" yaml:"dbname"`
SSLMode string `json:"sslmode" yaml:"sslmode"`
}
type RedisConfig struct {
Host string `json:"host" yaml:"host"`
Port int `json:"port" yaml:"port"`
Password string `json:"password" yaml:"password"`
DB int `json:"db" yaml:"db"`
}
type AdminConfig struct {
Email string `json:"email"`
Password string `json:"password"`
}
type ServerConfig struct {
Host string `json:"host" yaml:"host"`
Port int `json:"port" yaml:"port"`
Mode string `json:"mode" yaml:"mode"`
}
type JWTConfig struct {
Secret string `json:"secret" yaml:"secret"`
ExpireHour int `json:"expire_hour" yaml:"expire_hour"`
}
// NeedsSetup checks if the system needs initial setup
// Uses multiple checks to prevent attackers from forcing re-setup by deleting config
func NeedsSetup() bool {
// Check 1: Config file must not exist
if _, err := os.Stat(ConfigFile); !os.IsNotExist(err) {
return false // Config exists, no setup needed
}
// Check 2: Installation lock file (harder to bypass)
lockFile := ".installed"
if _, err := os.Stat(lockFile); !os.IsNotExist(err) {
return false // Lock file exists, already installed
}
return true
}
// TestDatabaseConnection tests the database connection
func TestDatabaseConnection(cfg *DatabaseConfig) error {
dsn := fmt.Sprintf(
"host=%s port=%d user=%s password=%s dbname=%s sslmode=%s",
cfg.Host, cfg.Port, cfg.User, cfg.Password, cfg.DBName, cfg.SSLMode,
)
db, err := gorm.Open(postgres.Open(dsn), &gorm.Config{})
if err != nil {
return fmt.Errorf("failed to connect: %w", err)
}
sqlDB, err := db.DB()
if err != nil {
return fmt.Errorf("failed to get db instance: %w", err)
}
defer sqlDB.Close()
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := sqlDB.PingContext(ctx); err != nil {
return fmt.Errorf("ping failed: %w", err)
}
return nil
}
// TestRedisConnection tests the Redis connection
func TestRedisConnection(cfg *RedisConfig) error {
rdb := redis.NewClient(&redis.Options{
Addr: fmt.Sprintf("%s:%d", cfg.Host, cfg.Port),
Password: cfg.Password,
DB: cfg.DB,
})
defer rdb.Close()
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := rdb.Ping(ctx).Err(); err != nil {
return fmt.Errorf("ping failed: %w", err)
}
return nil
}
// Install performs the installation with the given configuration
func Install(cfg *SetupConfig) error {
// Security check: prevent re-installation if already installed
if !NeedsSetup() {
return fmt.Errorf("system is already installed, re-installation is not allowed")
}
// Generate JWT secret if not provided
if cfg.JWT.Secret == "" {
cfg.JWT.Secret = generateSecret(32)
}
// Test connections
if err := TestDatabaseConnection(&cfg.Database); err != nil {
return fmt.Errorf("database connection failed: %w", err)
}
if err := TestRedisConnection(&cfg.Redis); err != nil {
return fmt.Errorf("redis connection failed: %w", err)
}
// Initialize database
if err := initializeDatabase(cfg); err != nil {
return fmt.Errorf("database initialization failed: %w", err)
}
// Create admin user
if err := createAdminUser(cfg); err != nil {
return fmt.Errorf("admin user creation failed: %w", err)
}
// Write config file
if err := writeConfigFile(cfg); err != nil {
return fmt.Errorf("config file creation failed: %w", err)
}
// Create installation lock file to prevent re-setup attacks
if err := createInstallLock(); err != nil {
return fmt.Errorf("failed to create install lock: %w", err)
}
return nil
}
// createInstallLock creates a lock file to prevent re-installation attacks
func createInstallLock() error {
lockFile := ".installed"
content := fmt.Sprintf("installed_at=%s\n", time.Now().UTC().Format(time.RFC3339))
return os.WriteFile(lockFile, []byte(content), 0400) // Read-only for owner
}
func initializeDatabase(cfg *SetupConfig) error {
dsn := fmt.Sprintf(
"host=%s port=%d user=%s password=%s dbname=%s sslmode=%s",
cfg.Database.Host, cfg.Database.Port, cfg.Database.User,
cfg.Database.Password, cfg.Database.DBName, cfg.Database.SSLMode,
)
db, err := gorm.Open(postgres.Open(dsn), &gorm.Config{})
if err != nil {
return err
}
sqlDB, err := db.DB()
if err != nil {
return err
}
defer sqlDB.Close()
// Run auto-migration for all models
return db.AutoMigrate(
&User{},
&Group{},
&APIKey{},
&Account{},
&Proxy{},
&RedeemCode{},
&UsageLog{},
&UserSubscription{},
&Setting{},
)
}
func createAdminUser(cfg *SetupConfig) error {
dsn := fmt.Sprintf(
"host=%s port=%d user=%s password=%s dbname=%s sslmode=%s",
cfg.Database.Host, cfg.Database.Port, cfg.Database.User,
cfg.Database.Password, cfg.Database.DBName, cfg.Database.SSLMode,
)
db, err := gorm.Open(postgres.Open(dsn), &gorm.Config{})
if err != nil {
return err
}
sqlDB, err := db.DB()
if err != nil {
return err
}
defer sqlDB.Close()
// Check if admin already exists
var count int64
db.Model(&User{}).Where("role = ?", "admin").Count(&count)
if count > 0 {
return nil // Admin already exists
}
// Hash password
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(cfg.Admin.Password), bcrypt.DefaultCost)
if err != nil {
return err
}
// Create admin user
admin := &User{
Email: cfg.Admin.Email,
PasswordHash: string(hashedPassword),
Role: "admin",
Status: "active",
Balance: 0,
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
}
return db.Create(admin).Error
}
func writeConfigFile(cfg *SetupConfig) error {
// Ensure timezone has a default value
tz := cfg.Timezone
if tz == "" {
tz = "Asia/Shanghai"
}
// Prepare config for YAML (exclude sensitive data and admin config)
yamlConfig := struct {
Server ServerConfig `yaml:"server"`
Database DatabaseConfig `yaml:"database"`
Redis RedisConfig `yaml:"redis"`
JWT struct {
Secret string `yaml:"secret"`
ExpireHour int `yaml:"expire_hour"`
} `yaml:"jwt"`
Default struct {
GroupID uint `yaml:"group_id"`
} `yaml:"default"`
RateLimit struct {
RequestsPerMinute int `yaml:"requests_per_minute"`
BurstSize int `yaml:"burst_size"`
} `yaml:"rate_limit"`
Timezone string `yaml:"timezone"`
}{
Server: cfg.Server,
Database: cfg.Database,
Redis: cfg.Redis,
JWT: struct {
Secret string `yaml:"secret"`
ExpireHour int `yaml:"expire_hour"`
}{
Secret: cfg.JWT.Secret,
ExpireHour: cfg.JWT.ExpireHour,
},
Default: struct {
GroupID uint `yaml:"group_id"`
}{
GroupID: 1,
},
RateLimit: struct {
RequestsPerMinute int `yaml:"requests_per_minute"`
BurstSize int `yaml:"burst_size"`
}{
RequestsPerMinute: 60,
BurstSize: 10,
},
Timezone: tz,
}
data, err := yaml.Marshal(&yamlConfig)
if err != nil {
return err
}
return os.WriteFile(ConfigFile, data, 0600)
}
func generateSecret(length int) string {
bytes := make([]byte, length)
rand.Read(bytes)
return hex.EncodeToString(bytes)
}
// Minimal model definitions for migration (to avoid circular import)
type User struct {
ID uint `gorm:"primaryKey"`
Email string `gorm:"uniqueIndex;not null"`
PasswordHash string `gorm:"not null"`
Role string `gorm:"default:user"`
Status string `gorm:"default:active"`
Balance float64 `gorm:"default:0"`
CreatedAt time.Time
UpdatedAt time.Time
}
type Group struct {
ID uint `gorm:"primaryKey"`
Name string `gorm:"uniqueIndex;not null"`
Description string `gorm:"type:text"`
RateMultiplier float64 `gorm:"default:1.0"`
IsExclusive bool `gorm:"default:false"`
Priority int `gorm:"default:0"`
Status string `gorm:"default:active"`
CreatedAt time.Time
UpdatedAt time.Time
}
type APIKey struct {
ID uint `gorm:"primaryKey"`
UserID uint `gorm:"index;not null"`
Key string `gorm:"uniqueIndex;not null"`
Name string
GroupID *uint
Status string `gorm:"default:active"`
CreatedAt time.Time
UpdatedAt time.Time
}
type Account struct {
ID uint `gorm:"primaryKey"`
Platform string `gorm:"not null"`
Type string `gorm:"not null"`
Credentials string `gorm:"type:text"`
Status string `gorm:"default:active"`
Priority int `gorm:"default:0"`
ProxyID *uint
CreatedAt time.Time
UpdatedAt time.Time
}
type Proxy struct {
ID uint `gorm:"primaryKey"`
Name string `gorm:"not null"`
Protocol string `gorm:"not null"`
Host string `gorm:"not null"`
Port int `gorm:"not null"`
Username string
Password string
Status string `gorm:"default:active"`
CreatedAt time.Time
UpdatedAt time.Time
}
type RedeemCode struct {
ID uint `gorm:"primaryKey"`
Code string `gorm:"uniqueIndex;not null"`
Value float64 `gorm:"not null"`
Status string `gorm:"default:unused"`
UsedBy *uint
UsedAt *time.Time
ExpiresAt *time.Time
CreatedAt time.Time
}
type UsageLog struct {
ID uint `gorm:"primaryKey"`
UserID uint `gorm:"index"`
APIKeyID uint `gorm:"index"`
AccountID *uint `gorm:"index"`
Model string `gorm:"index"`
InputTokens int
OutputTokens int
Cost float64
CreatedAt time.Time
}
type UserSubscription struct {
ID uint `gorm:"primaryKey"`
UserID uint `gorm:"index;not null"`
GroupID uint `gorm:"index;not null"`
Quota int64
Used int64 `gorm:"default:0"`
Status string
ExpiresAt *time.Time
CreatedAt time.Time
UpdatedAt time.Time
}
type Setting struct {
ID uint `gorm:"primaryKey"`
Key string `gorm:"uniqueIndex;not null"`
Value string `gorm:"type:text"`
CreatedAt time.Time
UpdatedAt time.Time
}
func (User) TableName() string { return "users" }
func (Group) TableName() string { return "groups" }
func (APIKey) TableName() string { return "api_keys" }
func (Account) TableName() string { return "accounts" }
func (Proxy) TableName() string { return "proxies" }
func (RedeemCode) TableName() string { return "redeem_codes" }
func (UsageLog) TableName() string { return "usage_logs" }
func (UserSubscription) TableName() string { return "user_subscriptions" }
func (Setting) TableName() string { return "settings" }
// =============================================================================
// Auto Setup for Docker Deployment
// =============================================================================
// AutoSetupEnabled checks if auto setup is enabled via environment variable
func AutoSetupEnabled() bool {
val := os.Getenv("AUTO_SETUP")
return val == "true" || val == "1" || val == "yes"
}
// getEnvOrDefault gets environment variable or returns default value
func getEnvOrDefault(key, defaultValue string) string {
if val := os.Getenv(key); val != "" {
return val
}
return defaultValue
}
// getEnvIntOrDefault gets environment variable as int or returns default value
func getEnvIntOrDefault(key string, defaultValue int) int {
if val := os.Getenv(key); val != "" {
if i, err := strconv.Atoi(val); err == nil {
return i
}
}
return defaultValue
}
// AutoSetupFromEnv performs automatic setup using environment variables
// This is designed for Docker deployment where all config is passed via env vars
func AutoSetupFromEnv() error {
log.Println("Auto setup enabled, configuring from environment variables...")
// Get timezone from TZ or TIMEZONE env var (TZ is standard for Docker)
tz := getEnvOrDefault("TZ", "")
if tz == "" {
tz = getEnvOrDefault("TIMEZONE", "Asia/Shanghai")
}
// Build config from environment variables
cfg := &SetupConfig{
Database: DatabaseConfig{
Host: getEnvOrDefault("DATABASE_HOST", "localhost"),
Port: getEnvIntOrDefault("DATABASE_PORT", 5432),
User: getEnvOrDefault("DATABASE_USER", "postgres"),
Password: getEnvOrDefault("DATABASE_PASSWORD", ""),
DBName: getEnvOrDefault("DATABASE_DBNAME", "sub2api"),
SSLMode: getEnvOrDefault("DATABASE_SSLMODE", "disable"),
},
Redis: RedisConfig{
Host: getEnvOrDefault("REDIS_HOST", "localhost"),
Port: getEnvIntOrDefault("REDIS_PORT", 6379),
Password: getEnvOrDefault("REDIS_PASSWORD", ""),
DB: getEnvIntOrDefault("REDIS_DB", 0),
},
Admin: AdminConfig{
Email: getEnvOrDefault("ADMIN_EMAIL", "admin@sub2api.local"),
Password: getEnvOrDefault("ADMIN_PASSWORD", ""),
},
Server: ServerConfig{
Host: getEnvOrDefault("SERVER_HOST", "0.0.0.0"),
Port: getEnvIntOrDefault("SERVER_PORT", 8080),
Mode: getEnvOrDefault("SERVER_MODE", "release"),
},
JWT: JWTConfig{
Secret: getEnvOrDefault("JWT_SECRET", ""),
ExpireHour: getEnvIntOrDefault("JWT_EXPIRE_HOUR", 24),
},
Timezone: tz,
}
// Generate JWT secret if not provided
if cfg.JWT.Secret == "" {
cfg.JWT.Secret = generateSecret(32)
log.Println("Generated JWT secret automatically")
}
// Generate admin password if not provided
if cfg.Admin.Password == "" {
cfg.Admin.Password = generateSecret(16)
log.Printf("Generated admin password: %s", cfg.Admin.Password)
log.Println("IMPORTANT: Save this password! It will not be shown again.")
}
// Test database connection
log.Println("Testing database connection...")
if err := TestDatabaseConnection(&cfg.Database); err != nil {
return fmt.Errorf("database connection failed: %w", err)
}
log.Println("Database connection successful")
// Test Redis connection
log.Println("Testing Redis connection...")
if err := TestRedisConnection(&cfg.Redis); err != nil {
return fmt.Errorf("redis connection failed: %w", err)
}
log.Println("Redis connection successful")
// Initialize database
log.Println("Initializing database...")
if err := initializeDatabase(cfg); err != nil {
return fmt.Errorf("database initialization failed: %w", err)
}
log.Println("Database initialized successfully")
// Create admin user
log.Println("Creating admin user...")
if err := createAdminUser(cfg); err != nil {
return fmt.Errorf("admin user creation failed: %w", err)
}
log.Printf("Admin user created: %s", cfg.Admin.Email)
// Write config file
log.Println("Writing configuration file...")
if err := writeConfigFile(cfg); err != nil {
return fmt.Errorf("config file creation failed: %w", err)
}
log.Println("Configuration file created")
// Create installation lock file
if err := createInstallLock(); err != nil {
return fmt.Errorf("failed to create install lock: %w", err)
}
log.Println("Installation lock created")
log.Println("Auto setup completed successfully!")
return nil
}

View File

@@ -0,0 +1,79 @@
package web
import (
"embed"
"io"
"io/fs"
"net/http"
"strings"
"github.com/gin-gonic/gin"
)
//go:embed dist/*
var frontendFS embed.FS
// ServeEmbeddedFrontend returns a Gin handler that serves embedded frontend assets
// and handles SPA routing by falling back to index.html for non-API routes.
func ServeEmbeddedFrontend() gin.HandlerFunc {
distFS, err := fs.Sub(frontendFS, "dist")
if err != nil {
panic("failed to get dist subdirectory: " + err.Error())
}
fileServer := http.FileServer(http.FS(distFS))
return func(c *gin.Context) {
path := c.Request.URL.Path
// Skip API and gateway routes
if strings.HasPrefix(path, "/api/") ||
strings.HasPrefix(path, "/v1/") ||
strings.HasPrefix(path, "/setup/") ||
path == "/health" {
c.Next()
return
}
// Try to serve static file
cleanPath := strings.TrimPrefix(path, "/")
if cleanPath == "" {
cleanPath = "index.html"
}
if file, err := distFS.Open(cleanPath); err == nil {
file.Close()
fileServer.ServeHTTP(c.Writer, c.Request)
c.Abort()
return
}
// SPA fallback: serve index.html for all other routes
serveIndexHTML(c, distFS)
}
}
func serveIndexHTML(c *gin.Context, fsys fs.FS) {
file, err := fsys.Open("index.html")
if err != nil {
c.String(http.StatusNotFound, "Frontend not found")
c.Abort()
return
}
defer file.Close()
content, err := io.ReadAll(file)
if err != nil {
c.String(http.StatusInternalServerError, "Failed to read index.html")
c.Abort()
return
}
c.Data(http.StatusOK, "text/html; charset=utf-8", content)
c.Abort()
}
// HasEmbeddedFrontend checks if frontend assets are embedded
func HasEmbeddedFrontend() bool {
_, err := frontendFS.ReadFile("dist/index.html")
return err == nil
}

View File

@@ -0,0 +1,183 @@
-- Sub2API 初始化数据库迁移脚本
-- PostgreSQL 15+
-- 1. proxies 代理IP表无外键依赖
CREATE TABLE IF NOT EXISTS proxies (
id BIGSERIAL PRIMARY KEY,
name VARCHAR(100) NOT NULL,
protocol VARCHAR(20) NOT NULL, -- http/https/socks5
host VARCHAR(255) NOT NULL,
port INT NOT NULL,
username VARCHAR(100),
password VARCHAR(100),
status VARCHAR(20) NOT NULL DEFAULT 'active', -- active/disabled
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
deleted_at TIMESTAMPTZ
);
CREATE INDEX IF NOT EXISTS idx_proxies_status ON proxies(status);
CREATE INDEX IF NOT EXISTS idx_proxies_deleted_at ON proxies(deleted_at);
-- 2. groups 分组表(无外键依赖)
CREATE TABLE IF NOT EXISTS groups (
id BIGSERIAL PRIMARY KEY,
name VARCHAR(100) NOT NULL UNIQUE,
description TEXT,
rate_multiplier DECIMAL(10, 4) NOT NULL DEFAULT 1.0, -- 费率倍率
is_exclusive BOOLEAN NOT NULL DEFAULT FALSE, -- 是否专属分组
status VARCHAR(20) NOT NULL DEFAULT 'active', -- active/disabled
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
deleted_at TIMESTAMPTZ
);
CREATE INDEX IF NOT EXISTS idx_groups_name ON groups(name);
CREATE INDEX IF NOT EXISTS idx_groups_status ON groups(status);
CREATE INDEX IF NOT EXISTS idx_groups_is_exclusive ON groups(is_exclusive);
CREATE INDEX IF NOT EXISTS idx_groups_deleted_at ON groups(deleted_at);
-- 3. users 用户表(无外键依赖)
CREATE TABLE IF NOT EXISTS users (
id BIGSERIAL PRIMARY KEY,
email VARCHAR(255) NOT NULL UNIQUE,
password_hash VARCHAR(255) NOT NULL,
role VARCHAR(20) NOT NULL DEFAULT 'user', -- admin/user
balance DECIMAL(20, 8) NOT NULL DEFAULT 0, -- 余额(可为负数)
concurrency INT NOT NULL DEFAULT 5, -- 并发数限制
status VARCHAR(20) NOT NULL DEFAULT 'active', -- active/disabled
allowed_groups BIGINT[] DEFAULT NULL, -- 允许绑定的分组ID列表
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
deleted_at TIMESTAMPTZ
);
CREATE INDEX IF NOT EXISTS idx_users_email ON users(email);
CREATE INDEX IF NOT EXISTS idx_users_status ON users(status);
CREATE INDEX IF NOT EXISTS idx_users_deleted_at ON users(deleted_at);
-- 4. accounts 上游账号表依赖proxies
CREATE TABLE IF NOT EXISTS accounts (
id BIGSERIAL PRIMARY KEY,
name VARCHAR(100) NOT NULL,
platform VARCHAR(50) NOT NULL, -- anthropic/openai/gemini
type VARCHAR(20) NOT NULL, -- oauth/apikey
credentials JSONB NOT NULL DEFAULT '{}', -- 凭证信息(加密存储)
extra JSONB NOT NULL DEFAULT '{}', -- 扩展信息
proxy_id BIGINT REFERENCES proxies(id) ON DELETE SET NULL,
concurrency INT NOT NULL DEFAULT 3, -- 账号并发限制
priority INT NOT NULL DEFAULT 50, -- 调度优先级(1-100越小越高)
status VARCHAR(20) NOT NULL DEFAULT 'active', -- active/disabled/error
error_message TEXT,
last_used_at TIMESTAMPTZ,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
deleted_at TIMESTAMPTZ
);
CREATE INDEX IF NOT EXISTS idx_accounts_platform ON accounts(platform);
CREATE INDEX IF NOT EXISTS idx_accounts_type ON accounts(type);
CREATE INDEX IF NOT EXISTS idx_accounts_status ON accounts(status);
CREATE INDEX IF NOT EXISTS idx_accounts_proxy_id ON accounts(proxy_id);
CREATE INDEX IF NOT EXISTS idx_accounts_priority ON accounts(priority);
CREATE INDEX IF NOT EXISTS idx_accounts_last_used_at ON accounts(last_used_at);
CREATE INDEX IF NOT EXISTS idx_accounts_deleted_at ON accounts(deleted_at);
-- 5. api_keys API密钥表依赖users, groups
CREATE TABLE IF NOT EXISTS api_keys (
id BIGSERIAL PRIMARY KEY,
user_id BIGINT NOT NULL REFERENCES users(id) ON DELETE CASCADE,
key VARCHAR(64) NOT NULL UNIQUE, -- sk-xxx格式
name VARCHAR(100) NOT NULL,
group_id BIGINT REFERENCES groups(id) ON DELETE SET NULL,
status VARCHAR(20) NOT NULL DEFAULT 'active', -- active/disabled
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
deleted_at TIMESTAMPTZ
);
CREATE INDEX IF NOT EXISTS idx_api_keys_key ON api_keys(key);
CREATE INDEX IF NOT EXISTS idx_api_keys_user_id ON api_keys(user_id);
CREATE INDEX IF NOT EXISTS idx_api_keys_group_id ON api_keys(group_id);
CREATE INDEX IF NOT EXISTS idx_api_keys_status ON api_keys(status);
CREATE INDEX IF NOT EXISTS idx_api_keys_deleted_at ON api_keys(deleted_at);
-- 6. account_groups 账号-分组关联表依赖accounts, groups
CREATE TABLE IF NOT EXISTS account_groups (
account_id BIGINT NOT NULL REFERENCES accounts(id) ON DELETE CASCADE,
group_id BIGINT NOT NULL REFERENCES groups(id) ON DELETE CASCADE,
priority INT NOT NULL DEFAULT 50, -- 分组内优先级
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
PRIMARY KEY (account_id, group_id)
);
CREATE INDEX IF NOT EXISTS idx_account_groups_group_id ON account_groups(group_id);
CREATE INDEX IF NOT EXISTS idx_account_groups_priority ON account_groups(priority);
-- 7. redeem_codes 卡密表依赖users
CREATE TABLE IF NOT EXISTS redeem_codes (
id BIGSERIAL PRIMARY KEY,
code VARCHAR(32) NOT NULL UNIQUE, -- 兑换码
type VARCHAR(20) NOT NULL DEFAULT 'balance', -- balance
value DECIMAL(20, 8) NOT NULL, -- 面值USD
status VARCHAR(20) NOT NULL DEFAULT 'unused', -- unused/used
used_by BIGINT REFERENCES users(id) ON DELETE SET NULL,
used_at TIMESTAMPTZ,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
CREATE INDEX IF NOT EXISTS idx_redeem_codes_code ON redeem_codes(code);
CREATE INDEX IF NOT EXISTS idx_redeem_codes_status ON redeem_codes(status);
CREATE INDEX IF NOT EXISTS idx_redeem_codes_used_by ON redeem_codes(used_by);
-- 8. usage_logs 使用记录表依赖users, api_keys, accounts
CREATE TABLE IF NOT EXISTS usage_logs (
id BIGSERIAL PRIMARY KEY,
user_id BIGINT NOT NULL REFERENCES users(id) ON DELETE CASCADE,
api_key_id BIGINT NOT NULL REFERENCES api_keys(id) ON DELETE CASCADE,
account_id BIGINT NOT NULL REFERENCES accounts(id) ON DELETE CASCADE,
request_id VARCHAR(64),
model VARCHAR(100) NOT NULL,
-- Token使用量4类
input_tokens INT NOT NULL DEFAULT 0,
output_tokens INT NOT NULL DEFAULT 0,
cache_creation_tokens INT NOT NULL DEFAULT 0,
cache_read_tokens INT NOT NULL DEFAULT 0,
-- 详细的缓存创建分类
cache_creation_5m_tokens INT NOT NULL DEFAULT 0,
cache_creation_1h_tokens INT NOT NULL DEFAULT 0,
-- 费用USD
input_cost DECIMAL(20, 10) NOT NULL DEFAULT 0,
output_cost DECIMAL(20, 10) NOT NULL DEFAULT 0,
cache_creation_cost DECIMAL(20, 10) NOT NULL DEFAULT 0,
cache_read_cost DECIMAL(20, 10) NOT NULL DEFAULT 0,
total_cost DECIMAL(20, 10) NOT NULL DEFAULT 0, -- 原始总费用
actual_cost DECIMAL(20, 10) NOT NULL DEFAULT 0, -- 实际扣除费用
-- 元数据
stream BOOLEAN NOT NULL DEFAULT FALSE,
duration_ms INT,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
CREATE INDEX IF NOT EXISTS idx_usage_logs_user_id ON usage_logs(user_id);
CREATE INDEX IF NOT EXISTS idx_usage_logs_api_key_id ON usage_logs(api_key_id);
CREATE INDEX IF NOT EXISTS idx_usage_logs_account_id ON usage_logs(account_id);
CREATE INDEX IF NOT EXISTS idx_usage_logs_model ON usage_logs(model);
CREATE INDEX IF NOT EXISTS idx_usage_logs_created_at ON usage_logs(created_at);
CREATE INDEX IF NOT EXISTS idx_usage_logs_user_created ON usage_logs(user_id, created_at);
-- 插入默认管理员用户
-- 密码: admin123 (bcrypt hash)
INSERT INTO users (email, password_hash, role, balance, concurrency, status)
VALUES ('admin@sub2api.com', '$2a$10$N9qo8uLOickgx2ZMRZoMye.IjJbDdJeCo0U2bBPJj9lS/5LqD.C.C', 'admin', 0, 10, 'active')
ON CONFLICT (email) DO NOTHING;
-- 插入默认分组
INSERT INTO groups (name, description, rate_multiplier, is_exclusive, status)
VALUES ('default', '默认分组', 1.0, false, 'active')
ON CONFLICT (name) DO NOTHING;

View File

@@ -0,0 +1,33 @@
-- Sub2API 账号类型迁移脚本
-- 将 'official' 类型账号迁移为 'oauth' 或 'setup-token'
-- 根据 credentials->>'scope' 字段判断:
-- - 包含 'user:profile' 的是 'oauth' 类型
-- - 只有 'user:inference' 的是 'setup-token' 类型
-- 1. 将包含 profile scope 的 official 账号迁移为 oauth
UPDATE accounts
SET type = 'oauth',
updated_at = NOW()
WHERE type = 'official'
AND credentials->>'scope' LIKE '%user:profile%';
-- 2. 将只有 inference scope 的 official 账号迁移为 setup-token
UPDATE accounts
SET type = 'setup-token',
updated_at = NOW()
WHERE type = 'official'
AND (
credentials->>'scope' = 'user:inference'
OR credentials->>'scope' NOT LIKE '%user:profile%'
);
-- 3. 处理没有 scope 字段的旧账号(默认为 oauth
UPDATE accounts
SET type = 'oauth',
updated_at = NOW()
WHERE type = 'official'
AND (credentials->>'scope' IS NULL OR credentials->>'scope' = '');
-- 4. 验证迁移结果(查询是否还有 official 类型账号)
-- SELECT COUNT(*) FROM accounts WHERE type = 'official';
-- 如果结果为 0说明迁移成功

View File

@@ -0,0 +1,65 @@
-- Sub2API 订阅功能迁移脚本
-- 添加订阅分组和用户订阅功能
-- 1. 扩展 groups 表添加订阅相关字段
ALTER TABLE groups ADD COLUMN IF NOT EXISTS platform VARCHAR(50) NOT NULL DEFAULT 'anthropic';
ALTER TABLE groups ADD COLUMN IF NOT EXISTS subscription_type VARCHAR(20) NOT NULL DEFAULT 'standard';
ALTER TABLE groups ADD COLUMN IF NOT EXISTS daily_limit_usd DECIMAL(20, 8) DEFAULT NULL;
ALTER TABLE groups ADD COLUMN IF NOT EXISTS weekly_limit_usd DECIMAL(20, 8) DEFAULT NULL;
ALTER TABLE groups ADD COLUMN IF NOT EXISTS monthly_limit_usd DECIMAL(20, 8) DEFAULT NULL;
ALTER TABLE groups ADD COLUMN IF NOT EXISTS default_validity_days INT NOT NULL DEFAULT 30;
-- 添加索引
CREATE INDEX IF NOT EXISTS idx_groups_platform ON groups(platform);
CREATE INDEX IF NOT EXISTS idx_groups_subscription_type ON groups(subscription_type);
-- 2. 创建 user_subscriptions 用户订阅表
CREATE TABLE IF NOT EXISTS user_subscriptions (
id BIGSERIAL PRIMARY KEY,
user_id BIGINT NOT NULL REFERENCES users(id) ON DELETE CASCADE,
group_id BIGINT NOT NULL REFERENCES groups(id) ON DELETE CASCADE,
-- 订阅有效期
starts_at TIMESTAMPTZ NOT NULL,
expires_at TIMESTAMPTZ NOT NULL,
status VARCHAR(20) NOT NULL DEFAULT 'active', -- active/expired/suspended
-- 滑动窗口起始时间NULL=未激活)
daily_window_start TIMESTAMPTZ,
weekly_window_start TIMESTAMPTZ,
monthly_window_start TIMESTAMPTZ,
-- 当前窗口已用额度USD基于 total_cost 计算)
daily_usage_usd DECIMAL(20, 10) NOT NULL DEFAULT 0,
weekly_usage_usd DECIMAL(20, 10) NOT NULL DEFAULT 0,
monthly_usage_usd DECIMAL(20, 10) NOT NULL DEFAULT 0,
-- 管理员分配信息
assigned_by BIGINT REFERENCES users(id) ON DELETE SET NULL,
assigned_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
notes TEXT,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
-- 唯一约束:每个用户对每个分组只能有一个订阅
UNIQUE(user_id, group_id)
);
-- user_subscriptions 索引
CREATE INDEX IF NOT EXISTS idx_user_subscriptions_user_id ON user_subscriptions(user_id);
CREATE INDEX IF NOT EXISTS idx_user_subscriptions_group_id ON user_subscriptions(group_id);
CREATE INDEX IF NOT EXISTS idx_user_subscriptions_status ON user_subscriptions(status);
CREATE INDEX IF NOT EXISTS idx_user_subscriptions_expires_at ON user_subscriptions(expires_at);
CREATE INDEX IF NOT EXISTS idx_user_subscriptions_assigned_by ON user_subscriptions(assigned_by);
-- 3. 扩展 usage_logs 表添加分组和订阅关联
ALTER TABLE usage_logs ADD COLUMN IF NOT EXISTS group_id BIGINT REFERENCES groups(id) ON DELETE SET NULL;
ALTER TABLE usage_logs ADD COLUMN IF NOT EXISTS subscription_id BIGINT REFERENCES user_subscriptions(id) ON DELETE SET NULL;
ALTER TABLE usage_logs ADD COLUMN IF NOT EXISTS rate_multiplier DECIMAL(10, 4) NOT NULL DEFAULT 1;
ALTER TABLE usage_logs ADD COLUMN IF NOT EXISTS first_token_ms INT;
-- usage_logs 新索引
CREATE INDEX IF NOT EXISTS idx_usage_logs_group_id ON usage_logs(group_id);
CREATE INDEX IF NOT EXISTS idx_usage_logs_subscription_id ON usage_logs(subscription_id);
CREATE INDEX IF NOT EXISTS idx_usage_logs_sub_created ON usage_logs(subscription_id, created_at);

View File

@@ -0,0 +1,37 @@
# Model Pricing Data
This directory contains a local copy of the mirrored model pricing data as a fallback mechanism.
## Source
The original file is maintained by the LiteLLM project and mirrored into the `price-mirror` branch of this repository via GitHub Actions:
- Mirror branch (configurable via `PRICE_MIRROR_REPO`): https://raw.githubusercontent.com/<your-repo>/price-mirror/model_prices_and_context_window.json
- Upstream source: https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json
## Purpose
This local copy serves as a fallback when the remote file cannot be downloaded due to:
- Network restrictions
- Firewall rules
- DNS resolution issues
- GitHub being blocked in certain regions
- Docker container network limitations
## Update Process
The pricingService will:
1. First attempt to download the latest version from GitHub
2. If download fails, use this local copy as fallback
3. Log a warning when using the fallback file
## Manual Update
To manually update this file with the latest pricing data (if automation is unavailable):
```bash
curl -s https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json -o model_prices_and_context_window.json
```
## File Format
The file contains JSON data with model pricing information including:
- Model names and identifiers
- Input/output token costs
- Context window sizes
- Model capabilities
Last updated: 2025-08-10

File diff suppressed because it is too large Load Diff

55
deploy/.env.example Normal file
View File

@@ -0,0 +1,55 @@
# =============================================================================
# Sub2API Docker Environment Configuration
# =============================================================================
# Copy this file to .env and modify as needed:
# cp .env.example .env
# nano .env
#
# Then start with: docker-compose up -d
# =============================================================================
# -----------------------------------------------------------------------------
# Server Configuration
# -----------------------------------------------------------------------------
# Bind address for host port mapping
BIND_HOST=0.0.0.0
# Server port (exposed on host)
SERVER_PORT=8080
# Server mode: release or debug
SERVER_MODE=release
# Timezone
TZ=Asia/Shanghai
# -----------------------------------------------------------------------------
# PostgreSQL Configuration (REQUIRED)
# -----------------------------------------------------------------------------
POSTGRES_USER=sub2api
POSTGRES_PASSWORD=change_this_secure_password
POSTGRES_DB=sub2api
# -----------------------------------------------------------------------------
# Redis Configuration
# -----------------------------------------------------------------------------
# Leave empty for no password (default for local development)
REDIS_PASSWORD=
REDIS_DB=0
# -----------------------------------------------------------------------------
# Admin Account
# -----------------------------------------------------------------------------
# Email for the admin account
ADMIN_EMAIL=admin@sub2api.local
# Password for admin account
# Leave empty to auto-generate (will be shown in logs on first run)
ADMIN_PASSWORD=
# -----------------------------------------------------------------------------
# JWT Configuration
# -----------------------------------------------------------------------------
# Leave empty to auto-generate (recommended)
JWT_SECRET=
JWT_EXPIRE_HOUR=24

Some files were not shown because too many files have changed in this diff Show More