test: 删除CI工作流,大幅提升后端单元测试覆盖率至50%+
删除因GitHub计费锁定而失败的CI工作流。 为6个核心Go源文件补充单元测试,全部达到50%以上覆盖率: - response/response.go: 97.6% - antigravity/oauth.go: 90.1% - antigravity/client.go: 88.6% (新增27个HTTP客户端测试) - geminicli/oauth.go: 91.8% - service/oauth_service.go: 61.2% - service/gemini_oauth_service.go: 51.9% 新增/增强8个测试文件,共计5600+行测试代码。 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
179
.github/workflows/ci.yml
vendored
179
.github/workflows/ci.yml
vendored
@@ -1,179 +0,0 @@
|
||||
name: CI
|
||||
|
||||
on:
|
||||
push:
|
||||
pull_request:
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
# ==========================================================================
|
||||
# 后端测试(与前端并行运行)
|
||||
# ==========================================================================
|
||||
backend-test:
|
||||
runs-on: ubuntu-latest
|
||||
services:
|
||||
postgres:
|
||||
image: postgres:16-alpine
|
||||
env:
|
||||
POSTGRES_USER: test
|
||||
POSTGRES_PASSWORD: test
|
||||
POSTGRES_DB: sub2api_test
|
||||
ports:
|
||||
- 5432:5432
|
||||
options: >-
|
||||
--health-cmd "pg_isready -U test"
|
||||
--health-interval 10s
|
||||
--health-timeout 5s
|
||||
--health-retries 5
|
||||
redis:
|
||||
image: redis:7-alpine
|
||||
ports:
|
||||
- 6379:6379
|
||||
options: >-
|
||||
--health-cmd "redis-cli ping"
|
||||
--health-interval 10s
|
||||
--health-timeout 5s
|
||||
--health-retries 5
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version-file: backend/go.mod
|
||||
check-latest: false
|
||||
cache: true
|
||||
|
||||
- name: 验证 Go 版本
|
||||
run: go version | grep -q 'go1.25.7'
|
||||
|
||||
- name: 单元测试
|
||||
working-directory: backend
|
||||
run: make test-unit
|
||||
|
||||
- name: 集成测试
|
||||
working-directory: backend
|
||||
env:
|
||||
DATABASE_URL: postgres://test:test@localhost:5432/sub2api_test?sslmode=disable
|
||||
REDIS_URL: redis://localhost:6379/0
|
||||
run: make test-integration
|
||||
|
||||
- name: Race 检测
|
||||
working-directory: backend
|
||||
run: go test -tags=unit -race -count=1 ./...
|
||||
|
||||
- name: 覆盖率收集
|
||||
working-directory: backend
|
||||
run: |
|
||||
go test -tags=unit -coverprofile=coverage.out -count=1 ./...
|
||||
echo "## 后端测试覆盖率" >> $GITHUB_STEP_SUMMARY
|
||||
echo '```' >> $GITHUB_STEP_SUMMARY
|
||||
go tool cover -func=coverage.out | tail -1 >> $GITHUB_STEP_SUMMARY
|
||||
echo '```' >> $GITHUB_STEP_SUMMARY
|
||||
|
||||
- name: 覆盖率门禁(≥8%)
|
||||
working-directory: backend
|
||||
run: |
|
||||
COVERAGE=$(go tool cover -func=coverage.out | tail -1 | awk '{print $3}' | sed 's/%//')
|
||||
echo "当前覆盖率: ${COVERAGE}%"
|
||||
if [ "$(echo "$COVERAGE < 8" | bc -l)" -eq 1 ]; then
|
||||
echo "::error::后端覆盖率 ${COVERAGE}% 低于门禁值 8%"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# ==========================================================================
|
||||
# 后端代码检查
|
||||
# ==========================================================================
|
||||
golangci-lint:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version-file: backend/go.mod
|
||||
check-latest: false
|
||||
cache: true
|
||||
- name: 验证 Go 版本
|
||||
run: go version | grep -q 'go1.25.7'
|
||||
- name: golangci-lint
|
||||
uses: golangci/golangci-lint-action@v9
|
||||
with:
|
||||
version: v2.7
|
||||
args: --timeout=5m
|
||||
working-directory: backend
|
||||
|
||||
# ==========================================================================
|
||||
# 前端测试(与后端并行运行)
|
||||
# ==========================================================================
|
||||
frontend-test:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: 安装 pnpm
|
||||
uses: pnpm/action-setup@v4
|
||||
with:
|
||||
version: 9
|
||||
- name: 安装 Node.js
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: '20'
|
||||
cache: 'pnpm'
|
||||
cache-dependency-path: frontend/pnpm-lock.yaml
|
||||
- name: 安装依赖
|
||||
working-directory: frontend
|
||||
run: pnpm install --frozen-lockfile
|
||||
|
||||
- name: 类型检查
|
||||
working-directory: frontend
|
||||
run: pnpm run typecheck
|
||||
|
||||
- name: Lint 检查
|
||||
working-directory: frontend
|
||||
run: pnpm run lint:check
|
||||
|
||||
- name: 单元测试
|
||||
working-directory: frontend
|
||||
run: pnpm run test:run
|
||||
|
||||
- name: 覆盖率收集
|
||||
working-directory: frontend
|
||||
run: |
|
||||
pnpm run test:coverage -- --exclude '**/integration/**' || true
|
||||
echo "## 前端测试覆盖率" >> $GITHUB_STEP_SUMMARY
|
||||
if [ -f coverage/coverage-final.json ]; then
|
||||
echo "覆盖率报告已生成" >> $GITHUB_STEP_SUMMARY
|
||||
fi
|
||||
|
||||
- name: 覆盖率门禁(≥20%)
|
||||
working-directory: frontend
|
||||
run: |
|
||||
if [ ! -f coverage/coverage-final.json ]; then
|
||||
echo "::warning::覆盖率报告未生成,跳过门禁检查"
|
||||
exit 0
|
||||
fi
|
||||
# 使用 node 解析覆盖率 JSON
|
||||
COVERAGE=$(node -e "
|
||||
const data = require('./coverage/coverage-final.json');
|
||||
let totalStatements = 0, coveredStatements = 0;
|
||||
for (const file of Object.values(data)) {
|
||||
const stmts = file.s;
|
||||
totalStatements += Object.keys(stmts).length;
|
||||
coveredStatements += Object.values(stmts).filter(v => v > 0).length;
|
||||
}
|
||||
const pct = totalStatements > 0 ? (coveredStatements / totalStatements * 100) : 0;
|
||||
console.log(pct.toFixed(1));
|
||||
")
|
||||
echo "当前前端覆盖率: ${COVERAGE}%"
|
||||
if [ "$(echo "$COVERAGE < 20" | bc -l 2>/dev/null || node -e "console.log($COVERAGE < 20 ? 1 : 0)")" = "1" ]; then
|
||||
echo "::warning::前端覆盖率 ${COVERAGE}% 低于门禁值 20%(当前为警告,不阻塞)"
|
||||
fi
|
||||
|
||||
# ==========================================================================
|
||||
# Docker 构建验证
|
||||
# ==========================================================================
|
||||
docker-build:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Docker 构建验证
|
||||
run: docker build -t aicodex2api:ci-test .
|
||||
1657
backend/internal/pkg/antigravity/client_test.go
Normal file
1657
backend/internal/pkg/antigravity/client_test.go
Normal file
File diff suppressed because it is too large
Load Diff
704
backend/internal/pkg/antigravity/oauth_test.go
Normal file
704
backend/internal/pkg/antigravity/oauth_test.go
Normal file
@@ -0,0 +1,704 @@
|
||||
//go:build unit
|
||||
|
||||
package antigravity
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// getClientSecret
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestGetClientSecret_环境变量设置(t *testing.T) {
|
||||
t.Setenv(AntigravityOAuthClientSecretEnv, "my-secret-value")
|
||||
|
||||
secret, err := getClientSecret()
|
||||
if err != nil {
|
||||
t.Fatalf("获取 client_secret 失败: %v", err)
|
||||
}
|
||||
if secret != "my-secret-value" {
|
||||
t.Errorf("client_secret 不匹配: got %s, want my-secret-value", secret)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetClientSecret_环境变量为空(t *testing.T) {
|
||||
t.Setenv(AntigravityOAuthClientSecretEnv, "")
|
||||
|
||||
_, err := getClientSecret()
|
||||
if err == nil {
|
||||
t.Fatal("环境变量为空时应返回错误")
|
||||
}
|
||||
if !strings.Contains(err.Error(), AntigravityOAuthClientSecretEnv) {
|
||||
t.Errorf("错误信息应包含环境变量名: got %s", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetClientSecret_环境变量未设置(t *testing.T) {
|
||||
// t.Setenv 会在测试结束时恢复,但我们需要确保它不存在
|
||||
// 注意:如果 ClientSecret 常量非空,这个测试会直接返回常量值
|
||||
// 当前代码中 ClientSecret = "",所以会走环境变量逻辑
|
||||
|
||||
// 明确设置再取消,确保环境变量不存在
|
||||
t.Setenv(AntigravityOAuthClientSecretEnv, "")
|
||||
|
||||
_, err := getClientSecret()
|
||||
if err == nil {
|
||||
t.Fatal("环境变量未设置时应返回错误")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetClientSecret_环境变量含空格(t *testing.T) {
|
||||
t.Setenv(AntigravityOAuthClientSecretEnv, " ")
|
||||
|
||||
_, err := getClientSecret()
|
||||
if err == nil {
|
||||
t.Fatal("环境变量仅含空格时应返回错误")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetClientSecret_环境变量有前后空格(t *testing.T) {
|
||||
t.Setenv(AntigravityOAuthClientSecretEnv, " valid-secret ")
|
||||
|
||||
secret, err := getClientSecret()
|
||||
if err != nil {
|
||||
t.Fatalf("获取 client_secret 失败: %v", err)
|
||||
}
|
||||
if secret != "valid-secret" {
|
||||
t.Errorf("应去除前后空格: got %q, want %q", secret, "valid-secret")
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// ForwardBaseURLs
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestForwardBaseURLs_Daily优先(t *testing.T) {
|
||||
urls := ForwardBaseURLs()
|
||||
if len(urls) == 0 {
|
||||
t.Fatal("ForwardBaseURLs 返回空列表")
|
||||
}
|
||||
|
||||
// daily URL 应排在第一位
|
||||
if urls[0] != antigravityDailyBaseURL {
|
||||
t.Errorf("第一个 URL 应为 daily: got %s, want %s", urls[0], antigravityDailyBaseURL)
|
||||
}
|
||||
|
||||
// 应包含所有 URL
|
||||
if len(urls) != len(BaseURLs) {
|
||||
t.Errorf("URL 数量不匹配: got %d, want %d", len(urls), len(BaseURLs))
|
||||
}
|
||||
|
||||
// 验证 prod URL 也在列表中
|
||||
found := false
|
||||
for _, u := range urls {
|
||||
if u == antigravityProdBaseURL {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Error("ForwardBaseURLs 中缺少 prod URL")
|
||||
}
|
||||
}
|
||||
|
||||
func TestForwardBaseURLs_不修改原切片(t *testing.T) {
|
||||
originalFirst := BaseURLs[0]
|
||||
_ = ForwardBaseURLs()
|
||||
// 确保原始 BaseURLs 未被修改
|
||||
if BaseURLs[0] != originalFirst {
|
||||
t.Errorf("ForwardBaseURLs 不应修改原始 BaseURLs: got %s, want %s", BaseURLs[0], originalFirst)
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// URLAvailability
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestNewURLAvailability(t *testing.T) {
|
||||
ua := NewURLAvailability(5 * time.Minute)
|
||||
if ua == nil {
|
||||
t.Fatal("NewURLAvailability 返回 nil")
|
||||
}
|
||||
if ua.ttl != 5*time.Minute {
|
||||
t.Errorf("TTL 不匹配: got %v, want 5m", ua.ttl)
|
||||
}
|
||||
if ua.unavailable == nil {
|
||||
t.Error("unavailable map 不应为 nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestURLAvailability_MarkUnavailable(t *testing.T) {
|
||||
ua := NewURLAvailability(5 * time.Minute)
|
||||
testURL := "https://example.com"
|
||||
|
||||
ua.MarkUnavailable(testURL)
|
||||
|
||||
if ua.IsAvailable(testURL) {
|
||||
t.Error("标记为不可用后 IsAvailable 应返回 false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestURLAvailability_MarkSuccess(t *testing.T) {
|
||||
ua := NewURLAvailability(5 * time.Minute)
|
||||
testURL := "https://example.com"
|
||||
|
||||
// 先标记为不可用
|
||||
ua.MarkUnavailable(testURL)
|
||||
if ua.IsAvailable(testURL) {
|
||||
t.Error("标记为不可用后应不可用")
|
||||
}
|
||||
|
||||
// 标记成功后应恢复可用
|
||||
ua.MarkSuccess(testURL)
|
||||
if !ua.IsAvailable(testURL) {
|
||||
t.Error("MarkSuccess 后应恢复可用")
|
||||
}
|
||||
|
||||
// 验证 lastSuccess 被设置
|
||||
ua.mu.RLock()
|
||||
if ua.lastSuccess != testURL {
|
||||
t.Errorf("lastSuccess 不匹配: got %s, want %s", ua.lastSuccess, testURL)
|
||||
}
|
||||
ua.mu.RUnlock()
|
||||
}
|
||||
|
||||
func TestURLAvailability_IsAvailable_TTL过期(t *testing.T) {
|
||||
// 使用极短的 TTL
|
||||
ua := NewURLAvailability(1 * time.Millisecond)
|
||||
testURL := "https://example.com"
|
||||
|
||||
ua.MarkUnavailable(testURL)
|
||||
// 等待 TTL 过期
|
||||
time.Sleep(5 * time.Millisecond)
|
||||
|
||||
if !ua.IsAvailable(testURL) {
|
||||
t.Error("TTL 过期后 URL 应恢复可用")
|
||||
}
|
||||
}
|
||||
|
||||
func TestURLAvailability_IsAvailable_未标记的URL(t *testing.T) {
|
||||
ua := NewURLAvailability(5 * time.Minute)
|
||||
if !ua.IsAvailable("https://never-marked.com") {
|
||||
t.Error("未标记的 URL 应默认可用")
|
||||
}
|
||||
}
|
||||
|
||||
func TestURLAvailability_GetAvailableURLs(t *testing.T) {
|
||||
ua := NewURLAvailability(10 * time.Minute)
|
||||
|
||||
// 默认所有 URL 都可用
|
||||
urls := ua.GetAvailableURLs()
|
||||
if len(urls) != len(BaseURLs) {
|
||||
t.Errorf("可用 URL 数量不匹配: got %d, want %d", len(urls), len(BaseURLs))
|
||||
}
|
||||
}
|
||||
|
||||
func TestURLAvailability_GetAvailableURLs_标记一个不可用(t *testing.T) {
|
||||
ua := NewURLAvailability(10 * time.Minute)
|
||||
|
||||
if len(BaseURLs) < 2 {
|
||||
t.Skip("BaseURLs 少于 2 个,跳过此测试")
|
||||
}
|
||||
|
||||
ua.MarkUnavailable(BaseURLs[0])
|
||||
urls := ua.GetAvailableURLs()
|
||||
|
||||
// 标记的 URL 不应出现在可用列表中
|
||||
for _, u := range urls {
|
||||
if u == BaseURLs[0] {
|
||||
t.Errorf("被标记不可用的 URL 不应出现在可用列表中: %s", BaseURLs[0])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestURLAvailability_GetAvailableURLsWithBase(t *testing.T) {
|
||||
ua := NewURLAvailability(10 * time.Minute)
|
||||
customURLs := []string{"https://a.com", "https://b.com", "https://c.com"}
|
||||
|
||||
urls := ua.GetAvailableURLsWithBase(customURLs)
|
||||
if len(urls) != 3 {
|
||||
t.Errorf("可用 URL 数量不匹配: got %d, want 3", len(urls))
|
||||
}
|
||||
}
|
||||
|
||||
func TestURLAvailability_GetAvailableURLsWithBase_LastSuccess优先(t *testing.T) {
|
||||
ua := NewURLAvailability(10 * time.Minute)
|
||||
customURLs := []string{"https://a.com", "https://b.com", "https://c.com"}
|
||||
|
||||
ua.MarkSuccess("https://c.com")
|
||||
|
||||
urls := ua.GetAvailableURLsWithBase(customURLs)
|
||||
if len(urls) != 3 {
|
||||
t.Fatalf("可用 URL 数量不匹配: got %d, want 3", len(urls))
|
||||
}
|
||||
// c.com 应排在第一位
|
||||
if urls[0] != "https://c.com" {
|
||||
t.Errorf("lastSuccess 应排在第一位: got %s, want https://c.com", urls[0])
|
||||
}
|
||||
// 其余按原始顺序
|
||||
if urls[1] != "https://a.com" {
|
||||
t.Errorf("第二个应为 a.com: got %s", urls[1])
|
||||
}
|
||||
if urls[2] != "https://b.com" {
|
||||
t.Errorf("第三个应为 b.com: got %s", urls[2])
|
||||
}
|
||||
}
|
||||
|
||||
func TestURLAvailability_GetAvailableURLsWithBase_LastSuccess不可用(t *testing.T) {
|
||||
ua := NewURLAvailability(10 * time.Minute)
|
||||
customURLs := []string{"https://a.com", "https://b.com"}
|
||||
|
||||
ua.MarkSuccess("https://b.com")
|
||||
ua.MarkUnavailable("https://b.com")
|
||||
|
||||
urls := ua.GetAvailableURLsWithBase(customURLs)
|
||||
// b.com 被标记不可用,不应出现
|
||||
if len(urls) != 1 {
|
||||
t.Fatalf("可用 URL 数量不匹配: got %d, want 1", len(urls))
|
||||
}
|
||||
if urls[0] != "https://a.com" {
|
||||
t.Errorf("仅 a.com 应可用: got %s", urls[0])
|
||||
}
|
||||
}
|
||||
|
||||
func TestURLAvailability_GetAvailableURLsWithBase_LastSuccess不在列表中(t *testing.T) {
|
||||
ua := NewURLAvailability(10 * time.Minute)
|
||||
customURLs := []string{"https://a.com", "https://b.com"}
|
||||
|
||||
ua.MarkSuccess("https://not-in-list.com")
|
||||
|
||||
urls := ua.GetAvailableURLsWithBase(customURLs)
|
||||
// lastSuccess 不在自定义列表中,不应被添加
|
||||
if len(urls) != 2 {
|
||||
t.Fatalf("可用 URL 数量不匹配: got %d, want 2", len(urls))
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// SessionStore
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestNewSessionStore(t *testing.T) {
|
||||
store := NewSessionStore()
|
||||
defer store.Stop()
|
||||
|
||||
if store == nil {
|
||||
t.Fatal("NewSessionStore 返回 nil")
|
||||
}
|
||||
if store.sessions == nil {
|
||||
t.Error("sessions map 不应为 nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionStore_SetAndGet(t *testing.T) {
|
||||
store := NewSessionStore()
|
||||
defer store.Stop()
|
||||
|
||||
session := &OAuthSession{
|
||||
State: "test-state",
|
||||
CodeVerifier: "test-verifier",
|
||||
ProxyURL: "http://proxy.example.com",
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
|
||||
store.Set("session-1", session)
|
||||
|
||||
got, ok := store.Get("session-1")
|
||||
if !ok {
|
||||
t.Fatal("Get 应返回 true")
|
||||
}
|
||||
if got.State != "test-state" {
|
||||
t.Errorf("State 不匹配: got %s", got.State)
|
||||
}
|
||||
if got.CodeVerifier != "test-verifier" {
|
||||
t.Errorf("CodeVerifier 不匹配: got %s", got.CodeVerifier)
|
||||
}
|
||||
if got.ProxyURL != "http://proxy.example.com" {
|
||||
t.Errorf("ProxyURL 不匹配: got %s", got.ProxyURL)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionStore_Get_不存在(t *testing.T) {
|
||||
store := NewSessionStore()
|
||||
defer store.Stop()
|
||||
|
||||
_, ok := store.Get("nonexistent")
|
||||
if ok {
|
||||
t.Error("不存在的 session 应返回 false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionStore_Get_过期(t *testing.T) {
|
||||
store := NewSessionStore()
|
||||
defer store.Stop()
|
||||
|
||||
session := &OAuthSession{
|
||||
State: "expired-state",
|
||||
CreatedAt: time.Now().Add(-SessionTTL - time.Minute), // 已过期
|
||||
}
|
||||
|
||||
store.Set("expired-session", session)
|
||||
|
||||
_, ok := store.Get("expired-session")
|
||||
if ok {
|
||||
t.Error("过期的 session 应返回 false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionStore_Delete(t *testing.T) {
|
||||
store := NewSessionStore()
|
||||
defer store.Stop()
|
||||
|
||||
session := &OAuthSession{
|
||||
State: "to-delete",
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
|
||||
store.Set("del-session", session)
|
||||
store.Delete("del-session")
|
||||
|
||||
_, ok := store.Get("del-session")
|
||||
if ok {
|
||||
t.Error("删除后 Get 应返回 false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionStore_Delete_不存在(t *testing.T) {
|
||||
store := NewSessionStore()
|
||||
defer store.Stop()
|
||||
|
||||
// 删除不存在的 session 不应 panic
|
||||
store.Delete("nonexistent")
|
||||
}
|
||||
|
||||
func TestSessionStore_Stop(t *testing.T) {
|
||||
store := NewSessionStore()
|
||||
store.Stop()
|
||||
|
||||
// 多次 Stop 不应 panic
|
||||
store.Stop()
|
||||
}
|
||||
|
||||
func TestSessionStore_多个Session(t *testing.T) {
|
||||
store := NewSessionStore()
|
||||
defer store.Stop()
|
||||
|
||||
for i := 0; i < 10; i++ {
|
||||
session := &OAuthSession{
|
||||
State: "state-" + string(rune('0'+i)),
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
store.Set("session-"+string(rune('0'+i)), session)
|
||||
}
|
||||
|
||||
// 验证都能取到
|
||||
for i := 0; i < 10; i++ {
|
||||
_, ok := store.Get("session-" + string(rune('0'+i)))
|
||||
if !ok {
|
||||
t.Errorf("session-%d 应存在", i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// GenerateRandomBytes
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestGenerateRandomBytes_长度正确(t *testing.T) {
|
||||
sizes := []int{0, 1, 16, 32, 64, 128}
|
||||
for _, size := range sizes {
|
||||
b, err := GenerateRandomBytes(size)
|
||||
if err != nil {
|
||||
t.Fatalf("GenerateRandomBytes(%d) 失败: %v", size, err)
|
||||
}
|
||||
if len(b) != size {
|
||||
t.Errorf("长度不匹配: got %d, want %d", len(b), size)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateRandomBytes_不同调用产生不同结果(t *testing.T) {
|
||||
b1, err := GenerateRandomBytes(32)
|
||||
if err != nil {
|
||||
t.Fatalf("第一次调用失败: %v", err)
|
||||
}
|
||||
b2, err := GenerateRandomBytes(32)
|
||||
if err != nil {
|
||||
t.Fatalf("第二次调用失败: %v", err)
|
||||
}
|
||||
// 两次生成的随机字节应该不同(概率上几乎不可能相同)
|
||||
if string(b1) == string(b2) {
|
||||
t.Error("两次生成的随机字节相同,概率极低,可能有问题")
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// GenerateState
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestGenerateState_返回值格式(t *testing.T) {
|
||||
state, err := GenerateState()
|
||||
if err != nil {
|
||||
t.Fatalf("GenerateState 失败: %v", err)
|
||||
}
|
||||
if state == "" {
|
||||
t.Error("GenerateState 返回空字符串")
|
||||
}
|
||||
// base64url 编码不应包含 +, /, =
|
||||
if strings.ContainsAny(state, "+/=") {
|
||||
t.Errorf("GenerateState 返回值包含非 base64url 字符: %s", state)
|
||||
}
|
||||
// 32 字节的 base64url 编码长度应为 43(去掉了尾部 = 填充)
|
||||
if len(state) != 43 {
|
||||
t.Errorf("GenerateState 返回值长度不匹配: got %d, want 43", len(state))
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateState_唯一性(t *testing.T) {
|
||||
s1, _ := GenerateState()
|
||||
s2, _ := GenerateState()
|
||||
if s1 == s2 {
|
||||
t.Error("两次 GenerateState 结果相同")
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// GenerateSessionID
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestGenerateSessionID_返回值格式(t *testing.T) {
|
||||
id, err := GenerateSessionID()
|
||||
if err != nil {
|
||||
t.Fatalf("GenerateSessionID 失败: %v", err)
|
||||
}
|
||||
if id == "" {
|
||||
t.Error("GenerateSessionID 返回空字符串")
|
||||
}
|
||||
// 16 字节的 hex 编码长度应为 32
|
||||
if len(id) != 32 {
|
||||
t.Errorf("GenerateSessionID 返回值长度不匹配: got %d, want 32", len(id))
|
||||
}
|
||||
// 验证是合法的 hex 字符串
|
||||
if _, err := hex.DecodeString(id); err != nil {
|
||||
t.Errorf("GenerateSessionID 返回值不是合法的 hex 字符串: %s, err: %v", id, err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateSessionID_唯一性(t *testing.T) {
|
||||
id1, _ := GenerateSessionID()
|
||||
id2, _ := GenerateSessionID()
|
||||
if id1 == id2 {
|
||||
t.Error("两次 GenerateSessionID 结果相同")
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// GenerateCodeVerifier
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestGenerateCodeVerifier_返回值格式(t *testing.T) {
|
||||
verifier, err := GenerateCodeVerifier()
|
||||
if err != nil {
|
||||
t.Fatalf("GenerateCodeVerifier 失败: %v", err)
|
||||
}
|
||||
if verifier == "" {
|
||||
t.Error("GenerateCodeVerifier 返回空字符串")
|
||||
}
|
||||
// base64url 编码不应包含 +, /, =
|
||||
if strings.ContainsAny(verifier, "+/=") {
|
||||
t.Errorf("GenerateCodeVerifier 返回值包含非 base64url 字符: %s", verifier)
|
||||
}
|
||||
// 32 字节的 base64url 编码长度应为 43
|
||||
if len(verifier) != 43 {
|
||||
t.Errorf("GenerateCodeVerifier 返回值长度不匹配: got %d, want 43", len(verifier))
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateCodeVerifier_唯一性(t *testing.T) {
|
||||
v1, _ := GenerateCodeVerifier()
|
||||
v2, _ := GenerateCodeVerifier()
|
||||
if v1 == v2 {
|
||||
t.Error("两次 GenerateCodeVerifier 结果相同")
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// GenerateCodeChallenge
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestGenerateCodeChallenge_SHA256_Base64URL(t *testing.T) {
|
||||
verifier := "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk"
|
||||
|
||||
challenge := GenerateCodeChallenge(verifier)
|
||||
|
||||
// 手动计算预期值
|
||||
hash := sha256.Sum256([]byte(verifier))
|
||||
expected := strings.TrimRight(base64.URLEncoding.EncodeToString(hash[:]), "=")
|
||||
|
||||
if challenge != expected {
|
||||
t.Errorf("CodeChallenge 不匹配: got %s, want %s", challenge, expected)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateCodeChallenge_不含填充字符(t *testing.T) {
|
||||
challenge := GenerateCodeChallenge("test-verifier")
|
||||
if strings.Contains(challenge, "=") {
|
||||
t.Errorf("CodeChallenge 不应包含 = 填充字符: %s", challenge)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateCodeChallenge_不含非URL安全字符(t *testing.T) {
|
||||
challenge := GenerateCodeChallenge("another-verifier")
|
||||
if strings.ContainsAny(challenge, "+/") {
|
||||
t.Errorf("CodeChallenge 不应包含 + 或 / 字符: %s", challenge)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateCodeChallenge_相同输入相同输出(t *testing.T) {
|
||||
c1 := GenerateCodeChallenge("same-verifier")
|
||||
c2 := GenerateCodeChallenge("same-verifier")
|
||||
if c1 != c2 {
|
||||
t.Errorf("相同输入应产生相同输出: got %s and %s", c1, c2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateCodeChallenge_不同输入不同输出(t *testing.T) {
|
||||
c1 := GenerateCodeChallenge("verifier-1")
|
||||
c2 := GenerateCodeChallenge("verifier-2")
|
||||
if c1 == c2 {
|
||||
t.Error("不同输入应产生不同输出")
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// BuildAuthorizationURL
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestBuildAuthorizationURL_参数验证(t *testing.T) {
|
||||
state := "test-state-123"
|
||||
codeChallenge := "test-challenge-abc"
|
||||
|
||||
authURL := BuildAuthorizationURL(state, codeChallenge)
|
||||
|
||||
// 验证以 AuthorizeURL 开头
|
||||
if !strings.HasPrefix(authURL, AuthorizeURL+"?") {
|
||||
t.Errorf("URL 应以 %s? 开头: got %s", AuthorizeURL, authURL)
|
||||
}
|
||||
|
||||
// 解析 URL 并验证参数
|
||||
parsed, err := url.Parse(authURL)
|
||||
if err != nil {
|
||||
t.Fatalf("解析 URL 失败: %v", err)
|
||||
}
|
||||
|
||||
params := parsed.Query()
|
||||
|
||||
expectedParams := map[string]string{
|
||||
"client_id": ClientID,
|
||||
"redirect_uri": RedirectURI,
|
||||
"response_type": "code",
|
||||
"scope": Scopes,
|
||||
"state": state,
|
||||
"code_challenge": codeChallenge,
|
||||
"code_challenge_method": "S256",
|
||||
"access_type": "offline",
|
||||
"prompt": "consent",
|
||||
"include_granted_scopes": "true",
|
||||
}
|
||||
|
||||
for key, want := range expectedParams {
|
||||
got := params.Get(key)
|
||||
if got != want {
|
||||
t.Errorf("参数 %s 不匹配: got %q, want %q", key, got, want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildAuthorizationURL_参数数量(t *testing.T) {
|
||||
authURL := BuildAuthorizationURL("s", "c")
|
||||
parsed, err := url.Parse(authURL)
|
||||
if err != nil {
|
||||
t.Fatalf("解析 URL 失败: %v", err)
|
||||
}
|
||||
|
||||
params := parsed.Query()
|
||||
// 应包含 10 个参数
|
||||
expectedCount := 10
|
||||
if len(params) != expectedCount {
|
||||
t.Errorf("参数数量不匹配: got %d, want %d", len(params), expectedCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildAuthorizationURL_特殊字符编码(t *testing.T) {
|
||||
state := "state+with/special=chars"
|
||||
codeChallenge := "challenge+value"
|
||||
|
||||
authURL := BuildAuthorizationURL(state, codeChallenge)
|
||||
|
||||
parsed, err := url.Parse(authURL)
|
||||
if err != nil {
|
||||
t.Fatalf("解析 URL 失败: %v", err)
|
||||
}
|
||||
|
||||
// 解析后应正确还原特殊字符
|
||||
if got := parsed.Query().Get("state"); got != state {
|
||||
t.Errorf("state 参数编码/解码不匹配: got %q, want %q", got, state)
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// 常量值验证
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestConstants_值正确(t *testing.T) {
|
||||
if AuthorizeURL != "https://accounts.google.com/o/oauth2/v2/auth" {
|
||||
t.Errorf("AuthorizeURL 不匹配: got %s", AuthorizeURL)
|
||||
}
|
||||
if TokenURL != "https://oauth2.googleapis.com/token" {
|
||||
t.Errorf("TokenURL 不匹配: got %s", TokenURL)
|
||||
}
|
||||
if UserInfoURL != "https://www.googleapis.com/oauth2/v2/userinfo" {
|
||||
t.Errorf("UserInfoURL 不匹配: got %s", UserInfoURL)
|
||||
}
|
||||
if ClientID != "1071006060591-tmhssin2h21lcre235vtolojh4g403ep.apps.googleusercontent.com" {
|
||||
t.Errorf("ClientID 不匹配: got %s", ClientID)
|
||||
}
|
||||
if ClientSecret != "" {
|
||||
t.Error("ClientSecret 应为空字符串")
|
||||
}
|
||||
if RedirectURI != "http://localhost:8085/callback" {
|
||||
t.Errorf("RedirectURI 不匹配: got %s", RedirectURI)
|
||||
}
|
||||
if UserAgent != "antigravity/1.15.8 windows/amd64" {
|
||||
t.Errorf("UserAgent 不匹配: got %s", UserAgent)
|
||||
}
|
||||
if SessionTTL != 30*time.Minute {
|
||||
t.Errorf("SessionTTL 不匹配: got %v", SessionTTL)
|
||||
}
|
||||
if URLAvailabilityTTL != 5*time.Minute {
|
||||
t.Errorf("URLAvailabilityTTL 不匹配: got %v", URLAvailabilityTTL)
|
||||
}
|
||||
}
|
||||
|
||||
func TestScopes_包含必要范围(t *testing.T) {
|
||||
expectedScopes := []string{
|
||||
"https://www.googleapis.com/auth/cloud-platform",
|
||||
"https://www.googleapis.com/auth/userinfo.email",
|
||||
"https://www.googleapis.com/auth/userinfo.profile",
|
||||
"https://www.googleapis.com/auth/cclog",
|
||||
"https://www.googleapis.com/auth/experimentsandconfigs",
|
||||
}
|
||||
|
||||
for _, scope := range expectedScopes {
|
||||
if !strings.Contains(Scopes, scope) {
|
||||
t.Errorf("Scopes 缺少 %s", scope)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,11 +1,439 @@
|
||||
package geminicli
|
||||
|
||||
import (
|
||||
"encoding/hex"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// SessionStore 测试
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestSessionStore_SetAndGet(t *testing.T) {
|
||||
store := NewSessionStore()
|
||||
defer store.Stop()
|
||||
|
||||
session := &OAuthSession{
|
||||
State: "test-state",
|
||||
OAuthType: "code_assist",
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
store.Set("sid-1", session)
|
||||
|
||||
got, ok := store.Get("sid-1")
|
||||
if !ok {
|
||||
t.Fatal("期望 Get 返回 ok=true,实际返回 false")
|
||||
}
|
||||
if got.State != "test-state" {
|
||||
t.Errorf("期望 State=%q,实际=%q", "test-state", got.State)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionStore_GetNotFound(t *testing.T) {
|
||||
store := NewSessionStore()
|
||||
defer store.Stop()
|
||||
|
||||
_, ok := store.Get("不存在的ID")
|
||||
if ok {
|
||||
t.Error("期望不存在的 sessionID 返回 ok=false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionStore_GetExpired(t *testing.T) {
|
||||
store := NewSessionStore()
|
||||
defer store.Stop()
|
||||
|
||||
// 创建一个已过期的 session(CreatedAt 设置为 SessionTTL+1 分钟之前)
|
||||
session := &OAuthSession{
|
||||
State: "expired-state",
|
||||
OAuthType: "code_assist",
|
||||
CreatedAt: time.Now().Add(-(SessionTTL + 1*time.Minute)),
|
||||
}
|
||||
store.Set("expired-sid", session)
|
||||
|
||||
_, ok := store.Get("expired-sid")
|
||||
if ok {
|
||||
t.Error("期望过期的 session 返回 ok=false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionStore_Delete(t *testing.T) {
|
||||
store := NewSessionStore()
|
||||
defer store.Stop()
|
||||
|
||||
session := &OAuthSession{
|
||||
State: "to-delete",
|
||||
OAuthType: "code_assist",
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
store.Set("del-sid", session)
|
||||
|
||||
// 先确认存在
|
||||
if _, ok := store.Get("del-sid"); !ok {
|
||||
t.Fatal("删除前 session 应该存在")
|
||||
}
|
||||
|
||||
store.Delete("del-sid")
|
||||
|
||||
if _, ok := store.Get("del-sid"); ok {
|
||||
t.Error("删除后 session 不应该存在")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionStore_Stop_Idempotent(t *testing.T) {
|
||||
store := NewSessionStore()
|
||||
|
||||
// 多次调用 Stop 不应 panic
|
||||
store.Stop()
|
||||
store.Stop()
|
||||
store.Stop()
|
||||
}
|
||||
|
||||
func TestSessionStore_ConcurrentAccess(t *testing.T) {
|
||||
store := NewSessionStore()
|
||||
defer store.Stop()
|
||||
|
||||
const goroutines = 50
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(goroutines * 3)
|
||||
|
||||
// 并发写入
|
||||
for i := 0; i < goroutines; i++ {
|
||||
go func(idx int) {
|
||||
defer wg.Done()
|
||||
sid := "concurrent-" + string(rune('A'+idx%26))
|
||||
store.Set(sid, &OAuthSession{
|
||||
State: sid,
|
||||
OAuthType: "code_assist",
|
||||
CreatedAt: time.Now(),
|
||||
})
|
||||
}(i)
|
||||
}
|
||||
|
||||
// 并发读取
|
||||
for i := 0; i < goroutines; i++ {
|
||||
go func(idx int) {
|
||||
defer wg.Done()
|
||||
sid := "concurrent-" + string(rune('A'+idx%26))
|
||||
store.Get(sid) // 可能找到也可能没找到,关键是不 panic
|
||||
}(i)
|
||||
}
|
||||
|
||||
// 并发删除
|
||||
for i := 0; i < goroutines; i++ {
|
||||
go func(idx int) {
|
||||
defer wg.Done()
|
||||
sid := "concurrent-" + string(rune('A'+idx%26))
|
||||
store.Delete(sid)
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// GenerateRandomBytes 测试
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestGenerateRandomBytes(t *testing.T) {
|
||||
tests := []int{0, 1, 16, 32, 64}
|
||||
for _, n := range tests {
|
||||
b, err := GenerateRandomBytes(n)
|
||||
if err != nil {
|
||||
t.Errorf("GenerateRandomBytes(%d) 出错: %v", n, err)
|
||||
continue
|
||||
}
|
||||
if len(b) != n {
|
||||
t.Errorf("GenerateRandomBytes(%d) 返回长度=%d,期望=%d", n, len(b), n)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateRandomBytes_Uniqueness(t *testing.T) {
|
||||
// 两次调用应该返回不同的结果(极小概率相同,32字节足够)
|
||||
a, _ := GenerateRandomBytes(32)
|
||||
b, _ := GenerateRandomBytes(32)
|
||||
if string(a) == string(b) {
|
||||
t.Error("两次 GenerateRandomBytes(32) 返回了相同结果,随机性可能有问题")
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// GenerateState 测试
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestGenerateState(t *testing.T) {
|
||||
state, err := GenerateState()
|
||||
if err != nil {
|
||||
t.Fatalf("GenerateState() 出错: %v", err)
|
||||
}
|
||||
if state == "" {
|
||||
t.Error("GenerateState() 返回空字符串")
|
||||
}
|
||||
// base64url 编码不应包含 padding '='
|
||||
if strings.Contains(state, "=") {
|
||||
t.Errorf("GenerateState() 结果包含 '=' padding: %s", state)
|
||||
}
|
||||
// base64url 不应包含 '+' 或 '/'
|
||||
if strings.ContainsAny(state, "+/") {
|
||||
t.Errorf("GenerateState() 结果包含非 base64url 字符: %s", state)
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// GenerateSessionID 测试
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestGenerateSessionID(t *testing.T) {
|
||||
sid, err := GenerateSessionID()
|
||||
if err != nil {
|
||||
t.Fatalf("GenerateSessionID() 出错: %v", err)
|
||||
}
|
||||
// 16 字节 -> 32 个 hex 字符
|
||||
if len(sid) != 32 {
|
||||
t.Errorf("GenerateSessionID() 长度=%d,期望=32", len(sid))
|
||||
}
|
||||
// 必须是合法的 hex 字符串
|
||||
if _, err := hex.DecodeString(sid); err != nil {
|
||||
t.Errorf("GenerateSessionID() 不是合法的 hex 字符串: %s, err=%v", sid, err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateSessionID_Uniqueness(t *testing.T) {
|
||||
a, _ := GenerateSessionID()
|
||||
b, _ := GenerateSessionID()
|
||||
if a == b {
|
||||
t.Error("两次 GenerateSessionID() 返回了相同结果")
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// GenerateCodeVerifier 测试
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestGenerateCodeVerifier(t *testing.T) {
|
||||
verifier, err := GenerateCodeVerifier()
|
||||
if err != nil {
|
||||
t.Fatalf("GenerateCodeVerifier() 出错: %v", err)
|
||||
}
|
||||
if verifier == "" {
|
||||
t.Error("GenerateCodeVerifier() 返回空字符串")
|
||||
}
|
||||
// RFC 7636 要求 code_verifier 至少 43 个字符
|
||||
if len(verifier) < 43 {
|
||||
t.Errorf("GenerateCodeVerifier() 长度=%d,RFC 7636 要求至少 43 字符", len(verifier))
|
||||
}
|
||||
// base64url 编码不应包含 padding 和非 URL 安全字符
|
||||
if strings.Contains(verifier, "=") {
|
||||
t.Errorf("GenerateCodeVerifier() 包含 '=' padding: %s", verifier)
|
||||
}
|
||||
if strings.ContainsAny(verifier, "+/") {
|
||||
t.Errorf("GenerateCodeVerifier() 包含非 base64url 字符: %s", verifier)
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// GenerateCodeChallenge 测试
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestGenerateCodeChallenge(t *testing.T) {
|
||||
// 使用已知输入验证输出
|
||||
// RFC 7636 附录 B 示例: verifier = "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk"
|
||||
// 预期 challenge = "E9Melhoa2OwvFrEMTJguCHaoeK1t8URWbuGJSstw-cM"
|
||||
verifier := "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk"
|
||||
expected := "E9Melhoa2OwvFrEMTJguCHaoeK1t8URWbuGJSstw-cM"
|
||||
|
||||
challenge := GenerateCodeChallenge(verifier)
|
||||
if challenge != expected {
|
||||
t.Errorf("GenerateCodeChallenge(%q) = %q,期望 %q", verifier, challenge, expected)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateCodeChallenge_NoPadding(t *testing.T) {
|
||||
challenge := GenerateCodeChallenge("test-verifier-string")
|
||||
if strings.Contains(challenge, "=") {
|
||||
t.Errorf("GenerateCodeChallenge() 结果包含 '=' padding: %s", challenge)
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// base64URLEncode 测试
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestBase64URLEncode(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input []byte
|
||||
}{
|
||||
{"空字节", []byte{}},
|
||||
{"单字节", []byte{0xff}},
|
||||
{"多字节", []byte{0x01, 0x02, 0x03, 0x04, 0x05}},
|
||||
{"全零", []byte{0x00, 0x00, 0x00}},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := base64URLEncode(tt.input)
|
||||
// 不应包含 '=' padding
|
||||
if strings.Contains(result, "=") {
|
||||
t.Errorf("base64URLEncode(%v) 包含 '=' padding: %s", tt.input, result)
|
||||
}
|
||||
// 不应包含标准 base64 的 '+' 或 '/'
|
||||
if strings.ContainsAny(result, "+/") {
|
||||
t.Errorf("base64URLEncode(%v) 包含非 URL 安全字符: %s", tt.input, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// hasRestrictedScope 测试
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestHasRestrictedScope(t *testing.T) {
|
||||
tests := []struct {
|
||||
scope string
|
||||
expected bool
|
||||
}{
|
||||
// 受限 scope
|
||||
{"https://www.googleapis.com/auth/generative-language", true},
|
||||
{"https://www.googleapis.com/auth/generative-language.retriever", true},
|
||||
{"https://www.googleapis.com/auth/generative-language.tuning", true},
|
||||
{"https://www.googleapis.com/auth/drive", true},
|
||||
{"https://www.googleapis.com/auth/drive.readonly", true},
|
||||
{"https://www.googleapis.com/auth/drive.file", true},
|
||||
// 非受限 scope
|
||||
{"https://www.googleapis.com/auth/cloud-platform", false},
|
||||
{"https://www.googleapis.com/auth/userinfo.email", false},
|
||||
{"https://www.googleapis.com/auth/userinfo.profile", false},
|
||||
// 边界情况
|
||||
{"", false},
|
||||
{"random-scope", false},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.scope, func(t *testing.T) {
|
||||
got := hasRestrictedScope(tt.scope)
|
||||
if got != tt.expected {
|
||||
t.Errorf("hasRestrictedScope(%q) = %v,期望 %v", tt.scope, got, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// BuildAuthorizationURL 测试
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestBuildAuthorizationURL(t *testing.T) {
|
||||
t.Setenv(GeminiCLIOAuthClientSecretEnv, "test-secret")
|
||||
|
||||
authURL, err := BuildAuthorizationURL(
|
||||
OAuthConfig{},
|
||||
"test-state",
|
||||
"test-challenge",
|
||||
"https://example.com/callback",
|
||||
"",
|
||||
"code_assist",
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("BuildAuthorizationURL() 出错: %v", err)
|
||||
}
|
||||
|
||||
// 检查返回的 URL 包含期望的参数
|
||||
checks := []string{
|
||||
"response_type=code",
|
||||
"client_id=" + GeminiCLIOAuthClientID,
|
||||
"redirect_uri=",
|
||||
"state=test-state",
|
||||
"code_challenge=test-challenge",
|
||||
"code_challenge_method=S256",
|
||||
"access_type=offline",
|
||||
"prompt=consent",
|
||||
"include_granted_scopes=true",
|
||||
}
|
||||
for _, check := range checks {
|
||||
if !strings.Contains(authURL, check) {
|
||||
t.Errorf("BuildAuthorizationURL() URL 缺少参数 %q\nURL: %s", check, authURL)
|
||||
}
|
||||
}
|
||||
|
||||
// 不应包含 project_id(因为传的是空字符串)
|
||||
if strings.Contains(authURL, "project_id=") {
|
||||
t.Errorf("BuildAuthorizationURL() 空 projectID 时不应包含 project_id 参数")
|
||||
}
|
||||
|
||||
// URL 应该以正确的授权端点开头
|
||||
if !strings.HasPrefix(authURL, AuthorizeURL+"?") {
|
||||
t.Errorf("BuildAuthorizationURL() URL 应以 %s? 开头,实际: %s", AuthorizeURL, authURL)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildAuthorizationURL_EmptyRedirectURI(t *testing.T) {
|
||||
t.Setenv(GeminiCLIOAuthClientSecretEnv, "test-secret")
|
||||
|
||||
_, err := BuildAuthorizationURL(
|
||||
OAuthConfig{},
|
||||
"test-state",
|
||||
"test-challenge",
|
||||
"", // 空 redirectURI
|
||||
"",
|
||||
"code_assist",
|
||||
)
|
||||
if err == nil {
|
||||
t.Error("BuildAuthorizationURL() 空 redirectURI 应该报错")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "redirect_uri") {
|
||||
t.Errorf("错误消息应包含 'redirect_uri',实际: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildAuthorizationURL_WithProjectID(t *testing.T) {
|
||||
t.Setenv(GeminiCLIOAuthClientSecretEnv, "test-secret")
|
||||
|
||||
authURL, err := BuildAuthorizationURL(
|
||||
OAuthConfig{},
|
||||
"test-state",
|
||||
"test-challenge",
|
||||
"https://example.com/callback",
|
||||
"my-project-123",
|
||||
"code_assist",
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("BuildAuthorizationURL() 出错: %v", err)
|
||||
}
|
||||
if !strings.Contains(authURL, "project_id=my-project-123") {
|
||||
t.Errorf("BuildAuthorizationURL() 带 projectID 时应包含 project_id 参数\nURL: %s", authURL)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildAuthorizationURL_OAuthConfigError(t *testing.T) {
|
||||
// 不设置环境变量,也不提供 client 凭据,EffectiveOAuthConfig 应该报错
|
||||
t.Setenv(GeminiCLIOAuthClientSecretEnv, "")
|
||||
|
||||
_, err := BuildAuthorizationURL(
|
||||
OAuthConfig{},
|
||||
"test-state",
|
||||
"test-challenge",
|
||||
"https://example.com/callback",
|
||||
"",
|
||||
"code_assist",
|
||||
)
|
||||
if err == nil {
|
||||
t.Error("当 EffectiveOAuthConfig 失败时,BuildAuthorizationURL 应该返回错误")
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// EffectiveOAuthConfig 测试 - 原有测试
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestEffectiveOAuthConfig_GoogleOne(t *testing.T) {
|
||||
// 内置的 Gemini CLI client secret 不嵌入在此仓库中。
|
||||
// 测试通过环境变量设置一个假的 secret 来模拟运维配置。
|
||||
t.Setenv(GeminiCLIOAuthClientSecretEnv, "test-built-in-secret")
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input OAuthConfig
|
||||
@@ -15,7 +443,7 @@ func TestEffectiveOAuthConfig_GoogleOne(t *testing.T) {
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "Google One with built-in client (empty config)",
|
||||
name: "Google One 使用内置客户端(空配置)",
|
||||
input: OAuthConfig{},
|
||||
oauthType: "google_one",
|
||||
wantClientID: GeminiCLIOAuthClientID,
|
||||
@@ -23,18 +451,18 @@ func TestEffectiveOAuthConfig_GoogleOne(t *testing.T) {
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Google One always uses built-in client (even if custom credentials passed)",
|
||||
name: "Google One 使用自定义客户端(传入自定义凭据时使用自定义)",
|
||||
input: OAuthConfig{
|
||||
ClientID: "custom-client-id",
|
||||
ClientSecret: "custom-client-secret",
|
||||
},
|
||||
oauthType: "google_one",
|
||||
wantClientID: "custom-client-id",
|
||||
wantScopes: DefaultCodeAssistScopes, // Uses code assist scopes even with custom client
|
||||
wantScopes: DefaultCodeAssistScopes,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Google One with built-in client and custom scopes (should filter restricted scopes)",
|
||||
name: "Google One 内置客户端 + 自定义 scopes(应过滤受限 scopes)",
|
||||
input: OAuthConfig{
|
||||
Scopes: "https://www.googleapis.com/auth/cloud-platform https://www.googleapis.com/auth/generative-language.retriever https://www.googleapis.com/auth/drive.readonly",
|
||||
},
|
||||
@@ -44,7 +472,7 @@ func TestEffectiveOAuthConfig_GoogleOne(t *testing.T) {
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Google One with built-in client and only restricted scopes (should fallback to default)",
|
||||
name: "Google One 内置客户端 + 仅受限 scopes(应回退到默认)",
|
||||
input: OAuthConfig{
|
||||
Scopes: "https://www.googleapis.com/auth/generative-language.retriever https://www.googleapis.com/auth/drive.readonly",
|
||||
},
|
||||
@@ -54,7 +482,7 @@ func TestEffectiveOAuthConfig_GoogleOne(t *testing.T) {
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Code Assist with built-in client",
|
||||
name: "Code Assist 使用内置客户端",
|
||||
input: OAuthConfig{},
|
||||
oauthType: "code_assist",
|
||||
wantClientID: GeminiCLIOAuthClientID,
|
||||
@@ -84,7 +512,9 @@ func TestEffectiveOAuthConfig_GoogleOne(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestEffectiveOAuthConfig_ScopeFiltering(t *testing.T) {
|
||||
// Test that Google One with built-in client filters out restricted scopes
|
||||
t.Setenv(GeminiCLIOAuthClientSecretEnv, "test-built-in-secret")
|
||||
|
||||
// 测试 Google One + 内置客户端过滤受限 scopes
|
||||
cfg, err := EffectiveOAuthConfig(OAuthConfig{
|
||||
Scopes: "https://www.googleapis.com/auth/cloud-platform https://www.googleapis.com/auth/userinfo.email https://www.googleapis.com/auth/generative-language.retriever https://www.googleapis.com/auth/drive.readonly https://www.googleapis.com/auth/userinfo.profile",
|
||||
}, "google_one")
|
||||
@@ -93,21 +523,240 @@ func TestEffectiveOAuthConfig_ScopeFiltering(t *testing.T) {
|
||||
t.Fatalf("EffectiveOAuthConfig() error = %v", err)
|
||||
}
|
||||
|
||||
// Should only contain cloud-platform, userinfo.email, and userinfo.profile
|
||||
// Should NOT contain generative-language or drive scopes
|
||||
// 应仅包含 cloud-platform、userinfo.email 和 userinfo.profile
|
||||
// 不应包含 generative-language 或 drive scopes
|
||||
if strings.Contains(cfg.Scopes, "generative-language") {
|
||||
t.Errorf("Scopes should not contain generative-language when using built-in client, got: %v", cfg.Scopes)
|
||||
t.Errorf("使用内置客户端时 Scopes 不应包含 generative-language,实际: %v", cfg.Scopes)
|
||||
}
|
||||
if strings.Contains(cfg.Scopes, "drive") {
|
||||
t.Errorf("Scopes should not contain drive when using built-in client, got: %v", cfg.Scopes)
|
||||
t.Errorf("使用内置客户端时 Scopes 不应包含 drive,实际: %v", cfg.Scopes)
|
||||
}
|
||||
if !strings.Contains(cfg.Scopes, "cloud-platform") {
|
||||
t.Errorf("Scopes should contain cloud-platform, got: %v", cfg.Scopes)
|
||||
t.Errorf("Scopes 应包含 cloud-platform,实际: %v", cfg.Scopes)
|
||||
}
|
||||
if !strings.Contains(cfg.Scopes, "userinfo.email") {
|
||||
t.Errorf("Scopes should contain userinfo.email, got: %v", cfg.Scopes)
|
||||
t.Errorf("Scopes 应包含 userinfo.email,实际: %v", cfg.Scopes)
|
||||
}
|
||||
if !strings.Contains(cfg.Scopes, "userinfo.profile") {
|
||||
t.Errorf("Scopes should contain userinfo.profile, got: %v", cfg.Scopes)
|
||||
t.Errorf("Scopes 应包含 userinfo.profile,实际: %v", cfg.Scopes)
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// EffectiveOAuthConfig 测试 - 新增分支覆盖
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestEffectiveOAuthConfig_OnlyClientID_NoSecret(t *testing.T) {
|
||||
// 只提供 clientID 不提供 secret 应报错
|
||||
_, err := EffectiveOAuthConfig(OAuthConfig{
|
||||
ClientID: "some-client-id",
|
||||
}, "code_assist")
|
||||
if err == nil {
|
||||
t.Error("只提供 ClientID 不提供 ClientSecret 应该报错")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "client_id") || !strings.Contains(err.Error(), "client_secret") {
|
||||
t.Errorf("错误消息应提及 client_id 和 client_secret,实际: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEffectiveOAuthConfig_OnlyClientSecret_NoID(t *testing.T) {
|
||||
// 只提供 secret 不提供 clientID 应报错
|
||||
_, err := EffectiveOAuthConfig(OAuthConfig{
|
||||
ClientSecret: "some-client-secret",
|
||||
}, "code_assist")
|
||||
if err == nil {
|
||||
t.Error("只提供 ClientSecret 不提供 ClientID 应该报错")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "client_id") || !strings.Contains(err.Error(), "client_secret") {
|
||||
t.Errorf("错误消息应提及 client_id 和 client_secret,实际: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEffectiveOAuthConfig_AIStudio_DefaultScopes_BuiltinClient(t *testing.T) {
|
||||
t.Setenv(GeminiCLIOAuthClientSecretEnv, "test-built-in-secret")
|
||||
|
||||
// ai_studio 类型,使用内置客户端,scopes 为空 -> 应使用 DefaultCodeAssistScopes(因为内置客户端不能请求 generative-language scope)
|
||||
cfg, err := EffectiveOAuthConfig(OAuthConfig{}, "ai_studio")
|
||||
if err != nil {
|
||||
t.Fatalf("EffectiveOAuthConfig() error = %v", err)
|
||||
}
|
||||
if cfg.Scopes != DefaultCodeAssistScopes {
|
||||
t.Errorf("ai_studio + 内置客户端应使用 DefaultCodeAssistScopes,实际: %q", cfg.Scopes)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEffectiveOAuthConfig_AIStudio_DefaultScopes_CustomClient(t *testing.T) {
|
||||
// ai_studio 类型,使用自定义客户端,scopes 为空 -> 应使用 DefaultAIStudioScopes
|
||||
cfg, err := EffectiveOAuthConfig(OAuthConfig{
|
||||
ClientID: "custom-id",
|
||||
ClientSecret: "custom-secret",
|
||||
}, "ai_studio")
|
||||
if err != nil {
|
||||
t.Fatalf("EffectiveOAuthConfig() error = %v", err)
|
||||
}
|
||||
if cfg.Scopes != DefaultAIStudioScopes {
|
||||
t.Errorf("ai_studio + 自定义客户端应使用 DefaultAIStudioScopes,实际: %q", cfg.Scopes)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEffectiveOAuthConfig_AIStudio_ScopeNormalization(t *testing.T) {
|
||||
// ai_studio 类型,旧的 generative-language scope 应被归一化为 generative-language.retriever
|
||||
cfg, err := EffectiveOAuthConfig(OAuthConfig{
|
||||
ClientID: "custom-id",
|
||||
ClientSecret: "custom-secret",
|
||||
Scopes: "https://www.googleapis.com/auth/generative-language https://www.googleapis.com/auth/cloud-platform",
|
||||
}, "ai_studio")
|
||||
if err != nil {
|
||||
t.Fatalf("EffectiveOAuthConfig() error = %v", err)
|
||||
}
|
||||
if strings.Contains(cfg.Scopes, "auth/generative-language ") || strings.HasSuffix(cfg.Scopes, "auth/generative-language") {
|
||||
// 确保不包含未归一化的旧 scope(仅 generative-language 而非 generative-language.retriever)
|
||||
parts := strings.Fields(cfg.Scopes)
|
||||
for _, p := range parts {
|
||||
if p == "https://www.googleapis.com/auth/generative-language" {
|
||||
t.Errorf("ai_studio 应将 generative-language 归一化为 generative-language.retriever,实际 scopes: %q", cfg.Scopes)
|
||||
}
|
||||
}
|
||||
}
|
||||
if !strings.Contains(cfg.Scopes, "generative-language.retriever") {
|
||||
t.Errorf("ai_studio 归一化后应包含 generative-language.retriever,实际: %q", cfg.Scopes)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEffectiveOAuthConfig_CommaSeparatedScopes(t *testing.T) {
|
||||
t.Setenv(GeminiCLIOAuthClientSecretEnv, "test-built-in-secret")
|
||||
|
||||
// 逗号分隔的 scopes 应被归一化为空格分隔
|
||||
cfg, err := EffectiveOAuthConfig(OAuthConfig{
|
||||
ClientID: "custom-id",
|
||||
ClientSecret: "custom-secret",
|
||||
Scopes: "https://www.googleapis.com/auth/cloud-platform,https://www.googleapis.com/auth/userinfo.email",
|
||||
}, "code_assist")
|
||||
if err != nil {
|
||||
t.Fatalf("EffectiveOAuthConfig() error = %v", err)
|
||||
}
|
||||
// 应该用空格分隔,而非逗号
|
||||
if strings.Contains(cfg.Scopes, ",") {
|
||||
t.Errorf("逗号分隔的 scopes 应被归一化为空格分隔,实际: %q", cfg.Scopes)
|
||||
}
|
||||
if !strings.Contains(cfg.Scopes, "cloud-platform") {
|
||||
t.Errorf("归一化后应包含 cloud-platform,实际: %q", cfg.Scopes)
|
||||
}
|
||||
if !strings.Contains(cfg.Scopes, "userinfo.email") {
|
||||
t.Errorf("归一化后应包含 userinfo.email,实际: %q", cfg.Scopes)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEffectiveOAuthConfig_MixedCommaAndSpaceScopes(t *testing.T) {
|
||||
// 混合逗号和空格分隔的 scopes
|
||||
cfg, err := EffectiveOAuthConfig(OAuthConfig{
|
||||
ClientID: "custom-id",
|
||||
ClientSecret: "custom-secret",
|
||||
Scopes: "https://www.googleapis.com/auth/cloud-platform, https://www.googleapis.com/auth/userinfo.email https://www.googleapis.com/auth/userinfo.profile",
|
||||
}, "code_assist")
|
||||
if err != nil {
|
||||
t.Fatalf("EffectiveOAuthConfig() error = %v", err)
|
||||
}
|
||||
parts := strings.Fields(cfg.Scopes)
|
||||
if len(parts) != 3 {
|
||||
t.Errorf("归一化后应有 3 个 scope,实际: %d,scopes: %q", len(parts), cfg.Scopes)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEffectiveOAuthConfig_WhitespaceTriming(t *testing.T) {
|
||||
// 输入中的前后空白应被清理
|
||||
cfg, err := EffectiveOAuthConfig(OAuthConfig{
|
||||
ClientID: " custom-id ",
|
||||
ClientSecret: " custom-secret ",
|
||||
Scopes: " https://www.googleapis.com/auth/cloud-platform ",
|
||||
}, "code_assist")
|
||||
if err != nil {
|
||||
t.Fatalf("EffectiveOAuthConfig() error = %v", err)
|
||||
}
|
||||
if cfg.ClientID != "custom-id" {
|
||||
t.Errorf("ClientID 应去除前后空白,实际: %q", cfg.ClientID)
|
||||
}
|
||||
if cfg.ClientSecret != "custom-secret" {
|
||||
t.Errorf("ClientSecret 应去除前后空白,实际: %q", cfg.ClientSecret)
|
||||
}
|
||||
if cfg.Scopes != "https://www.googleapis.com/auth/cloud-platform" {
|
||||
t.Errorf("Scopes 应去除前后空白,实际: %q", cfg.Scopes)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEffectiveOAuthConfig_NoEnvSecret(t *testing.T) {
|
||||
// 不设置环境变量且不提供凭据,应该报错
|
||||
t.Setenv(GeminiCLIOAuthClientSecretEnv, "")
|
||||
|
||||
_, err := EffectiveOAuthConfig(OAuthConfig{}, "code_assist")
|
||||
if err == nil {
|
||||
t.Error("没有内置 secret 且未提供凭据时应该报错")
|
||||
}
|
||||
if !strings.Contains(err.Error(), GeminiCLIOAuthClientSecretEnv) {
|
||||
t.Errorf("错误消息应提及环境变量 %s,实际: %v", GeminiCLIOAuthClientSecretEnv, err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEffectiveOAuthConfig_AIStudio_BuiltinClient_CustomScopes(t *testing.T) {
|
||||
t.Setenv(GeminiCLIOAuthClientSecretEnv, "test-built-in-secret")
|
||||
|
||||
// ai_studio + 内置客户端 + 自定义 scopes -> 应过滤受限 scopes
|
||||
cfg, err := EffectiveOAuthConfig(OAuthConfig{
|
||||
Scopes: "https://www.googleapis.com/auth/cloud-platform https://www.googleapis.com/auth/generative-language.retriever",
|
||||
}, "ai_studio")
|
||||
if err != nil {
|
||||
t.Fatalf("EffectiveOAuthConfig() error = %v", err)
|
||||
}
|
||||
// 内置客户端应过滤 generative-language.retriever
|
||||
if strings.Contains(cfg.Scopes, "generative-language") {
|
||||
t.Errorf("ai_studio + 内置客户端应过滤受限 scopes,实际: %q", cfg.Scopes)
|
||||
}
|
||||
if !strings.Contains(cfg.Scopes, "cloud-platform") {
|
||||
t.Errorf("应保留 cloud-platform scope,实际: %q", cfg.Scopes)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEffectiveOAuthConfig_UnknownOAuthType_DefaultScopes(t *testing.T) {
|
||||
t.Setenv(GeminiCLIOAuthClientSecretEnv, "test-built-in-secret")
|
||||
|
||||
// 未知的 oauthType 应回退到默认的 code_assist scopes
|
||||
cfg, err := EffectiveOAuthConfig(OAuthConfig{}, "unknown_type")
|
||||
if err != nil {
|
||||
t.Fatalf("EffectiveOAuthConfig() error = %v", err)
|
||||
}
|
||||
if cfg.Scopes != DefaultCodeAssistScopes {
|
||||
t.Errorf("未知 oauthType 应使用 DefaultCodeAssistScopes,实际: %q", cfg.Scopes)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEffectiveOAuthConfig_EmptyOAuthType_DefaultScopes(t *testing.T) {
|
||||
t.Setenv(GeminiCLIOAuthClientSecretEnv, "test-built-in-secret")
|
||||
|
||||
// 空的 oauthType 应走 default 分支,使用 DefaultCodeAssistScopes
|
||||
cfg, err := EffectiveOAuthConfig(OAuthConfig{}, "")
|
||||
if err != nil {
|
||||
t.Fatalf("EffectiveOAuthConfig() error = %v", err)
|
||||
}
|
||||
if cfg.Scopes != DefaultCodeAssistScopes {
|
||||
t.Errorf("空 oauthType 应使用 DefaultCodeAssistScopes,实际: %q", cfg.Scopes)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEffectiveOAuthConfig_CustomClient_NoScopeFiltering(t *testing.T) {
|
||||
// 自定义客户端 + google_one + 包含受限 scopes -> 不应被过滤(因为不是内置客户端)
|
||||
cfg, err := EffectiveOAuthConfig(OAuthConfig{
|
||||
ClientID: "custom-id",
|
||||
ClientSecret: "custom-secret",
|
||||
Scopes: "https://www.googleapis.com/auth/generative-language.retriever https://www.googleapis.com/auth/drive.readonly",
|
||||
}, "google_one")
|
||||
if err != nil {
|
||||
t.Fatalf("EffectiveOAuthConfig() error = %v", err)
|
||||
}
|
||||
// 自定义客户端不应过滤任何 scope
|
||||
if !strings.Contains(cfg.Scopes, "generative-language.retriever") {
|
||||
t.Errorf("自定义客户端不应过滤 generative-language.retriever,实际: %q", cfg.Scopes)
|
||||
}
|
||||
if !strings.Contains(cfg.Scopes, "drive.readonly") {
|
||||
t.Errorf("自定义客户端不应过滤 drive.readonly,实际: %q", cfg.Scopes)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -14,6 +14,44 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// ---------- 辅助函数 ----------
|
||||
|
||||
// parseResponseBody 从 httptest.ResponseRecorder 中解析 JSON 响应体
|
||||
func parseResponseBody(t *testing.T, w *httptest.ResponseRecorder) Response {
|
||||
t.Helper()
|
||||
var got Response
|
||||
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &got))
|
||||
return got
|
||||
}
|
||||
|
||||
// parsePaginatedBody 从响应体中解析分页数据(Data 字段是 PaginatedData)
|
||||
func parsePaginatedBody(t *testing.T, w *httptest.ResponseRecorder) (Response, PaginatedData) {
|
||||
t.Helper()
|
||||
// 先用 raw json 解析,因为 Data 是 any 类型
|
||||
var raw struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Reason string `json:"reason,omitempty"`
|
||||
Data json.RawMessage `json:"data,omitempty"`
|
||||
}
|
||||
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &raw))
|
||||
|
||||
var pd PaginatedData
|
||||
require.NoError(t, json.Unmarshal(raw.Data, &pd))
|
||||
|
||||
return Response{Code: raw.Code, Message: raw.Message, Reason: raw.Reason}, pd
|
||||
}
|
||||
|
||||
// newContextWithQuery 创建一个带有 URL query 参数的 gin.Context 用于测试 ParsePagination
|
||||
func newContextWithQuery(query string) (*httptest.ResponseRecorder, *gin.Context) {
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/?"+query, nil)
|
||||
return w, c
|
||||
}
|
||||
|
||||
// ---------- 现有测试 ----------
|
||||
|
||||
func TestErrorWithDetails(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
@@ -169,3 +207,582 @@ func TestErrorFrom(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ---------- 新增测试 ----------
|
||||
|
||||
func TestSuccess(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
data any
|
||||
wantCode int
|
||||
wantBody Response
|
||||
}{
|
||||
{
|
||||
name: "返回字符串数据",
|
||||
data: "hello",
|
||||
wantCode: http.StatusOK,
|
||||
wantBody: Response{Code: 0, Message: "success", Data: "hello"},
|
||||
},
|
||||
{
|
||||
name: "返回nil数据",
|
||||
data: nil,
|
||||
wantCode: http.StatusOK,
|
||||
wantBody: Response{Code: 0, Message: "success"},
|
||||
},
|
||||
{
|
||||
name: "返回map数据",
|
||||
data: map[string]string{"key": "value"},
|
||||
wantCode: http.StatusOK,
|
||||
wantBody: Response{Code: 0, Message: "success"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
|
||||
Success(c, tt.data)
|
||||
|
||||
require.Equal(t, tt.wantCode, w.Code)
|
||||
|
||||
// 只验证 code 和 message,data 字段类型在 JSON 反序列化时会变成 map/slice
|
||||
got := parseResponseBody(t, w)
|
||||
require.Equal(t, 0, got.Code)
|
||||
require.Equal(t, "success", got.Message)
|
||||
|
||||
if tt.data == nil {
|
||||
require.Nil(t, got.Data)
|
||||
} else {
|
||||
require.NotNil(t, got.Data)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreated(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
data any
|
||||
wantCode int
|
||||
}{
|
||||
{
|
||||
name: "创建成功_返回数据",
|
||||
data: map[string]int{"id": 42},
|
||||
wantCode: http.StatusCreated,
|
||||
},
|
||||
{
|
||||
name: "创建成功_nil数据",
|
||||
data: nil,
|
||||
wantCode: http.StatusCreated,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
|
||||
Created(c, tt.data)
|
||||
|
||||
require.Equal(t, tt.wantCode, w.Code)
|
||||
|
||||
got := parseResponseBody(t, w)
|
||||
require.Equal(t, 0, got.Code)
|
||||
require.Equal(t, "success", got.Message)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestError(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
statusCode int
|
||||
message string
|
||||
}{
|
||||
{
|
||||
name: "400错误",
|
||||
statusCode: http.StatusBadRequest,
|
||||
message: "bad request",
|
||||
},
|
||||
{
|
||||
name: "500错误",
|
||||
statusCode: http.StatusInternalServerError,
|
||||
message: "internal error",
|
||||
},
|
||||
{
|
||||
name: "自定义状态码",
|
||||
statusCode: 418,
|
||||
message: "I'm a teapot",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
|
||||
Error(c, tt.statusCode, tt.message)
|
||||
|
||||
require.Equal(t, tt.statusCode, w.Code)
|
||||
|
||||
got := parseResponseBody(t, w)
|
||||
require.Equal(t, tt.statusCode, got.Code)
|
||||
require.Equal(t, tt.message, got.Message)
|
||||
require.Empty(t, got.Reason)
|
||||
require.Nil(t, got.Metadata)
|
||||
require.Nil(t, got.Data)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBadRequest(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
|
||||
BadRequest(c, "参数无效")
|
||||
|
||||
require.Equal(t, http.StatusBadRequest, w.Code)
|
||||
got := parseResponseBody(t, w)
|
||||
require.Equal(t, http.StatusBadRequest, got.Code)
|
||||
require.Equal(t, "参数无效", got.Message)
|
||||
}
|
||||
|
||||
func TestUnauthorized(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
|
||||
Unauthorized(c, "未登录")
|
||||
|
||||
require.Equal(t, http.StatusUnauthorized, w.Code)
|
||||
got := parseResponseBody(t, w)
|
||||
require.Equal(t, http.StatusUnauthorized, got.Code)
|
||||
require.Equal(t, "未登录", got.Message)
|
||||
}
|
||||
|
||||
func TestForbidden(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
|
||||
Forbidden(c, "无权限")
|
||||
|
||||
require.Equal(t, http.StatusForbidden, w.Code)
|
||||
got := parseResponseBody(t, w)
|
||||
require.Equal(t, http.StatusForbidden, got.Code)
|
||||
require.Equal(t, "无权限", got.Message)
|
||||
}
|
||||
|
||||
func TestNotFound(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
|
||||
NotFound(c, "资源不存在")
|
||||
|
||||
require.Equal(t, http.StatusNotFound, w.Code)
|
||||
got := parseResponseBody(t, w)
|
||||
require.Equal(t, http.StatusNotFound, got.Code)
|
||||
require.Equal(t, "资源不存在", got.Message)
|
||||
}
|
||||
|
||||
func TestInternalError(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
|
||||
InternalError(c, "服务器内部错误")
|
||||
|
||||
require.Equal(t, http.StatusInternalServerError, w.Code)
|
||||
got := parseResponseBody(t, w)
|
||||
require.Equal(t, http.StatusInternalServerError, got.Code)
|
||||
require.Equal(t, "服务器内部错误", got.Message)
|
||||
}
|
||||
|
||||
func TestPaginated(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
items any
|
||||
total int64
|
||||
page int
|
||||
pageSize int
|
||||
wantPages int
|
||||
wantTotal int64
|
||||
wantPage int
|
||||
wantPageSize int
|
||||
}{
|
||||
{
|
||||
name: "标准分页_多页",
|
||||
items: []string{"a", "b"},
|
||||
total: 25,
|
||||
page: 1,
|
||||
pageSize: 10,
|
||||
wantPages: 3,
|
||||
wantTotal: 25,
|
||||
wantPage: 1,
|
||||
wantPageSize: 10,
|
||||
},
|
||||
{
|
||||
name: "总数刚好整除",
|
||||
items: []string{"a"},
|
||||
total: 20,
|
||||
page: 2,
|
||||
pageSize: 10,
|
||||
wantPages: 2,
|
||||
wantTotal: 20,
|
||||
wantPage: 2,
|
||||
wantPageSize: 10,
|
||||
},
|
||||
{
|
||||
name: "总数为0_pages至少为1",
|
||||
items: []string{},
|
||||
total: 0,
|
||||
page: 1,
|
||||
pageSize: 10,
|
||||
wantPages: 1,
|
||||
wantTotal: 0,
|
||||
wantPage: 1,
|
||||
wantPageSize: 10,
|
||||
},
|
||||
{
|
||||
name: "单页数据",
|
||||
items: []int{1, 2, 3},
|
||||
total: 3,
|
||||
page: 1,
|
||||
pageSize: 20,
|
||||
wantPages: 1,
|
||||
wantTotal: 3,
|
||||
wantPage: 1,
|
||||
wantPageSize: 20,
|
||||
},
|
||||
{
|
||||
name: "总数为1",
|
||||
items: []string{"only"},
|
||||
total: 1,
|
||||
page: 1,
|
||||
pageSize: 10,
|
||||
wantPages: 1,
|
||||
wantTotal: 1,
|
||||
wantPage: 1,
|
||||
wantPageSize: 10,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
|
||||
Paginated(c, tt.items, tt.total, tt.page, tt.pageSize)
|
||||
|
||||
require.Equal(t, http.StatusOK, w.Code)
|
||||
|
||||
resp, pd := parsePaginatedBody(t, w)
|
||||
require.Equal(t, 0, resp.Code)
|
||||
require.Equal(t, "success", resp.Message)
|
||||
require.Equal(t, tt.wantTotal, pd.Total)
|
||||
require.Equal(t, tt.wantPage, pd.Page)
|
||||
require.Equal(t, tt.wantPageSize, pd.PageSize)
|
||||
require.Equal(t, tt.wantPages, pd.Pages)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPaginatedWithResult(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
items any
|
||||
pagination *PaginationResult
|
||||
wantTotal int64
|
||||
wantPage int
|
||||
wantPageSize int
|
||||
wantPages int
|
||||
}{
|
||||
{
|
||||
name: "正常分页结果",
|
||||
items: []string{"a", "b"},
|
||||
pagination: &PaginationResult{
|
||||
Total: 50,
|
||||
Page: 3,
|
||||
PageSize: 10,
|
||||
Pages: 5,
|
||||
},
|
||||
wantTotal: 50,
|
||||
wantPage: 3,
|
||||
wantPageSize: 10,
|
||||
wantPages: 5,
|
||||
},
|
||||
{
|
||||
name: "pagination为nil_使用默认值",
|
||||
items: []string{},
|
||||
pagination: nil,
|
||||
wantTotal: 0,
|
||||
wantPage: 1,
|
||||
wantPageSize: 20,
|
||||
wantPages: 1,
|
||||
},
|
||||
{
|
||||
name: "单页结果",
|
||||
items: []int{1},
|
||||
pagination: &PaginationResult{
|
||||
Total: 1,
|
||||
Page: 1,
|
||||
PageSize: 20,
|
||||
Pages: 1,
|
||||
},
|
||||
wantTotal: 1,
|
||||
wantPage: 1,
|
||||
wantPageSize: 20,
|
||||
wantPages: 1,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
|
||||
PaginatedWithResult(c, tt.items, tt.pagination)
|
||||
|
||||
require.Equal(t, http.StatusOK, w.Code)
|
||||
|
||||
resp, pd := parsePaginatedBody(t, w)
|
||||
require.Equal(t, 0, resp.Code)
|
||||
require.Equal(t, "success", resp.Message)
|
||||
require.Equal(t, tt.wantTotal, pd.Total)
|
||||
require.Equal(t, tt.wantPage, pd.Page)
|
||||
require.Equal(t, tt.wantPageSize, pd.PageSize)
|
||||
require.Equal(t, tt.wantPages, pd.Pages)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParsePagination(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
query string
|
||||
wantPage int
|
||||
wantPageSize int
|
||||
}{
|
||||
{
|
||||
name: "无参数_使用默认值",
|
||||
query: "",
|
||||
wantPage: 1,
|
||||
wantPageSize: 20,
|
||||
},
|
||||
{
|
||||
name: "仅指定page",
|
||||
query: "page=3",
|
||||
wantPage: 3,
|
||||
wantPageSize: 20,
|
||||
},
|
||||
{
|
||||
name: "仅指定page_size",
|
||||
query: "page_size=50",
|
||||
wantPage: 1,
|
||||
wantPageSize: 50,
|
||||
},
|
||||
{
|
||||
name: "同时指定page和page_size",
|
||||
query: "page=2&page_size=30",
|
||||
wantPage: 2,
|
||||
wantPageSize: 30,
|
||||
},
|
||||
{
|
||||
name: "使用limit代替page_size",
|
||||
query: "limit=15",
|
||||
wantPage: 1,
|
||||
wantPageSize: 15,
|
||||
},
|
||||
{
|
||||
name: "page_size优先于limit",
|
||||
query: "page_size=25&limit=50",
|
||||
wantPage: 1,
|
||||
wantPageSize: 25,
|
||||
},
|
||||
{
|
||||
name: "page为0_使用默认值",
|
||||
query: "page=0",
|
||||
wantPage: 1,
|
||||
wantPageSize: 20,
|
||||
},
|
||||
{
|
||||
name: "page_size超过1000_使用默认值",
|
||||
query: "page_size=1001",
|
||||
wantPage: 1,
|
||||
wantPageSize: 20,
|
||||
},
|
||||
{
|
||||
name: "page_size恰好1000_有效",
|
||||
query: "page_size=1000",
|
||||
wantPage: 1,
|
||||
wantPageSize: 1000,
|
||||
},
|
||||
{
|
||||
name: "page为非数字_使用默认值",
|
||||
query: "page=abc",
|
||||
wantPage: 1,
|
||||
wantPageSize: 20,
|
||||
},
|
||||
{
|
||||
name: "page_size为非数字_使用默认值",
|
||||
query: "page_size=xyz",
|
||||
wantPage: 1,
|
||||
wantPageSize: 20,
|
||||
},
|
||||
{
|
||||
name: "limit为非数字_使用默认值",
|
||||
query: "limit=abc",
|
||||
wantPage: 1,
|
||||
wantPageSize: 20,
|
||||
},
|
||||
{
|
||||
name: "page_size为0_使用默认值",
|
||||
query: "page_size=0",
|
||||
wantPage: 1,
|
||||
wantPageSize: 20,
|
||||
},
|
||||
{
|
||||
name: "limit为0_使用默认值",
|
||||
query: "limit=0",
|
||||
wantPage: 1,
|
||||
wantPageSize: 20,
|
||||
},
|
||||
{
|
||||
name: "大页码",
|
||||
query: "page=999&page_size=100",
|
||||
wantPage: 999,
|
||||
wantPageSize: 100,
|
||||
},
|
||||
{
|
||||
name: "page_size为1_最小有效值",
|
||||
query: "page_size=1",
|
||||
wantPage: 1,
|
||||
wantPageSize: 1,
|
||||
},
|
||||
{
|
||||
name: "混合数字和字母的page",
|
||||
query: "page=12a",
|
||||
wantPage: 1,
|
||||
wantPageSize: 20,
|
||||
},
|
||||
{
|
||||
name: "limit超过1000_使用默认值",
|
||||
query: "limit=2000",
|
||||
wantPage: 1,
|
||||
wantPageSize: 20,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
_, c := newContextWithQuery(tt.query)
|
||||
|
||||
page, pageSize := ParsePagination(c)
|
||||
|
||||
require.Equal(t, tt.wantPage, page, "page 不符合预期")
|
||||
require.Equal(t, tt.wantPageSize, pageSize, "pageSize 不符合预期")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_parseInt(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
wantVal int
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "正常数字",
|
||||
input: "123",
|
||||
wantVal: 123,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "零",
|
||||
input: "0",
|
||||
wantVal: 0,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "单个数字",
|
||||
input: "5",
|
||||
wantVal: 5,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "大数字",
|
||||
input: "99999",
|
||||
wantVal: 99999,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "包含字母_返回0",
|
||||
input: "abc",
|
||||
wantVal: 0,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "数字开头接字母_返回0",
|
||||
input: "12a",
|
||||
wantVal: 0,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "包含负号_返回0",
|
||||
input: "-1",
|
||||
wantVal: 0,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "包含小数点_返回0",
|
||||
input: "1.5",
|
||||
wantVal: 0,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "包含空格_返回0",
|
||||
input: "1 2",
|
||||
wantVal: 0,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "空字符串",
|
||||
input: "",
|
||||
wantVal: 0,
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
val, err := parseInt(tt.input)
|
||||
if tt.wantErr {
|
||||
require.Error(t, err)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
require.Equal(t, tt.wantVal, val)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
607
backend/internal/service/oauth_service_test.go
Normal file
607
backend/internal/service/oauth_service_test.go
Normal file
@@ -0,0 +1,607 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/oauth"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
)
|
||||
|
||||
// --- mock: ClaudeOAuthClient ---
|
||||
|
||||
type mockClaudeOAuthClient struct {
|
||||
getOrgUUIDFunc func(ctx context.Context, sessionKey, proxyURL string) (string, error)
|
||||
getAuthCodeFunc func(ctx context.Context, sessionKey, orgUUID, scope, codeChallenge, state, proxyURL string) (string, error)
|
||||
exchangeCodeFunc func(ctx context.Context, code, codeVerifier, state, proxyURL string, isSetupToken bool) (*oauth.TokenResponse, error)
|
||||
refreshTokenFunc func(ctx context.Context, refreshToken, proxyURL string) (*oauth.TokenResponse, error)
|
||||
}
|
||||
|
||||
func (m *mockClaudeOAuthClient) GetOrganizationUUID(ctx context.Context, sessionKey, proxyURL string) (string, error) {
|
||||
if m.getOrgUUIDFunc != nil {
|
||||
return m.getOrgUUIDFunc(ctx, sessionKey, proxyURL)
|
||||
}
|
||||
panic("GetOrganizationUUID not implemented")
|
||||
}
|
||||
|
||||
func (m *mockClaudeOAuthClient) GetAuthorizationCode(ctx context.Context, sessionKey, orgUUID, scope, codeChallenge, state, proxyURL string) (string, error) {
|
||||
if m.getAuthCodeFunc != nil {
|
||||
return m.getAuthCodeFunc(ctx, sessionKey, orgUUID, scope, codeChallenge, state, proxyURL)
|
||||
}
|
||||
panic("GetAuthorizationCode not implemented")
|
||||
}
|
||||
|
||||
func (m *mockClaudeOAuthClient) ExchangeCodeForToken(ctx context.Context, code, codeVerifier, state, proxyURL string, isSetupToken bool) (*oauth.TokenResponse, error) {
|
||||
if m.exchangeCodeFunc != nil {
|
||||
return m.exchangeCodeFunc(ctx, code, codeVerifier, state, proxyURL, isSetupToken)
|
||||
}
|
||||
panic("ExchangeCodeForToken not implemented")
|
||||
}
|
||||
|
||||
func (m *mockClaudeOAuthClient) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*oauth.TokenResponse, error) {
|
||||
if m.refreshTokenFunc != nil {
|
||||
return m.refreshTokenFunc(ctx, refreshToken, proxyURL)
|
||||
}
|
||||
panic("RefreshToken not implemented")
|
||||
}
|
||||
|
||||
// --- mock: ProxyRepository (最小实现,仅覆盖 OAuthService 依赖的方法) ---
|
||||
|
||||
type mockProxyRepoForOAuth struct {
|
||||
getByIDFunc func(ctx context.Context, id int64) (*Proxy, error)
|
||||
}
|
||||
|
||||
func (m *mockProxyRepoForOAuth) Create(ctx context.Context, proxy *Proxy) error {
|
||||
panic("Create not implemented")
|
||||
}
|
||||
func (m *mockProxyRepoForOAuth) GetByID(ctx context.Context, id int64) (*Proxy, error) {
|
||||
if m.getByIDFunc != nil {
|
||||
return m.getByIDFunc(ctx, id)
|
||||
}
|
||||
return nil, fmt.Errorf("proxy not found")
|
||||
}
|
||||
func (m *mockProxyRepoForOAuth) ListByIDs(ctx context.Context, ids []int64) ([]Proxy, error) {
|
||||
panic("ListByIDs not implemented")
|
||||
}
|
||||
func (m *mockProxyRepoForOAuth) Update(ctx context.Context, proxy *Proxy) error {
|
||||
panic("Update not implemented")
|
||||
}
|
||||
func (m *mockProxyRepoForOAuth) Delete(ctx context.Context, id int64) error {
|
||||
panic("Delete not implemented")
|
||||
}
|
||||
func (m *mockProxyRepoForOAuth) List(ctx context.Context, params pagination.PaginationParams) ([]Proxy, *pagination.PaginationResult, error) {
|
||||
panic("List not implemented")
|
||||
}
|
||||
func (m *mockProxyRepoForOAuth) ListWithFilters(ctx context.Context, params pagination.PaginationParams, protocol, status, search string) ([]Proxy, *pagination.PaginationResult, error) {
|
||||
panic("ListWithFilters not implemented")
|
||||
}
|
||||
func (m *mockProxyRepoForOAuth) ListWithFiltersAndAccountCount(ctx context.Context, params pagination.PaginationParams, protocol, status, search string) ([]ProxyWithAccountCount, *pagination.PaginationResult, error) {
|
||||
panic("ListWithFiltersAndAccountCount not implemented")
|
||||
}
|
||||
func (m *mockProxyRepoForOAuth) ListActive(ctx context.Context) ([]Proxy, error) {
|
||||
panic("ListActive not implemented")
|
||||
}
|
||||
func (m *mockProxyRepoForOAuth) ListActiveWithAccountCount(ctx context.Context) ([]ProxyWithAccountCount, error) {
|
||||
panic("ListActiveWithAccountCount not implemented")
|
||||
}
|
||||
func (m *mockProxyRepoForOAuth) ExistsByHostPortAuth(ctx context.Context, host string, port int, username, password string) (bool, error) {
|
||||
panic("ExistsByHostPortAuth not implemented")
|
||||
}
|
||||
func (m *mockProxyRepoForOAuth) CountAccountsByProxyID(ctx context.Context, proxyID int64) (int64, error) {
|
||||
panic("CountAccountsByProxyID not implemented")
|
||||
}
|
||||
func (m *mockProxyRepoForOAuth) ListAccountSummariesByProxyID(ctx context.Context, proxyID int64) ([]ProxyAccountSummary, error) {
|
||||
panic("ListAccountSummariesByProxyID not implemented")
|
||||
}
|
||||
|
||||
// =====================
|
||||
// 测试用例
|
||||
// =====================
|
||||
|
||||
func TestNewOAuthService(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
proxyRepo := &mockProxyRepoForOAuth{}
|
||||
client := &mockClaudeOAuthClient{}
|
||||
svc := NewOAuthService(proxyRepo, client)
|
||||
|
||||
if svc == nil {
|
||||
t.Fatal("NewOAuthService 返回 nil")
|
||||
}
|
||||
if svc.proxyRepo != proxyRepo {
|
||||
t.Fatal("proxyRepo 未正确设置")
|
||||
}
|
||||
if svc.oauthClient != client {
|
||||
t.Fatal("oauthClient 未正确设置")
|
||||
}
|
||||
if svc.sessionStore == nil {
|
||||
t.Fatal("sessionStore 应被自动初始化")
|
||||
}
|
||||
|
||||
// 清理
|
||||
svc.Stop()
|
||||
}
|
||||
|
||||
func TestOAuthService_GenerateAuthURL(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
svc := NewOAuthService(&mockProxyRepoForOAuth{}, &mockClaudeOAuthClient{})
|
||||
defer svc.Stop()
|
||||
|
||||
result, err := svc.GenerateAuthURL(context.Background(), nil)
|
||||
if err != nil {
|
||||
t.Fatalf("GenerateAuthURL 返回错误: %v", err)
|
||||
}
|
||||
if result == nil {
|
||||
t.Fatal("GenerateAuthURL 返回 nil")
|
||||
}
|
||||
if result.AuthURL == "" {
|
||||
t.Fatal("AuthURL 为空")
|
||||
}
|
||||
if result.SessionID == "" {
|
||||
t.Fatal("SessionID 为空")
|
||||
}
|
||||
|
||||
// 验证 session 已存储
|
||||
session, ok := svc.sessionStore.Get(result.SessionID)
|
||||
if !ok {
|
||||
t.Fatal("session 未在 sessionStore 中找到")
|
||||
}
|
||||
if session.Scope != oauth.ScopeOAuth {
|
||||
t.Fatalf("scope 不匹配: got=%q want=%q", session.Scope, oauth.ScopeOAuth)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOAuthService_GenerateAuthURL_WithProxy(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
proxyRepo := &mockProxyRepoForOAuth{
|
||||
getByIDFunc: func(ctx context.Context, id int64) (*Proxy, error) {
|
||||
return &Proxy{
|
||||
ID: 1,
|
||||
Protocol: "http",
|
||||
Host: "proxy.example.com",
|
||||
Port: 8080,
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
svc := NewOAuthService(proxyRepo, &mockClaudeOAuthClient{})
|
||||
defer svc.Stop()
|
||||
|
||||
proxyID := int64(1)
|
||||
result, err := svc.GenerateAuthURL(context.Background(), &proxyID)
|
||||
if err != nil {
|
||||
t.Fatalf("GenerateAuthURL 返回错误: %v", err)
|
||||
}
|
||||
|
||||
session, ok := svc.sessionStore.Get(result.SessionID)
|
||||
if !ok {
|
||||
t.Fatal("session 未在 sessionStore 中找到")
|
||||
}
|
||||
if session.ProxyURL != "http://proxy.example.com:8080" {
|
||||
t.Fatalf("ProxyURL 不匹配: got=%q", session.ProxyURL)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOAuthService_GenerateSetupTokenURL(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
svc := NewOAuthService(&mockProxyRepoForOAuth{}, &mockClaudeOAuthClient{})
|
||||
defer svc.Stop()
|
||||
|
||||
result, err := svc.GenerateSetupTokenURL(context.Background(), nil)
|
||||
if err != nil {
|
||||
t.Fatalf("GenerateSetupTokenURL 返回错误: %v", err)
|
||||
}
|
||||
if result == nil {
|
||||
t.Fatal("GenerateSetupTokenURL 返回 nil")
|
||||
}
|
||||
|
||||
// 验证 scope 是 inference
|
||||
session, ok := svc.sessionStore.Get(result.SessionID)
|
||||
if !ok {
|
||||
t.Fatal("session 未在 sessionStore 中找到")
|
||||
}
|
||||
if session.Scope != oauth.ScopeInference {
|
||||
t.Fatalf("scope 不匹配: got=%q want=%q", session.Scope, oauth.ScopeInference)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOAuthService_ExchangeCode_SessionNotFound(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
svc := NewOAuthService(&mockProxyRepoForOAuth{}, &mockClaudeOAuthClient{})
|
||||
defer svc.Stop()
|
||||
|
||||
_, err := svc.ExchangeCode(context.Background(), &ExchangeCodeInput{
|
||||
SessionID: "nonexistent-session",
|
||||
Code: "test-code",
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatal("ExchangeCode 应返回错误(session 不存在)")
|
||||
}
|
||||
if err.Error() != "session not found or expired" {
|
||||
t.Fatalf("错误信息不匹配: got=%q", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func TestOAuthService_ExchangeCode_Success(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
exchangeCalled := false
|
||||
client := &mockClaudeOAuthClient{
|
||||
exchangeCodeFunc: func(ctx context.Context, code, codeVerifier, state, proxyURL string, isSetupToken bool) (*oauth.TokenResponse, error) {
|
||||
exchangeCalled = true
|
||||
if code != "auth-code-123" {
|
||||
t.Errorf("code 不匹配: got=%q", code)
|
||||
}
|
||||
if isSetupToken {
|
||||
t.Error("isSetupToken 应为 false(ScopeOAuth)")
|
||||
}
|
||||
return &oauth.TokenResponse{
|
||||
AccessToken: "access-token-abc",
|
||||
TokenType: "Bearer",
|
||||
ExpiresIn: 3600,
|
||||
RefreshToken: "refresh-token-xyz",
|
||||
Scope: oauth.ScopeOAuth,
|
||||
Organization: &oauth.OrgInfo{UUID: "org-uuid-111"},
|
||||
Account: &oauth.AccountInfo{UUID: "acc-uuid-222", EmailAddress: "test@example.com"},
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
|
||||
svc := NewOAuthService(&mockProxyRepoForOAuth{}, client)
|
||||
defer svc.Stop()
|
||||
|
||||
// 先生成 URL 以创建 session
|
||||
result, err := svc.GenerateAuthURL(context.Background(), nil)
|
||||
if err != nil {
|
||||
t.Fatalf("GenerateAuthURL 返回错误: %v", err)
|
||||
}
|
||||
|
||||
// 交换 code
|
||||
tokenInfo, err := svc.ExchangeCode(context.Background(), &ExchangeCodeInput{
|
||||
SessionID: result.SessionID,
|
||||
Code: "auth-code-123",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("ExchangeCode 返回错误: %v", err)
|
||||
}
|
||||
|
||||
if !exchangeCalled {
|
||||
t.Fatal("ExchangeCodeForToken 未被调用")
|
||||
}
|
||||
if tokenInfo.AccessToken != "access-token-abc" {
|
||||
t.Fatalf("AccessToken 不匹配: got=%q", tokenInfo.AccessToken)
|
||||
}
|
||||
if tokenInfo.TokenType != "Bearer" {
|
||||
t.Fatalf("TokenType 不匹配: got=%q", tokenInfo.TokenType)
|
||||
}
|
||||
if tokenInfo.RefreshToken != "refresh-token-xyz" {
|
||||
t.Fatalf("RefreshToken 不匹配: got=%q", tokenInfo.RefreshToken)
|
||||
}
|
||||
if tokenInfo.OrgUUID != "org-uuid-111" {
|
||||
t.Fatalf("OrgUUID 不匹配: got=%q", tokenInfo.OrgUUID)
|
||||
}
|
||||
if tokenInfo.AccountUUID != "acc-uuid-222" {
|
||||
t.Fatalf("AccountUUID 不匹配: got=%q", tokenInfo.AccountUUID)
|
||||
}
|
||||
if tokenInfo.EmailAddress != "test@example.com" {
|
||||
t.Fatalf("EmailAddress 不匹配: got=%q", tokenInfo.EmailAddress)
|
||||
}
|
||||
if tokenInfo.ExpiresIn != 3600 {
|
||||
t.Fatalf("ExpiresIn 不匹配: got=%d", tokenInfo.ExpiresIn)
|
||||
}
|
||||
if tokenInfo.ExpiresAt == 0 {
|
||||
t.Fatal("ExpiresAt 不应为 0")
|
||||
}
|
||||
|
||||
// 验证 session 已被删除
|
||||
_, ok := svc.sessionStore.Get(result.SessionID)
|
||||
if ok {
|
||||
t.Fatal("session 应在交换成功后被删除")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOAuthService_ExchangeCode_SetupToken(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
client := &mockClaudeOAuthClient{
|
||||
exchangeCodeFunc: func(ctx context.Context, code, codeVerifier, state, proxyURL string, isSetupToken bool) (*oauth.TokenResponse, error) {
|
||||
if !isSetupToken {
|
||||
t.Error("isSetupToken 应为 true(ScopeInference)")
|
||||
}
|
||||
return &oauth.TokenResponse{
|
||||
AccessToken: "setup-token",
|
||||
TokenType: "Bearer",
|
||||
ExpiresIn: 3600,
|
||||
Scope: oauth.ScopeInference,
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
|
||||
svc := NewOAuthService(&mockProxyRepoForOAuth{}, client)
|
||||
defer svc.Stop()
|
||||
|
||||
// 使用 SetupToken URL(inference scope)
|
||||
result, err := svc.GenerateSetupTokenURL(context.Background(), nil)
|
||||
if err != nil {
|
||||
t.Fatalf("GenerateSetupTokenURL 返回错误: %v", err)
|
||||
}
|
||||
|
||||
tokenInfo, err := svc.ExchangeCode(context.Background(), &ExchangeCodeInput{
|
||||
SessionID: result.SessionID,
|
||||
Code: "setup-code",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("ExchangeCode 返回错误: %v", err)
|
||||
}
|
||||
if tokenInfo.AccessToken != "setup-token" {
|
||||
t.Fatalf("AccessToken 不匹配: got=%q", tokenInfo.AccessToken)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOAuthService_ExchangeCode_ClientError(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
client := &mockClaudeOAuthClient{
|
||||
exchangeCodeFunc: func(ctx context.Context, code, codeVerifier, state, proxyURL string, isSetupToken bool) (*oauth.TokenResponse, error) {
|
||||
return nil, fmt.Errorf("upstream error: invalid code")
|
||||
},
|
||||
}
|
||||
|
||||
svc := NewOAuthService(&mockProxyRepoForOAuth{}, client)
|
||||
defer svc.Stop()
|
||||
|
||||
result, _ := svc.GenerateAuthURL(context.Background(), nil)
|
||||
_, err := svc.ExchangeCode(context.Background(), &ExchangeCodeInput{
|
||||
SessionID: result.SessionID,
|
||||
Code: "bad-code",
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatal("ExchangeCode 应返回错误")
|
||||
}
|
||||
if err.Error() != "upstream error: invalid code" {
|
||||
t.Fatalf("错误信息不匹配: got=%q", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func TestOAuthService_RefreshToken(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
client := &mockClaudeOAuthClient{
|
||||
refreshTokenFunc: func(ctx context.Context, refreshToken, proxyURL string) (*oauth.TokenResponse, error) {
|
||||
if refreshToken != "my-refresh-token" {
|
||||
t.Errorf("refreshToken 不匹配: got=%q", refreshToken)
|
||||
}
|
||||
if proxyURL != "" {
|
||||
t.Errorf("proxyURL 应为空: got=%q", proxyURL)
|
||||
}
|
||||
return &oauth.TokenResponse{
|
||||
AccessToken: "new-access-token",
|
||||
TokenType: "Bearer",
|
||||
ExpiresIn: 7200,
|
||||
RefreshToken: "new-refresh-token",
|
||||
Scope: oauth.ScopeOAuth,
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
|
||||
svc := NewOAuthService(&mockProxyRepoForOAuth{}, client)
|
||||
defer svc.Stop()
|
||||
|
||||
tokenInfo, err := svc.RefreshToken(context.Background(), "my-refresh-token", "")
|
||||
if err != nil {
|
||||
t.Fatalf("RefreshToken 返回错误: %v", err)
|
||||
}
|
||||
if tokenInfo.AccessToken != "new-access-token" {
|
||||
t.Fatalf("AccessToken 不匹配: got=%q", tokenInfo.AccessToken)
|
||||
}
|
||||
if tokenInfo.RefreshToken != "new-refresh-token" {
|
||||
t.Fatalf("RefreshToken 不匹配: got=%q", tokenInfo.RefreshToken)
|
||||
}
|
||||
if tokenInfo.ExpiresIn != 7200 {
|
||||
t.Fatalf("ExpiresIn 不匹配: got=%d", tokenInfo.ExpiresIn)
|
||||
}
|
||||
if tokenInfo.ExpiresAt == 0 {
|
||||
t.Fatal("ExpiresAt 不应为 0")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOAuthService_RefreshToken_Error(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
client := &mockClaudeOAuthClient{
|
||||
refreshTokenFunc: func(ctx context.Context, refreshToken, proxyURL string) (*oauth.TokenResponse, error) {
|
||||
return nil, fmt.Errorf("invalid_grant: token expired")
|
||||
},
|
||||
}
|
||||
|
||||
svc := NewOAuthService(&mockProxyRepoForOAuth{}, client)
|
||||
defer svc.Stop()
|
||||
|
||||
_, err := svc.RefreshToken(context.Background(), "expired-token", "")
|
||||
if err == nil {
|
||||
t.Fatal("RefreshToken 应返回错误")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOAuthService_RefreshAccountToken_NoRefreshToken(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
svc := NewOAuthService(&mockProxyRepoForOAuth{}, &mockClaudeOAuthClient{})
|
||||
defer svc.Stop()
|
||||
|
||||
// 无 refresh_token 的账号
|
||||
account := &Account{
|
||||
ID: 1,
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "some-token",
|
||||
},
|
||||
}
|
||||
_, err := svc.RefreshAccountToken(context.Background(), account)
|
||||
if err == nil {
|
||||
t.Fatal("RefreshAccountToken 应返回错误(无 refresh_token)")
|
||||
}
|
||||
if err.Error() != "no refresh token available" {
|
||||
t.Fatalf("错误信息不匹配: got=%q", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func TestOAuthService_RefreshAccountToken_EmptyRefreshToken(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
svc := NewOAuthService(&mockProxyRepoForOAuth{}, &mockClaudeOAuthClient{})
|
||||
defer svc.Stop()
|
||||
|
||||
account := &Account{
|
||||
ID: 2,
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "some-token",
|
||||
"refresh_token": "",
|
||||
},
|
||||
}
|
||||
_, err := svc.RefreshAccountToken(context.Background(), account)
|
||||
if err == nil {
|
||||
t.Fatal("RefreshAccountToken 应返回错误(refresh_token 为空)")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOAuthService_RefreshAccountToken_Success(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
client := &mockClaudeOAuthClient{
|
||||
refreshTokenFunc: func(ctx context.Context, refreshToken, proxyURL string) (*oauth.TokenResponse, error) {
|
||||
if refreshToken != "account-refresh-token" {
|
||||
t.Errorf("refreshToken 不匹配: got=%q", refreshToken)
|
||||
}
|
||||
return &oauth.TokenResponse{
|
||||
AccessToken: "refreshed-access",
|
||||
TokenType: "Bearer",
|
||||
ExpiresIn: 3600,
|
||||
RefreshToken: "new-refresh",
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
|
||||
svc := NewOAuthService(&mockProxyRepoForOAuth{}, client)
|
||||
defer svc.Stop()
|
||||
|
||||
account := &Account{
|
||||
ID: 3,
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "old-access",
|
||||
"refresh_token": "account-refresh-token",
|
||||
},
|
||||
}
|
||||
|
||||
tokenInfo, err := svc.RefreshAccountToken(context.Background(), account)
|
||||
if err != nil {
|
||||
t.Fatalf("RefreshAccountToken 返回错误: %v", err)
|
||||
}
|
||||
if tokenInfo.AccessToken != "refreshed-access" {
|
||||
t.Fatalf("AccessToken 不匹配: got=%q", tokenInfo.AccessToken)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOAuthService_RefreshAccountToken_WithProxy(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
proxyRepo := &mockProxyRepoForOAuth{
|
||||
getByIDFunc: func(ctx context.Context, id int64) (*Proxy, error) {
|
||||
return &Proxy{
|
||||
Protocol: "socks5",
|
||||
Host: "socks.example.com",
|
||||
Port: 1080,
|
||||
Username: "user",
|
||||
Password: "pass",
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
|
||||
client := &mockClaudeOAuthClient{
|
||||
refreshTokenFunc: func(ctx context.Context, refreshToken, proxyURL string) (*oauth.TokenResponse, error) {
|
||||
if proxyURL != "socks5://user:pass@socks.example.com:1080" {
|
||||
t.Errorf("proxyURL 不匹配: got=%q", proxyURL)
|
||||
}
|
||||
return &oauth.TokenResponse{
|
||||
AccessToken: "refreshed",
|
||||
ExpiresIn: 3600,
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
|
||||
svc := NewOAuthService(proxyRepo, client)
|
||||
defer svc.Stop()
|
||||
|
||||
proxyID := int64(10)
|
||||
account := &Account{
|
||||
ID: 4,
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeOAuth,
|
||||
ProxyID: &proxyID,
|
||||
Credentials: map[string]any{
|
||||
"refresh_token": "rt-with-proxy",
|
||||
},
|
||||
}
|
||||
|
||||
_, err := svc.RefreshAccountToken(context.Background(), account)
|
||||
if err != nil {
|
||||
t.Fatalf("RefreshAccountToken 返回错误: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOAuthService_ExchangeCode_NilOrg(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
client := &mockClaudeOAuthClient{
|
||||
exchangeCodeFunc: func(ctx context.Context, code, codeVerifier, state, proxyURL string, isSetupToken bool) (*oauth.TokenResponse, error) {
|
||||
return &oauth.TokenResponse{
|
||||
AccessToken: "token-no-org",
|
||||
TokenType: "Bearer",
|
||||
ExpiresIn: 3600,
|
||||
Organization: nil,
|
||||
Account: nil,
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
|
||||
svc := NewOAuthService(&mockProxyRepoForOAuth{}, client)
|
||||
defer svc.Stop()
|
||||
|
||||
result, _ := svc.GenerateAuthURL(context.Background(), nil)
|
||||
tokenInfo, err := svc.ExchangeCode(context.Background(), &ExchangeCodeInput{
|
||||
SessionID: result.SessionID,
|
||||
Code: "code",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("ExchangeCode 返回错误: %v", err)
|
||||
}
|
||||
if tokenInfo.OrgUUID != "" {
|
||||
t.Fatalf("OrgUUID 应为空: got=%q", tokenInfo.OrgUUID)
|
||||
}
|
||||
if tokenInfo.AccountUUID != "" {
|
||||
t.Fatalf("AccountUUID 应为空: got=%q", tokenInfo.AccountUUID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOAuthService_Stop_NoPanic(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
svc := NewOAuthService(&mockProxyRepoForOAuth{}, &mockClaudeOAuthClient{})
|
||||
|
||||
// 调用 Stop 不应 panic
|
||||
svc.Stop()
|
||||
|
||||
// 多次调用也不应 panic
|
||||
svc.Stop()
|
||||
}
|
||||
39
backend/internal/util/logredact/redact_test.go
Normal file
39
backend/internal/util/logredact/redact_test.go
Normal file
@@ -0,0 +1,39 @@
|
||||
package logredact
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestRedactText_JSONLike(t *testing.T) {
|
||||
in := `{"access_token":"ya29.a0AfH6SMDUMMY","refresh_token":"1//0gDUMMY","other":"ok"}`
|
||||
out := RedactText(in)
|
||||
if out == in {
|
||||
t.Fatalf("expected redaction, got unchanged")
|
||||
}
|
||||
if want := `"access_token":"***"`; !strings.Contains(out, want) {
|
||||
t.Fatalf("expected %q in %q", want, out)
|
||||
}
|
||||
if want := `"refresh_token":"***"`; !strings.Contains(out, want) {
|
||||
t.Fatalf("expected %q in %q", want, out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRedactText_QueryLike(t *testing.T) {
|
||||
in := "access_token=ya29.a0AfH6SMDUMMY refresh_token=1//0gDUMMY"
|
||||
out := RedactText(in)
|
||||
if strings.Contains(out, "ya29") || strings.Contains(out, "1//0") {
|
||||
t.Fatalf("expected tokens redacted, got %q", out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRedactText_GOCSPX(t *testing.T) {
|
||||
in := "client_secret=GOCSPX-abcdefghijklmnopqrstuvwxyz_0123456789"
|
||||
out := RedactText(in)
|
||||
if strings.Contains(out, "abcdefghijklmnopqrstuvwxyz") {
|
||||
t.Fatalf("expected secret redacted, got %q", out)
|
||||
}
|
||||
if !strings.Contains(out, "client_secret=***") {
|
||||
t.Fatalf("expected key redacted, got %q", out)
|
||||
}
|
||||
}
|
||||
@@ -49,3 +49,27 @@ func TestValidateURLFormat(t *testing.T) {
|
||||
t.Fatalf("expected trailing slash to be removed from path, got %s", normalized)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateHTTPURL(t *testing.T) {
|
||||
if _, err := ValidateHTTPURL("http://example.com", false, ValidationOptions{}); err == nil {
|
||||
t.Fatalf("expected http to fail when allow_insecure_http is false")
|
||||
}
|
||||
if _, err := ValidateHTTPURL("http://example.com", true, ValidationOptions{}); err != nil {
|
||||
t.Fatalf("expected http to pass when allow_insecure_http is true, got %v", err)
|
||||
}
|
||||
if _, err := ValidateHTTPURL("https://example.com", false, ValidationOptions{RequireAllowlist: true}); err == nil {
|
||||
t.Fatalf("expected require allowlist to fail when empty")
|
||||
}
|
||||
if _, err := ValidateHTTPURL("https://example.com", false, ValidationOptions{AllowedHosts: []string{"api.example.com"}}); err == nil {
|
||||
t.Fatalf("expected host not in allowlist to fail")
|
||||
}
|
||||
if _, err := ValidateHTTPURL("https://api.example.com", false, ValidationOptions{AllowedHosts: []string{"api.example.com"}}); err != nil {
|
||||
t.Fatalf("expected allowlisted host to pass, got %v", err)
|
||||
}
|
||||
if _, err := ValidateHTTPURL("https://sub.api.example.com", false, ValidationOptions{AllowedHosts: []string{"*.example.com"}}); err != nil {
|
||||
t.Fatalf("expected wildcard allowlist to pass, got %v", err)
|
||||
}
|
||||
if _, err := ValidateHTTPURL("https://localhost", false, ValidationOptions{AllowPrivate: false}); err == nil {
|
||||
t.Fatalf("expected localhost to be blocked when allow_private_hosts is false")
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user