Merge branch 'main' into feat/api-key-ip-restriction
This commit is contained in:
@@ -66,6 +66,7 @@ type AccountBulkUpdate struct {
|
||||
Concurrency *int
|
||||
Priority *int
|
||||
Status *string
|
||||
Schedulable *bool
|
||||
Credentials map[string]any
|
||||
Extra map[string]any
|
||||
}
|
||||
|
||||
@@ -661,13 +661,7 @@ func (s *AccountTestService) processGeminiStream(c *gin.Context, body io.Reader)
|
||||
}
|
||||
if candidates, ok := data["candidates"].([]any); ok && len(candidates) > 0 {
|
||||
if candidate, ok := candidates[0].(map[string]any); ok {
|
||||
// Check for completion
|
||||
if finishReason, ok := candidate["finishReason"].(string); ok && finishReason != "" {
|
||||
s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
|
||||
return nil
|
||||
}
|
||||
|
||||
// Extract content
|
||||
// Extract content first (before checking completion)
|
||||
if content, ok := candidate["content"].(map[string]any); ok {
|
||||
if parts, ok := content["parts"].([]any); ok {
|
||||
for _, part := range parts {
|
||||
@@ -679,6 +673,12 @@ func (s *AccountTestService) processGeminiStream(c *gin.Context, body io.Reader)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check for completion after extracting content
|
||||
if finishReason, ok := candidate["finishReason"].(string); ok && finishReason != "" {
|
||||
s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -24,7 +24,7 @@ type AdminService interface {
|
||||
GetUserUsageStats(ctx context.Context, userID int64, period string) (any, error)
|
||||
|
||||
// Group management
|
||||
ListGroups(ctx context.Context, page, pageSize int, platform, status string, isExclusive *bool) ([]Group, int64, error)
|
||||
ListGroups(ctx context.Context, page, pageSize int, platform, status, search string, isExclusive *bool) ([]Group, int64, error)
|
||||
GetAllGroups(ctx context.Context) ([]Group, error)
|
||||
GetAllGroupsByPlatform(ctx context.Context, platform string) ([]Group, error)
|
||||
GetGroup(ctx context.Context, id int64) (*Group, error)
|
||||
@@ -168,6 +168,7 @@ type BulkUpdateAccountsInput struct {
|
||||
Concurrency *int
|
||||
Priority *int
|
||||
Status string
|
||||
Schedulable *bool
|
||||
GroupIDs *[]int64
|
||||
Credentials map[string]any
|
||||
Extra map[string]any
|
||||
@@ -478,9 +479,9 @@ func (s *adminServiceImpl) GetUserUsageStats(ctx context.Context, userID int64,
|
||||
}
|
||||
|
||||
// Group management implementations
|
||||
func (s *adminServiceImpl) ListGroups(ctx context.Context, page, pageSize int, platform, status string, isExclusive *bool) ([]Group, int64, error) {
|
||||
func (s *adminServiceImpl) ListGroups(ctx context.Context, page, pageSize int, platform, status, search string, isExclusive *bool) ([]Group, int64, error) {
|
||||
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
|
||||
groups, result, err := s.groupRepo.ListWithFilters(ctx, params, platform, status, isExclusive)
|
||||
groups, result, err := s.groupRepo.ListWithFilters(ctx, params, platform, status, search, isExclusive)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
@@ -910,6 +911,9 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp
|
||||
if input.Status != "" {
|
||||
repoUpdates.Status = &input.Status
|
||||
}
|
||||
if input.Schedulable != nil {
|
||||
repoUpdates.Schedulable = input.Schedulable
|
||||
}
|
||||
|
||||
// Run bulk update for column/jsonb fields first.
|
||||
if _, err := s.accountRepo.BulkUpdate(ctx, input.AccountIDs, repoUpdates); err != nil {
|
||||
|
||||
@@ -124,7 +124,7 @@ func (s *groupRepoStub) List(ctx context.Context, params pagination.PaginationPa
|
||||
panic("unexpected List call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStub) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status string, isExclusive *bool) ([]Group, *pagination.PaginationResult, error) {
|
||||
func (s *groupRepoStub) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status, search string, isExclusive *bool) ([]Group, *pagination.PaginationResult, error) {
|
||||
panic("unexpected ListWithFilters call")
|
||||
}
|
||||
|
||||
|
||||
@@ -16,6 +16,16 @@ type groupRepoStubForAdmin struct {
|
||||
updated *Group // 记录 Update 调用的参数
|
||||
getByID *Group // GetByID 返回值
|
||||
getErr error // GetByID 返回的错误
|
||||
|
||||
listWithFiltersCalls int
|
||||
listWithFiltersParams pagination.PaginationParams
|
||||
listWithFiltersPlatform string
|
||||
listWithFiltersStatus string
|
||||
listWithFiltersSearch string
|
||||
listWithFiltersIsExclusive *bool
|
||||
listWithFiltersGroups []Group
|
||||
listWithFiltersResult *pagination.PaginationResult
|
||||
listWithFiltersErr error
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForAdmin) Create(_ context.Context, g *Group) error {
|
||||
@@ -47,8 +57,28 @@ func (s *groupRepoStubForAdmin) List(_ context.Context, _ pagination.PaginationP
|
||||
panic("unexpected List call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForAdmin) ListWithFilters(_ context.Context, _ pagination.PaginationParams, _, _ string, _ *bool) ([]Group, *pagination.PaginationResult, error) {
|
||||
panic("unexpected ListWithFilters call")
|
||||
func (s *groupRepoStubForAdmin) ListWithFilters(_ context.Context, params pagination.PaginationParams, platform, status, search string, isExclusive *bool) ([]Group, *pagination.PaginationResult, error) {
|
||||
s.listWithFiltersCalls++
|
||||
s.listWithFiltersParams = params
|
||||
s.listWithFiltersPlatform = platform
|
||||
s.listWithFiltersStatus = status
|
||||
s.listWithFiltersSearch = search
|
||||
s.listWithFiltersIsExclusive = isExclusive
|
||||
|
||||
if s.listWithFiltersErr != nil {
|
||||
return nil, nil, s.listWithFiltersErr
|
||||
}
|
||||
|
||||
result := s.listWithFiltersResult
|
||||
if result == nil {
|
||||
result = &pagination.PaginationResult{
|
||||
Total: int64(len(s.listWithFiltersGroups)),
|
||||
Page: params.Page,
|
||||
PageSize: params.PageSize,
|
||||
}
|
||||
}
|
||||
|
||||
return s.listWithFiltersGroups, result, nil
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForAdmin) ListActive(_ context.Context) ([]Group, error) {
|
||||
@@ -195,3 +225,68 @@ func TestAdminService_UpdateGroup_PartialImagePricing(t *testing.T) {
|
||||
require.InDelta(t, 0.15, *repo.updated.ImagePrice2K, 0.0001) // 原值保持
|
||||
require.Nil(t, repo.updated.ImagePrice4K)
|
||||
}
|
||||
|
||||
func TestAdminService_ListGroups_WithSearch(t *testing.T) {
|
||||
// 测试:
|
||||
// 1. search 参数正常传递到 repository 层
|
||||
// 2. search 为空字符串时的行为
|
||||
// 3. search 与其他过滤条件组合使用
|
||||
|
||||
t.Run("search 参数正常传递到 repository 层", func(t *testing.T) {
|
||||
repo := &groupRepoStubForAdmin{
|
||||
listWithFiltersGroups: []Group{{ID: 1, Name: "alpha"}},
|
||||
listWithFiltersResult: &pagination.PaginationResult{Total: 1},
|
||||
}
|
||||
svc := &adminServiceImpl{groupRepo: repo}
|
||||
|
||||
groups, total, err := svc.ListGroups(context.Background(), 1, 20, "", "", "alpha", nil)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(1), total)
|
||||
require.Equal(t, []Group{{ID: 1, Name: "alpha"}}, groups)
|
||||
|
||||
require.Equal(t, 1, repo.listWithFiltersCalls)
|
||||
require.Equal(t, pagination.PaginationParams{Page: 1, PageSize: 20}, repo.listWithFiltersParams)
|
||||
require.Equal(t, "alpha", repo.listWithFiltersSearch)
|
||||
require.Nil(t, repo.listWithFiltersIsExclusive)
|
||||
})
|
||||
|
||||
t.Run("search 为空字符串时传递空字符串", func(t *testing.T) {
|
||||
repo := &groupRepoStubForAdmin{
|
||||
listWithFiltersGroups: []Group{},
|
||||
listWithFiltersResult: &pagination.PaginationResult{Total: 0},
|
||||
}
|
||||
svc := &adminServiceImpl{groupRepo: repo}
|
||||
|
||||
groups, total, err := svc.ListGroups(context.Background(), 2, 10, "", "", "", nil)
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, groups)
|
||||
require.Equal(t, int64(0), total)
|
||||
|
||||
require.Equal(t, 1, repo.listWithFiltersCalls)
|
||||
require.Equal(t, pagination.PaginationParams{Page: 2, PageSize: 10}, repo.listWithFiltersParams)
|
||||
require.Equal(t, "", repo.listWithFiltersSearch)
|
||||
require.Nil(t, repo.listWithFiltersIsExclusive)
|
||||
})
|
||||
|
||||
t.Run("search 与其他过滤条件组合使用", func(t *testing.T) {
|
||||
isExclusive := true
|
||||
repo := &groupRepoStubForAdmin{
|
||||
listWithFiltersGroups: []Group{{ID: 2, Name: "beta"}},
|
||||
listWithFiltersResult: &pagination.PaginationResult{Total: 42},
|
||||
}
|
||||
svc := &adminServiceImpl{groupRepo: repo}
|
||||
|
||||
groups, total, err := svc.ListGroups(context.Background(), 3, 50, PlatformAntigravity, StatusActive, "beta", &isExclusive)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(42), total)
|
||||
require.Equal(t, []Group{{ID: 2, Name: "beta"}}, groups)
|
||||
|
||||
require.Equal(t, 1, repo.listWithFiltersCalls)
|
||||
require.Equal(t, pagination.PaginationParams{Page: 3, PageSize: 50}, repo.listWithFiltersParams)
|
||||
require.Equal(t, PlatformAntigravity, repo.listWithFiltersPlatform)
|
||||
require.Equal(t, StatusActive, repo.listWithFiltersStatus)
|
||||
require.Equal(t, "beta", repo.listWithFiltersSearch)
|
||||
require.NotNil(t, repo.listWithFiltersIsExclusive)
|
||||
require.True(t, *repo.listWithFiltersIsExclusive)
|
||||
})
|
||||
}
|
||||
|
||||
238
backend/internal/service/admin_service_search_test.go
Normal file
238
backend/internal/service/admin_service_search_test.go
Normal file
@@ -0,0 +1,238 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type accountRepoStubForAdminList struct {
|
||||
accountRepoStub
|
||||
|
||||
listWithFiltersCalls int
|
||||
listWithFiltersParams pagination.PaginationParams
|
||||
listWithFiltersPlatform string
|
||||
listWithFiltersType string
|
||||
listWithFiltersStatus string
|
||||
listWithFiltersSearch string
|
||||
listWithFiltersAccounts []Account
|
||||
listWithFiltersResult *pagination.PaginationResult
|
||||
listWithFiltersErr error
|
||||
}
|
||||
|
||||
func (s *accountRepoStubForAdminList) ListWithFilters(_ context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]Account, *pagination.PaginationResult, error) {
|
||||
s.listWithFiltersCalls++
|
||||
s.listWithFiltersParams = params
|
||||
s.listWithFiltersPlatform = platform
|
||||
s.listWithFiltersType = accountType
|
||||
s.listWithFiltersStatus = status
|
||||
s.listWithFiltersSearch = search
|
||||
|
||||
if s.listWithFiltersErr != nil {
|
||||
return nil, nil, s.listWithFiltersErr
|
||||
}
|
||||
|
||||
result := s.listWithFiltersResult
|
||||
if result == nil {
|
||||
result = &pagination.PaginationResult{
|
||||
Total: int64(len(s.listWithFiltersAccounts)),
|
||||
Page: params.Page,
|
||||
PageSize: params.PageSize,
|
||||
}
|
||||
}
|
||||
|
||||
return s.listWithFiltersAccounts, result, nil
|
||||
}
|
||||
|
||||
type proxyRepoStubForAdminList struct {
|
||||
proxyRepoStub
|
||||
|
||||
listWithFiltersCalls int
|
||||
listWithFiltersParams pagination.PaginationParams
|
||||
listWithFiltersProtocol string
|
||||
listWithFiltersStatus string
|
||||
listWithFiltersSearch string
|
||||
listWithFiltersProxies []Proxy
|
||||
listWithFiltersResult *pagination.PaginationResult
|
||||
listWithFiltersErr error
|
||||
|
||||
listWithFiltersAndAccountCountCalls int
|
||||
listWithFiltersAndAccountCountParams pagination.PaginationParams
|
||||
listWithFiltersAndAccountCountProtocol string
|
||||
listWithFiltersAndAccountCountStatus string
|
||||
listWithFiltersAndAccountCountSearch string
|
||||
listWithFiltersAndAccountCountProxies []ProxyWithAccountCount
|
||||
listWithFiltersAndAccountCountResult *pagination.PaginationResult
|
||||
listWithFiltersAndAccountCountErr error
|
||||
}
|
||||
|
||||
func (s *proxyRepoStubForAdminList) ListWithFilters(_ context.Context, params pagination.PaginationParams, protocol, status, search string) ([]Proxy, *pagination.PaginationResult, error) {
|
||||
s.listWithFiltersCalls++
|
||||
s.listWithFiltersParams = params
|
||||
s.listWithFiltersProtocol = protocol
|
||||
s.listWithFiltersStatus = status
|
||||
s.listWithFiltersSearch = search
|
||||
|
||||
if s.listWithFiltersErr != nil {
|
||||
return nil, nil, s.listWithFiltersErr
|
||||
}
|
||||
|
||||
result := s.listWithFiltersResult
|
||||
if result == nil {
|
||||
result = &pagination.PaginationResult{
|
||||
Total: int64(len(s.listWithFiltersProxies)),
|
||||
Page: params.Page,
|
||||
PageSize: params.PageSize,
|
||||
}
|
||||
}
|
||||
|
||||
return s.listWithFiltersProxies, result, nil
|
||||
}
|
||||
|
||||
func (s *proxyRepoStubForAdminList) ListWithFiltersAndAccountCount(_ context.Context, params pagination.PaginationParams, protocol, status, search string) ([]ProxyWithAccountCount, *pagination.PaginationResult, error) {
|
||||
s.listWithFiltersAndAccountCountCalls++
|
||||
s.listWithFiltersAndAccountCountParams = params
|
||||
s.listWithFiltersAndAccountCountProtocol = protocol
|
||||
s.listWithFiltersAndAccountCountStatus = status
|
||||
s.listWithFiltersAndAccountCountSearch = search
|
||||
|
||||
if s.listWithFiltersAndAccountCountErr != nil {
|
||||
return nil, nil, s.listWithFiltersAndAccountCountErr
|
||||
}
|
||||
|
||||
result := s.listWithFiltersAndAccountCountResult
|
||||
if result == nil {
|
||||
result = &pagination.PaginationResult{
|
||||
Total: int64(len(s.listWithFiltersAndAccountCountProxies)),
|
||||
Page: params.Page,
|
||||
PageSize: params.PageSize,
|
||||
}
|
||||
}
|
||||
|
||||
return s.listWithFiltersAndAccountCountProxies, result, nil
|
||||
}
|
||||
|
||||
type redeemRepoStubForAdminList struct {
|
||||
redeemRepoStub
|
||||
|
||||
listWithFiltersCalls int
|
||||
listWithFiltersParams pagination.PaginationParams
|
||||
listWithFiltersType string
|
||||
listWithFiltersStatus string
|
||||
listWithFiltersSearch string
|
||||
listWithFiltersCodes []RedeemCode
|
||||
listWithFiltersResult *pagination.PaginationResult
|
||||
listWithFiltersErr error
|
||||
}
|
||||
|
||||
func (s *redeemRepoStubForAdminList) ListWithFilters(_ context.Context, params pagination.PaginationParams, codeType, status, search string) ([]RedeemCode, *pagination.PaginationResult, error) {
|
||||
s.listWithFiltersCalls++
|
||||
s.listWithFiltersParams = params
|
||||
s.listWithFiltersType = codeType
|
||||
s.listWithFiltersStatus = status
|
||||
s.listWithFiltersSearch = search
|
||||
|
||||
if s.listWithFiltersErr != nil {
|
||||
return nil, nil, s.listWithFiltersErr
|
||||
}
|
||||
|
||||
result := s.listWithFiltersResult
|
||||
if result == nil {
|
||||
result = &pagination.PaginationResult{
|
||||
Total: int64(len(s.listWithFiltersCodes)),
|
||||
Page: params.Page,
|
||||
PageSize: params.PageSize,
|
||||
}
|
||||
}
|
||||
|
||||
return s.listWithFiltersCodes, result, nil
|
||||
}
|
||||
|
||||
func TestAdminService_ListAccounts_WithSearch(t *testing.T) {
|
||||
t.Run("search 参数正常传递到 repository 层", func(t *testing.T) {
|
||||
repo := &accountRepoStubForAdminList{
|
||||
listWithFiltersAccounts: []Account{{ID: 1, Name: "acc"}},
|
||||
listWithFiltersResult: &pagination.PaginationResult{Total: 10},
|
||||
}
|
||||
svc := &adminServiceImpl{accountRepo: repo}
|
||||
|
||||
accounts, total, err := svc.ListAccounts(context.Background(), 1, 20, PlatformGemini, AccountTypeOAuth, StatusActive, "acc")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(10), total)
|
||||
require.Equal(t, []Account{{ID: 1, Name: "acc"}}, accounts)
|
||||
|
||||
require.Equal(t, 1, repo.listWithFiltersCalls)
|
||||
require.Equal(t, pagination.PaginationParams{Page: 1, PageSize: 20}, repo.listWithFiltersParams)
|
||||
require.Equal(t, PlatformGemini, repo.listWithFiltersPlatform)
|
||||
require.Equal(t, AccountTypeOAuth, repo.listWithFiltersType)
|
||||
require.Equal(t, StatusActive, repo.listWithFiltersStatus)
|
||||
require.Equal(t, "acc", repo.listWithFiltersSearch)
|
||||
})
|
||||
}
|
||||
|
||||
func TestAdminService_ListProxies_WithSearch(t *testing.T) {
|
||||
t.Run("search 参数正常传递到 repository 层", func(t *testing.T) {
|
||||
repo := &proxyRepoStubForAdminList{
|
||||
listWithFiltersProxies: []Proxy{{ID: 2, Name: "p1"}},
|
||||
listWithFiltersResult: &pagination.PaginationResult{Total: 7},
|
||||
}
|
||||
svc := &adminServiceImpl{proxyRepo: repo}
|
||||
|
||||
proxies, total, err := svc.ListProxies(context.Background(), 3, 50, "http", StatusActive, "p1")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(7), total)
|
||||
require.Equal(t, []Proxy{{ID: 2, Name: "p1"}}, proxies)
|
||||
|
||||
require.Equal(t, 1, repo.listWithFiltersCalls)
|
||||
require.Equal(t, pagination.PaginationParams{Page: 3, PageSize: 50}, repo.listWithFiltersParams)
|
||||
require.Equal(t, "http", repo.listWithFiltersProtocol)
|
||||
require.Equal(t, StatusActive, repo.listWithFiltersStatus)
|
||||
require.Equal(t, "p1", repo.listWithFiltersSearch)
|
||||
})
|
||||
}
|
||||
|
||||
func TestAdminService_ListProxiesWithAccountCount_WithSearch(t *testing.T) {
|
||||
t.Run("search 参数正常传递到 repository 层", func(t *testing.T) {
|
||||
repo := &proxyRepoStubForAdminList{
|
||||
listWithFiltersAndAccountCountProxies: []ProxyWithAccountCount{{Proxy: Proxy{ID: 3, Name: "p2"}, AccountCount: 5}},
|
||||
listWithFiltersAndAccountCountResult: &pagination.PaginationResult{Total: 9},
|
||||
}
|
||||
svc := &adminServiceImpl{proxyRepo: repo}
|
||||
|
||||
proxies, total, err := svc.ListProxiesWithAccountCount(context.Background(), 2, 10, "socks5", StatusDisabled, "p2")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(9), total)
|
||||
require.Equal(t, []ProxyWithAccountCount{{Proxy: Proxy{ID: 3, Name: "p2"}, AccountCount: 5}}, proxies)
|
||||
|
||||
require.Equal(t, 1, repo.listWithFiltersAndAccountCountCalls)
|
||||
require.Equal(t, pagination.PaginationParams{Page: 2, PageSize: 10}, repo.listWithFiltersAndAccountCountParams)
|
||||
require.Equal(t, "socks5", repo.listWithFiltersAndAccountCountProtocol)
|
||||
require.Equal(t, StatusDisabled, repo.listWithFiltersAndAccountCountStatus)
|
||||
require.Equal(t, "p2", repo.listWithFiltersAndAccountCountSearch)
|
||||
})
|
||||
}
|
||||
|
||||
func TestAdminService_ListRedeemCodes_WithSearch(t *testing.T) {
|
||||
t.Run("search 参数正常传递到 repository 层", func(t *testing.T) {
|
||||
repo := &redeemRepoStubForAdminList{
|
||||
listWithFiltersCodes: []RedeemCode{{ID: 4, Code: "ABC"}},
|
||||
listWithFiltersResult: &pagination.PaginationResult{Total: 3},
|
||||
}
|
||||
svc := &adminServiceImpl{redeemCodeRepo: repo}
|
||||
|
||||
codes, total, err := svc.ListRedeemCodes(context.Background(), 1, 20, RedeemTypeBalance, StatusUnused, "ABC")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(3), total)
|
||||
require.Equal(t, []RedeemCode{{ID: 4, Code: "ABC"}}, codes)
|
||||
|
||||
require.Equal(t, 1, repo.listWithFiltersCalls)
|
||||
require.Equal(t, pagination.PaginationParams{Page: 1, PageSize: 20}, repo.listWithFiltersParams)
|
||||
require.Equal(t, RedeemTypeBalance, repo.listWithFiltersType)
|
||||
require.Equal(t, StatusUnused, repo.listWithFiltersStatus)
|
||||
require.Equal(t, "ABC", repo.listWithFiltersSearch)
|
||||
})
|
||||
}
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"io"
|
||||
"log"
|
||||
mathrand "math/rand"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
@@ -27,6 +28,32 @@ const (
|
||||
antigravityRetryMaxDelay = 16 * time.Second
|
||||
)
|
||||
|
||||
// isAntigravityConnectionError 判断是否为连接错误(网络超时、DNS 失败、连接拒绝)
|
||||
func isAntigravityConnectionError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// 检查超时错误
|
||||
var netErr net.Error
|
||||
if errors.As(err, &netErr) && netErr.Timeout() {
|
||||
return true
|
||||
}
|
||||
|
||||
// 检查连接错误(DNS 失败、连接拒绝)
|
||||
var opErr *net.OpError
|
||||
return errors.As(err, &opErr)
|
||||
}
|
||||
|
||||
// shouldAntigravityFallbackToNextURL 判断是否应切换到下一个 URL
|
||||
// 仅连接错误和 HTTP 429 触发 URL 降级
|
||||
func shouldAntigravityFallbackToNextURL(err error, statusCode int) bool {
|
||||
if isAntigravityConnectionError(err) {
|
||||
return true
|
||||
}
|
||||
return statusCode == http.StatusTooManyRequests
|
||||
}
|
||||
|
||||
// getSessionID 从 gin.Context 获取 session_id(用于日志追踪)
|
||||
func getSessionID(c *gin.Context) string {
|
||||
if c == nil {
|
||||
@@ -181,45 +208,70 @@ func (s *AntigravityGatewayService) TestConnection(ctx context.Context, account
|
||||
return nil, fmt.Errorf("构建请求失败: %w", err)
|
||||
}
|
||||
|
||||
// 构建 HTTP 请求(总是使用流式 endpoint,与官方客户端一致)
|
||||
req, err := antigravity.NewAPIRequest(ctx, "streamGenerateContent", accessToken, requestBody)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 调试日志:Test 请求信息
|
||||
log.Printf("[antigravity-Test] account=%s request_size=%d url=%s", account.Name, len(requestBody), req.URL.String())
|
||||
|
||||
// 代理 URL
|
||||
proxyURL := ""
|
||||
if account.ProxyID != nil && account.Proxy != nil {
|
||||
proxyURL = account.Proxy.URL()
|
||||
}
|
||||
|
||||
// 发送请求
|
||||
resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("请求失败: %w", err)
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
// 读取响应
|
||||
respBody, err := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("读取响应失败: %w", err)
|
||||
// URL fallback 循环
|
||||
availableURLs := antigravity.DefaultURLAvailability.GetAvailableURLs()
|
||||
if len(availableURLs) == 0 {
|
||||
availableURLs = antigravity.BaseURLs // 所有 URL 都不可用时,重试所有
|
||||
}
|
||||
|
||||
if resp.StatusCode >= 400 {
|
||||
return nil, fmt.Errorf("API 返回 %d: %s", resp.StatusCode, string(respBody))
|
||||
var lastErr error
|
||||
for urlIdx, baseURL := range availableURLs {
|
||||
// 构建 HTTP 请求(总是使用流式 endpoint,与官方客户端一致)
|
||||
req, err := antigravity.NewAPIRequestWithURL(ctx, baseURL, "streamGenerateContent", accessToken, requestBody)
|
||||
if err != nil {
|
||||
lastErr = err
|
||||
continue
|
||||
}
|
||||
|
||||
// 调试日志:Test 请求信息
|
||||
log.Printf("[antigravity-Test] account=%s request_size=%d url=%s", account.Name, len(requestBody), req.URL.String())
|
||||
|
||||
// 发送请求
|
||||
resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency)
|
||||
if err != nil {
|
||||
lastErr = fmt.Errorf("请求失败: %w", err)
|
||||
if shouldAntigravityFallbackToNextURL(err, 0) && urlIdx < len(availableURLs)-1 {
|
||||
antigravity.DefaultURLAvailability.MarkUnavailable(baseURL)
|
||||
log.Printf("[antigravity-Test] URL fallback: %s -> %s", baseURL, availableURLs[urlIdx+1])
|
||||
continue
|
||||
}
|
||||
return nil, lastErr
|
||||
}
|
||||
|
||||
// 读取响应
|
||||
respBody, err := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||||
_ = resp.Body.Close() // 立即关闭,避免循环内 defer 导致的资源泄漏
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("读取响应失败: %w", err)
|
||||
}
|
||||
|
||||
// 检查是否需要 URL 降级
|
||||
if shouldAntigravityFallbackToNextURL(nil, resp.StatusCode) && urlIdx < len(availableURLs)-1 {
|
||||
antigravity.DefaultURLAvailability.MarkUnavailable(baseURL)
|
||||
log.Printf("[antigravity-Test] URL fallback (HTTP %d): %s -> %s", resp.StatusCode, baseURL, availableURLs[urlIdx+1])
|
||||
continue
|
||||
}
|
||||
|
||||
if resp.StatusCode >= 400 {
|
||||
return nil, fmt.Errorf("API 返回 %d: %s", resp.StatusCode, string(respBody))
|
||||
}
|
||||
|
||||
// 解析流式响应,提取文本
|
||||
text := extractTextFromSSEResponse(respBody)
|
||||
|
||||
return &TestConnectionResult{
|
||||
Text: text,
|
||||
MappedModel: mappedModel,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// 解析流式响应,提取文本
|
||||
text := extractTextFromSSEResponse(respBody)
|
||||
|
||||
return &TestConnectionResult{
|
||||
Text: text,
|
||||
MappedModel: mappedModel,
|
||||
}, nil
|
||||
return nil, lastErr
|
||||
}
|
||||
|
||||
// buildGeminiTestRequest 构建 Gemini 格式测试请求
|
||||
@@ -484,62 +536,86 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
|
||||
// 如果客户端请求非流式,在响应处理阶段会收集完整流式响应后转换返回
|
||||
action := "streamGenerateContent"
|
||||
|
||||
// URL fallback 循环
|
||||
availableURLs := antigravity.DefaultURLAvailability.GetAvailableURLs()
|
||||
if len(availableURLs) == 0 {
|
||||
availableURLs = antigravity.BaseURLs // 所有 URL 都不可用时,重试所有
|
||||
}
|
||||
|
||||
// 重试循环
|
||||
var resp *http.Response
|
||||
for attempt := 1; attempt <= antigravityMaxRetries; attempt++ {
|
||||
// 检查 context 是否已取消(客户端断开连接)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
log.Printf("%s status=context_canceled error=%v", prefix, ctx.Err())
|
||||
return nil, ctx.Err()
|
||||
default:
|
||||
}
|
||||
urlFallbackLoop:
|
||||
for urlIdx, baseURL := range availableURLs {
|
||||
for attempt := 1; attempt <= antigravityMaxRetries; attempt++ {
|
||||
// 检查 context 是否已取消(客户端断开连接)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
log.Printf("%s status=context_canceled error=%v", prefix, ctx.Err())
|
||||
return nil, ctx.Err()
|
||||
default:
|
||||
}
|
||||
|
||||
upstreamReq, err := antigravity.NewAPIRequest(ctx, action, accessToken, geminiBody)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
upstreamReq, err := antigravity.NewAPIRequestWithURL(ctx, baseURL, action, accessToken, geminiBody)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resp, err = s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
|
||||
if err != nil {
|
||||
if attempt < antigravityMaxRetries {
|
||||
log.Printf("%s status=request_failed retry=%d/%d error=%v", prefix, attempt, antigravityMaxRetries, err)
|
||||
if !sleepAntigravityBackoffWithContext(ctx, attempt) {
|
||||
log.Printf("%s status=context_canceled_during_backoff", prefix)
|
||||
return nil, ctx.Err()
|
||||
resp, err = s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
|
||||
if err != nil {
|
||||
// 检查是否应触发 URL 降级
|
||||
if shouldAntigravityFallbackToNextURL(err, 0) && urlIdx < len(availableURLs)-1 {
|
||||
antigravity.DefaultURLAvailability.MarkUnavailable(baseURL)
|
||||
log.Printf("%s URL fallback (connection error): %s -> %s", prefix, baseURL, availableURLs[urlIdx+1])
|
||||
continue urlFallbackLoop
|
||||
}
|
||||
continue
|
||||
}
|
||||
log.Printf("%s status=request_failed retries_exhausted error=%v", prefix, err)
|
||||
return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Upstream request failed after retries")
|
||||
}
|
||||
|
||||
if resp.StatusCode >= 400 && s.shouldRetryUpstreamError(resp.StatusCode) {
|
||||
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||||
_ = resp.Body.Close()
|
||||
|
||||
if attempt < antigravityMaxRetries {
|
||||
log.Printf("%s status=%d retry=%d/%d body=%s", prefix, resp.StatusCode, attempt, antigravityMaxRetries, truncateForLog(respBody, 500))
|
||||
if !sleepAntigravityBackoffWithContext(ctx, attempt) {
|
||||
log.Printf("%s status=context_canceled_during_backoff", prefix)
|
||||
return nil, ctx.Err()
|
||||
if attempt < antigravityMaxRetries {
|
||||
log.Printf("%s status=request_failed retry=%d/%d error=%v", prefix, attempt, antigravityMaxRetries, err)
|
||||
if !sleepAntigravityBackoffWithContext(ctx, attempt) {
|
||||
log.Printf("%s status=context_canceled_during_backoff", prefix)
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
continue
|
||||
}
|
||||
continue
|
||||
log.Printf("%s status=request_failed retries_exhausted error=%v", prefix, err)
|
||||
return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Upstream request failed after retries")
|
||||
}
|
||||
// 所有重试都失败,标记限流状态
|
||||
if resp.StatusCode == 429 {
|
||||
s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody)
|
||||
}
|
||||
// 最后一次尝试也失败
|
||||
resp = &http.Response{
|
||||
StatusCode: resp.StatusCode,
|
||||
Header: resp.Header.Clone(),
|
||||
Body: io.NopCloser(bytes.NewReader(respBody)),
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
break
|
||||
// 检查是否应触发 URL 降级(仅 429)
|
||||
if resp.StatusCode == http.StatusTooManyRequests && urlIdx < len(availableURLs)-1 {
|
||||
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||||
_ = resp.Body.Close()
|
||||
antigravity.DefaultURLAvailability.MarkUnavailable(baseURL)
|
||||
log.Printf("%s URL fallback (HTTP 429): %s -> %s body=%s", prefix, baseURL, availableURLs[urlIdx+1], truncateForLog(respBody, 200))
|
||||
continue urlFallbackLoop
|
||||
}
|
||||
|
||||
if resp.StatusCode >= 400 && s.shouldRetryUpstreamError(resp.StatusCode) {
|
||||
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||||
_ = resp.Body.Close()
|
||||
|
||||
if attempt < antigravityMaxRetries {
|
||||
log.Printf("%s status=%d retry=%d/%d body=%s", prefix, resp.StatusCode, attempt, antigravityMaxRetries, truncateForLog(respBody, 500))
|
||||
if !sleepAntigravityBackoffWithContext(ctx, attempt) {
|
||||
log.Printf("%s status=context_canceled_during_backoff", prefix)
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
continue
|
||||
}
|
||||
// 所有重试都失败,标记限流状态
|
||||
if resp.StatusCode == 429 {
|
||||
s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody)
|
||||
}
|
||||
// 最后一次尝试也失败
|
||||
resp = &http.Response{
|
||||
StatusCode: resp.StatusCode,
|
||||
Header: resp.Header.Clone(),
|
||||
Body: io.NopCloser(bytes.NewReader(respBody)),
|
||||
}
|
||||
break urlFallbackLoop
|
||||
}
|
||||
|
||||
break urlFallbackLoop
|
||||
}
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
@@ -1003,61 +1079,85 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
|
||||
// 如果客户端请求非流式,在响应处理阶段会收集完整流式响应后返回
|
||||
upstreamAction := "streamGenerateContent"
|
||||
|
||||
// URL fallback 循环
|
||||
availableURLs := antigravity.DefaultURLAvailability.GetAvailableURLs()
|
||||
if len(availableURLs) == 0 {
|
||||
availableURLs = antigravity.BaseURLs // 所有 URL 都不可用时,重试所有
|
||||
}
|
||||
|
||||
// 重试循环
|
||||
var resp *http.Response
|
||||
for attempt := 1; attempt <= antigravityMaxRetries; attempt++ {
|
||||
// 检查 context 是否已取消(客户端断开连接)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
log.Printf("%s status=context_canceled error=%v", prefix, ctx.Err())
|
||||
return nil, ctx.Err()
|
||||
default:
|
||||
}
|
||||
urlFallbackLoop:
|
||||
for urlIdx, baseURL := range availableURLs {
|
||||
for attempt := 1; attempt <= antigravityMaxRetries; attempt++ {
|
||||
// 检查 context 是否已取消(客户端断开连接)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
log.Printf("%s status=context_canceled error=%v", prefix, ctx.Err())
|
||||
return nil, ctx.Err()
|
||||
default:
|
||||
}
|
||||
|
||||
upstreamReq, err := antigravity.NewAPIRequest(ctx, upstreamAction, accessToken, wrappedBody)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
upstreamReq, err := antigravity.NewAPIRequestWithURL(ctx, baseURL, upstreamAction, accessToken, wrappedBody)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resp, err = s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
|
||||
if err != nil {
|
||||
if attempt < antigravityMaxRetries {
|
||||
log.Printf("%s status=request_failed retry=%d/%d error=%v", prefix, attempt, antigravityMaxRetries, err)
|
||||
if !sleepAntigravityBackoffWithContext(ctx, attempt) {
|
||||
log.Printf("%s status=context_canceled_during_backoff", prefix)
|
||||
return nil, ctx.Err()
|
||||
resp, err = s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
|
||||
if err != nil {
|
||||
// 检查是否应触发 URL 降级
|
||||
if shouldAntigravityFallbackToNextURL(err, 0) && urlIdx < len(availableURLs)-1 {
|
||||
antigravity.DefaultURLAvailability.MarkUnavailable(baseURL)
|
||||
log.Printf("%s URL fallback (connection error): %s -> %s", prefix, baseURL, availableURLs[urlIdx+1])
|
||||
continue urlFallbackLoop
|
||||
}
|
||||
continue
|
||||
}
|
||||
log.Printf("%s status=request_failed retries_exhausted error=%v", prefix, err)
|
||||
return nil, s.writeGoogleError(c, http.StatusBadGateway, "Upstream request failed after retries")
|
||||
}
|
||||
|
||||
if resp.StatusCode >= 400 && s.shouldRetryUpstreamError(resp.StatusCode) {
|
||||
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||||
_ = resp.Body.Close()
|
||||
|
||||
if attempt < antigravityMaxRetries {
|
||||
log.Printf("%s status=%d retry=%d/%d", prefix, resp.StatusCode, attempt, antigravityMaxRetries)
|
||||
if !sleepAntigravityBackoffWithContext(ctx, attempt) {
|
||||
log.Printf("%s status=context_canceled_during_backoff", prefix)
|
||||
return nil, ctx.Err()
|
||||
if attempt < antigravityMaxRetries {
|
||||
log.Printf("%s status=request_failed retry=%d/%d error=%v", prefix, attempt, antigravityMaxRetries, err)
|
||||
if !sleepAntigravityBackoffWithContext(ctx, attempt) {
|
||||
log.Printf("%s status=context_canceled_during_backoff", prefix)
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
continue
|
||||
}
|
||||
continue
|
||||
log.Printf("%s status=request_failed retries_exhausted error=%v", prefix, err)
|
||||
return nil, s.writeGoogleError(c, http.StatusBadGateway, "Upstream request failed after retries")
|
||||
}
|
||||
// 所有重试都失败,标记限流状态
|
||||
if resp.StatusCode == 429 {
|
||||
s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody)
|
||||
}
|
||||
resp = &http.Response{
|
||||
StatusCode: resp.StatusCode,
|
||||
Header: resp.Header.Clone(),
|
||||
Body: io.NopCloser(bytes.NewReader(respBody)),
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
break
|
||||
// 检查是否应触发 URL 降级(仅 429)
|
||||
if resp.StatusCode == http.StatusTooManyRequests && urlIdx < len(availableURLs)-1 {
|
||||
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||||
_ = resp.Body.Close()
|
||||
antigravity.DefaultURLAvailability.MarkUnavailable(baseURL)
|
||||
log.Printf("%s URL fallback (HTTP 429): %s -> %s body=%s", prefix, baseURL, availableURLs[urlIdx+1], truncateForLog(respBody, 200))
|
||||
continue urlFallbackLoop
|
||||
}
|
||||
|
||||
if resp.StatusCode >= 400 && s.shouldRetryUpstreamError(resp.StatusCode) {
|
||||
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||||
_ = resp.Body.Close()
|
||||
|
||||
if attempt < antigravityMaxRetries {
|
||||
log.Printf("%s status=%d retry=%d/%d", prefix, resp.StatusCode, attempt, antigravityMaxRetries)
|
||||
if !sleepAntigravityBackoffWithContext(ctx, attempt) {
|
||||
log.Printf("%s status=context_canceled_during_backoff", prefix)
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
continue
|
||||
}
|
||||
// 所有重试都失败,标记限流状态
|
||||
if resp.StatusCode == 429 {
|
||||
s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody)
|
||||
}
|
||||
resp = &http.Response{
|
||||
StatusCode: resp.StatusCode,
|
||||
Header: resp.Header.Clone(),
|
||||
Body: io.NopCloser(bytes.NewReader(respBody)),
|
||||
}
|
||||
break urlFallbackLoop
|
||||
}
|
||||
|
||||
break urlFallbackLoop
|
||||
}
|
||||
}
|
||||
defer func() {
|
||||
if resp != nil && resp.Body != nil {
|
||||
|
||||
@@ -2,9 +2,13 @@ package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/mail"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
@@ -18,6 +22,7 @@ var (
|
||||
ErrInvalidCredentials = infraerrors.Unauthorized("INVALID_CREDENTIALS", "invalid email or password")
|
||||
ErrUserNotActive = infraerrors.Forbidden("USER_NOT_ACTIVE", "user is not active")
|
||||
ErrEmailExists = infraerrors.Conflict("EMAIL_EXISTS", "email already exists")
|
||||
ErrEmailReserved = infraerrors.BadRequest("EMAIL_RESERVED", "email is reserved")
|
||||
ErrInvalidToken = infraerrors.Unauthorized("INVALID_TOKEN", "invalid token")
|
||||
ErrTokenExpired = infraerrors.Unauthorized("TOKEN_EXPIRED", "token has expired")
|
||||
ErrTokenTooLarge = infraerrors.BadRequest("TOKEN_TOO_LARGE", "token too large")
|
||||
@@ -75,21 +80,30 @@ func (s *AuthService) Register(ctx context.Context, email, password string) (str
|
||||
|
||||
// RegisterWithVerification 用户注册(支持邮件验证),返回token和用户
|
||||
func (s *AuthService) RegisterWithVerification(ctx context.Context, email, password, verifyCode string) (string, *User, error) {
|
||||
// 检查是否开放注册
|
||||
if s.settingService != nil && !s.settingService.IsRegistrationEnabled(ctx) {
|
||||
// 检查是否开放注册(默认关闭:settingService 未配置时不允许注册)
|
||||
if s.settingService == nil || !s.settingService.IsRegistrationEnabled(ctx) {
|
||||
return "", nil, ErrRegDisabled
|
||||
}
|
||||
|
||||
// 防止用户注册 LinuxDo OAuth 合成邮箱,避免第三方登录与本地账号发生碰撞。
|
||||
if isReservedEmail(email) {
|
||||
return "", nil, ErrEmailReserved
|
||||
}
|
||||
|
||||
// 检查是否需要邮件验证
|
||||
if s.settingService != nil && s.settingService.IsEmailVerifyEnabled(ctx) {
|
||||
// 如果邮件验证已开启但邮件服务未配置,拒绝注册
|
||||
// 这是一个配置错误,不应该允许绕过验证
|
||||
if s.emailService == nil {
|
||||
log.Println("[Auth] Email verification enabled but email service not configured, rejecting registration")
|
||||
return "", nil, ErrServiceUnavailable
|
||||
}
|
||||
if verifyCode == "" {
|
||||
return "", nil, ErrEmailVerifyRequired
|
||||
}
|
||||
// 验证邮箱验证码
|
||||
if s.emailService != nil {
|
||||
if err := s.emailService.VerifyCode(ctx, email, verifyCode); err != nil {
|
||||
return "", nil, fmt.Errorf("verify code: %w", err)
|
||||
}
|
||||
if err := s.emailService.VerifyCode(ctx, email, verifyCode); err != nil {
|
||||
return "", nil, fmt.Errorf("verify code: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -128,6 +142,10 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw
|
||||
}
|
||||
|
||||
if err := s.userRepo.Create(ctx, user); err != nil {
|
||||
// 优先检查邮箱冲突错误(竞态条件下可能发生)
|
||||
if errors.Is(err, ErrEmailExists) {
|
||||
return "", nil, ErrEmailExists
|
||||
}
|
||||
log.Printf("[Auth] Database error creating user: %v", err)
|
||||
return "", nil, ErrServiceUnavailable
|
||||
}
|
||||
@@ -148,11 +166,15 @@ type SendVerifyCodeResult struct {
|
||||
|
||||
// SendVerifyCode 发送邮箱验证码(同步方式)
|
||||
func (s *AuthService) SendVerifyCode(ctx context.Context, email string) error {
|
||||
// 检查是否开放注册
|
||||
if s.settingService != nil && !s.settingService.IsRegistrationEnabled(ctx) {
|
||||
// 检查是否开放注册(默认关闭)
|
||||
if s.settingService == nil || !s.settingService.IsRegistrationEnabled(ctx) {
|
||||
return ErrRegDisabled
|
||||
}
|
||||
|
||||
if isReservedEmail(email) {
|
||||
return ErrEmailReserved
|
||||
}
|
||||
|
||||
// 检查邮箱是否已存在
|
||||
existsEmail, err := s.userRepo.ExistsByEmail(ctx, email)
|
||||
if err != nil {
|
||||
@@ -181,12 +203,16 @@ func (s *AuthService) SendVerifyCode(ctx context.Context, email string) error {
|
||||
func (s *AuthService) SendVerifyCodeAsync(ctx context.Context, email string) (*SendVerifyCodeResult, error) {
|
||||
log.Printf("[Auth] SendVerifyCodeAsync called for email: %s", email)
|
||||
|
||||
// 检查是否开放注册
|
||||
if s.settingService != nil && !s.settingService.IsRegistrationEnabled(ctx) {
|
||||
// 检查是否开放注册(默认关闭)
|
||||
if s.settingService == nil || !s.settingService.IsRegistrationEnabled(ctx) {
|
||||
log.Println("[Auth] Registration is disabled")
|
||||
return nil, ErrRegDisabled
|
||||
}
|
||||
|
||||
if isReservedEmail(email) {
|
||||
return nil, ErrEmailReserved
|
||||
}
|
||||
|
||||
// 检查邮箱是否已存在
|
||||
existsEmail, err := s.userRepo.ExistsByEmail(ctx, email)
|
||||
if err != nil {
|
||||
@@ -266,7 +292,7 @@ func (s *AuthService) IsTurnstileEnabled(ctx context.Context) bool {
|
||||
// IsRegistrationEnabled 检查是否开放注册
|
||||
func (s *AuthService) IsRegistrationEnabled(ctx context.Context) bool {
|
||||
if s.settingService == nil {
|
||||
return true
|
||||
return false // 安全默认:settingService 未配置时关闭注册
|
||||
}
|
||||
return s.settingService.IsRegistrationEnabled(ctx)
|
||||
}
|
||||
@@ -311,6 +337,102 @@ func (s *AuthService) Login(ctx context.Context, email, password string) (string
|
||||
return token, user, nil
|
||||
}
|
||||
|
||||
// LoginOrRegisterOAuth 用于第三方 OAuth/SSO 登录:
|
||||
// - 如果邮箱已存在:直接登录(不需要本地密码)
|
||||
// - 如果邮箱不存在:创建新用户并登录
|
||||
//
|
||||
// 注意:该函数用于“终端用户登录 Sub2API 本身”的场景(不同于上游账号的 OAuth,例如 OpenAI/Gemini)。
|
||||
// 为了满足现有数据库约束(需要密码哈希),新用户会生成随机密码并进行哈希保存。
|
||||
func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username string) (string, *User, error) {
|
||||
email = strings.TrimSpace(email)
|
||||
if email == "" || len(email) > 255 {
|
||||
return "", nil, infraerrors.BadRequest("INVALID_EMAIL", "invalid email")
|
||||
}
|
||||
if _, err := mail.ParseAddress(email); err != nil {
|
||||
return "", nil, infraerrors.BadRequest("INVALID_EMAIL", "invalid email")
|
||||
}
|
||||
|
||||
username = strings.TrimSpace(username)
|
||||
if len([]rune(username)) > 100 {
|
||||
username = string([]rune(username)[:100])
|
||||
}
|
||||
|
||||
user, err := s.userRepo.GetByEmail(ctx, email)
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrUserNotFound) {
|
||||
// OAuth 首次登录视为注册。
|
||||
if s.settingService != nil && !s.settingService.IsRegistrationEnabled(ctx) {
|
||||
return "", nil, ErrRegDisabled
|
||||
}
|
||||
|
||||
randomPassword, err := randomHexString(32)
|
||||
if err != nil {
|
||||
log.Printf("[Auth] Failed to generate random password for oauth signup: %v", err)
|
||||
return "", nil, ErrServiceUnavailable
|
||||
}
|
||||
hashedPassword, err := s.HashPassword(randomPassword)
|
||||
if err != nil {
|
||||
return "", nil, fmt.Errorf("hash password: %w", err)
|
||||
}
|
||||
|
||||
// 新用户默认值。
|
||||
defaultBalance := s.cfg.Default.UserBalance
|
||||
defaultConcurrency := s.cfg.Default.UserConcurrency
|
||||
if s.settingService != nil {
|
||||
defaultBalance = s.settingService.GetDefaultBalance(ctx)
|
||||
defaultConcurrency = s.settingService.GetDefaultConcurrency(ctx)
|
||||
}
|
||||
|
||||
newUser := &User{
|
||||
Email: email,
|
||||
Username: username,
|
||||
PasswordHash: hashedPassword,
|
||||
Role: RoleUser,
|
||||
Balance: defaultBalance,
|
||||
Concurrency: defaultConcurrency,
|
||||
Status: StatusActive,
|
||||
}
|
||||
|
||||
if err := s.userRepo.Create(ctx, newUser); err != nil {
|
||||
if errors.Is(err, ErrEmailExists) {
|
||||
// 并发场景:GetByEmail 与 Create 之间用户被创建。
|
||||
user, err = s.userRepo.GetByEmail(ctx, email)
|
||||
if err != nil {
|
||||
log.Printf("[Auth] Database error getting user after conflict: %v", err)
|
||||
return "", nil, ErrServiceUnavailable
|
||||
}
|
||||
} else {
|
||||
log.Printf("[Auth] Database error creating oauth user: %v", err)
|
||||
return "", nil, ErrServiceUnavailable
|
||||
}
|
||||
} else {
|
||||
user = newUser
|
||||
}
|
||||
} else {
|
||||
log.Printf("[Auth] Database error during oauth login: %v", err)
|
||||
return "", nil, ErrServiceUnavailable
|
||||
}
|
||||
}
|
||||
|
||||
if !user.IsActive() {
|
||||
return "", nil, ErrUserNotActive
|
||||
}
|
||||
|
||||
// 尽力补全:当用户名为空时,使用第三方返回的用户名回填。
|
||||
if user.Username == "" && username != "" {
|
||||
user.Username = username
|
||||
if err := s.userRepo.Update(ctx, user); err != nil {
|
||||
log.Printf("[Auth] Failed to update username after oauth login: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
token, err := s.GenerateToken(user)
|
||||
if err != nil {
|
||||
return "", nil, fmt.Errorf("generate token: %w", err)
|
||||
}
|
||||
return token, user, nil
|
||||
}
|
||||
|
||||
// ValidateToken 验证JWT token并返回用户声明
|
||||
func (s *AuthService) ValidateToken(tokenString string) (*JWTClaims, error) {
|
||||
// 先做长度校验,尽早拒绝异常超长 token,降低 DoS 风险。
|
||||
@@ -336,6 +458,11 @@ func (s *AuthService) ValidateToken(tokenString string) (*JWTClaims, error) {
|
||||
|
||||
if err != nil {
|
||||
if errors.Is(err, jwt.ErrTokenExpired) {
|
||||
// token 过期但仍返回 claims(用于 RefreshToken 等场景)
|
||||
// jwt-go 在解析时即使遇到过期错误,token.Claims 仍会被填充
|
||||
if claims, ok := token.Claims.(*JWTClaims); ok {
|
||||
return claims, ErrTokenExpired
|
||||
}
|
||||
return nil, ErrTokenExpired
|
||||
}
|
||||
return nil, ErrInvalidToken
|
||||
@@ -348,6 +475,22 @@ func (s *AuthService) ValidateToken(tokenString string) (*JWTClaims, error) {
|
||||
return nil, ErrInvalidToken
|
||||
}
|
||||
|
||||
func randomHexString(byteLength int) (string, error) {
|
||||
if byteLength <= 0 {
|
||||
byteLength = 16
|
||||
}
|
||||
buf := make([]byte, byteLength)
|
||||
if _, err := rand.Read(buf); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return hex.EncodeToString(buf), nil
|
||||
}
|
||||
|
||||
func isReservedEmail(email string) bool {
|
||||
normalized := strings.ToLower(strings.TrimSpace(email))
|
||||
return strings.HasSuffix(normalized, LinuxDoConnectSyntheticEmailDomain)
|
||||
}
|
||||
|
||||
// GenerateToken 生成JWT token
|
||||
func (s *AuthService) GenerateToken(user *User) (string, error) {
|
||||
now := time.Now()
|
||||
|
||||
@@ -113,13 +113,36 @@ func TestAuthService_Register_Disabled(t *testing.T) {
|
||||
require.ErrorIs(t, err, ErrRegDisabled)
|
||||
}
|
||||
|
||||
func TestAuthService_Register_EmailVerifyRequired(t *testing.T) {
|
||||
func TestAuthService_Register_DisabledByDefault(t *testing.T) {
|
||||
// 当 settings 为 nil(设置项不存在)时,注册应该默认关闭
|
||||
repo := &userRepoStub{}
|
||||
service := newAuthService(repo, nil, nil)
|
||||
|
||||
_, _, err := service.Register(context.Background(), "user@test.com", "password")
|
||||
require.ErrorIs(t, err, ErrRegDisabled)
|
||||
}
|
||||
|
||||
func TestAuthService_Register_EmailVerifyEnabledButServiceNotConfigured(t *testing.T) {
|
||||
repo := &userRepoStub{}
|
||||
// 邮件验证开启但 emailCache 为 nil(emailService 未配置)
|
||||
service := newAuthService(repo, map[string]string{
|
||||
SettingKeyRegistrationEnabled: "true",
|
||||
SettingKeyEmailVerifyEnabled: "true",
|
||||
}, nil)
|
||||
|
||||
// 应返回服务不可用错误,而不是允许绕过验证
|
||||
_, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "any-code")
|
||||
require.ErrorIs(t, err, ErrServiceUnavailable)
|
||||
}
|
||||
|
||||
func TestAuthService_Register_EmailVerifyRequired(t *testing.T) {
|
||||
repo := &userRepoStub{}
|
||||
cache := &emailCacheStub{} // 配置 emailService
|
||||
service := newAuthService(repo, map[string]string{
|
||||
SettingKeyRegistrationEnabled: "true",
|
||||
SettingKeyEmailVerifyEnabled: "true",
|
||||
}, cache)
|
||||
|
||||
_, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "")
|
||||
require.ErrorIs(t, err, ErrEmailVerifyRequired)
|
||||
}
|
||||
@@ -141,7 +164,9 @@ func TestAuthService_Register_EmailVerifyInvalid(t *testing.T) {
|
||||
|
||||
func TestAuthService_Register_EmailExists(t *testing.T) {
|
||||
repo := &userRepoStub{exists: true}
|
||||
service := newAuthService(repo, nil, nil)
|
||||
service := newAuthService(repo, map[string]string{
|
||||
SettingKeyRegistrationEnabled: "true",
|
||||
}, nil)
|
||||
|
||||
_, _, err := service.Register(context.Background(), "user@test.com", "password")
|
||||
require.ErrorIs(t, err, ErrEmailExists)
|
||||
@@ -149,23 +174,50 @@ func TestAuthService_Register_EmailExists(t *testing.T) {
|
||||
|
||||
func TestAuthService_Register_CheckEmailError(t *testing.T) {
|
||||
repo := &userRepoStub{existsErr: errors.New("db down")}
|
||||
service := newAuthService(repo, nil, nil)
|
||||
service := newAuthService(repo, map[string]string{
|
||||
SettingKeyRegistrationEnabled: "true",
|
||||
}, nil)
|
||||
|
||||
_, _, err := service.Register(context.Background(), "user@test.com", "password")
|
||||
require.ErrorIs(t, err, ErrServiceUnavailable)
|
||||
}
|
||||
|
||||
func TestAuthService_Register_ReservedEmail(t *testing.T) {
|
||||
repo := &userRepoStub{}
|
||||
service := newAuthService(repo, map[string]string{
|
||||
SettingKeyRegistrationEnabled: "true",
|
||||
}, nil)
|
||||
|
||||
_, _, err := service.Register(context.Background(), "linuxdo-123@linuxdo-connect.invalid", "password")
|
||||
require.ErrorIs(t, err, ErrEmailReserved)
|
||||
}
|
||||
|
||||
func TestAuthService_Register_CreateError(t *testing.T) {
|
||||
repo := &userRepoStub{createErr: errors.New("create failed")}
|
||||
service := newAuthService(repo, nil, nil)
|
||||
service := newAuthService(repo, map[string]string{
|
||||
SettingKeyRegistrationEnabled: "true",
|
||||
}, nil)
|
||||
|
||||
_, _, err := service.Register(context.Background(), "user@test.com", "password")
|
||||
require.ErrorIs(t, err, ErrServiceUnavailable)
|
||||
}
|
||||
|
||||
func TestAuthService_Register_CreateEmailExistsRace(t *testing.T) {
|
||||
// 模拟竞态条件:ExistsByEmail 返回 false,但 Create 时因唯一约束失败
|
||||
repo := &userRepoStub{createErr: ErrEmailExists}
|
||||
service := newAuthService(repo, map[string]string{
|
||||
SettingKeyRegistrationEnabled: "true",
|
||||
}, nil)
|
||||
|
||||
_, _, err := service.Register(context.Background(), "user@test.com", "password")
|
||||
require.ErrorIs(t, err, ErrEmailExists)
|
||||
}
|
||||
|
||||
func TestAuthService_Register_Success(t *testing.T) {
|
||||
repo := &userRepoStub{nextID: 5}
|
||||
service := newAuthService(repo, nil, nil)
|
||||
service := newAuthService(repo, map[string]string{
|
||||
SettingKeyRegistrationEnabled: "true",
|
||||
}, nil)
|
||||
|
||||
token, user, err := service.Register(context.Background(), "user@test.com", "password")
|
||||
require.NoError(t, err)
|
||||
@@ -180,3 +232,63 @@ func TestAuthService_Register_Success(t *testing.T) {
|
||||
require.Len(t, repo.created, 1)
|
||||
require.True(t, user.CheckPassword("password"))
|
||||
}
|
||||
|
||||
func TestAuthService_ValidateToken_ExpiredReturnsClaimsWithError(t *testing.T) {
|
||||
repo := &userRepoStub{}
|
||||
service := newAuthService(repo, nil, nil)
|
||||
|
||||
// 创建用户并生成 token
|
||||
user := &User{
|
||||
ID: 1,
|
||||
Email: "test@test.com",
|
||||
Role: RoleUser,
|
||||
Status: StatusActive,
|
||||
TokenVersion: 1,
|
||||
}
|
||||
token, err := service.GenerateToken(user)
|
||||
require.NoError(t, err)
|
||||
|
||||
// 验证有效 token
|
||||
claims, err := service.ValidateToken(token)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, claims)
|
||||
require.Equal(t, int64(1), claims.UserID)
|
||||
|
||||
// 模拟过期 token(通过创建一个过期很久的 token)
|
||||
service.cfg.JWT.ExpireHour = -1 // 设置为负数使 token 立即过期
|
||||
expiredToken, err := service.GenerateToken(user)
|
||||
require.NoError(t, err)
|
||||
service.cfg.JWT.ExpireHour = 1 // 恢复
|
||||
|
||||
// 验证过期 token 应返回 claims 和 ErrTokenExpired
|
||||
claims, err = service.ValidateToken(expiredToken)
|
||||
require.ErrorIs(t, err, ErrTokenExpired)
|
||||
require.NotNil(t, claims, "claims should not be nil when token is expired")
|
||||
require.Equal(t, int64(1), claims.UserID)
|
||||
require.Equal(t, "test@test.com", claims.Email)
|
||||
}
|
||||
|
||||
func TestAuthService_RefreshToken_ExpiredTokenNoPanic(t *testing.T) {
|
||||
user := &User{
|
||||
ID: 1,
|
||||
Email: "test@test.com",
|
||||
Role: RoleUser,
|
||||
Status: StatusActive,
|
||||
TokenVersion: 1,
|
||||
}
|
||||
repo := &userRepoStub{user: user}
|
||||
service := newAuthService(repo, nil, nil)
|
||||
|
||||
// 创建过期 token
|
||||
service.cfg.JWT.ExpireHour = -1
|
||||
expiredToken, err := service.GenerateToken(user)
|
||||
require.NoError(t, err)
|
||||
service.cfg.JWT.ExpireHour = 1
|
||||
|
||||
// RefreshToken 使用过期 token 不应 panic
|
||||
require.NotPanics(t, func() {
|
||||
newToken, err := service.RefreshToken(context.Background(), expiredToken)
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, newToken)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -105,7 +105,17 @@ const (
|
||||
// Request identity patch (Claude -> Gemini systemInstruction injection)
|
||||
SettingKeyEnableIdentityPatch = "enable_identity_patch"
|
||||
SettingKeyIdentityPatchPrompt = "identity_patch_prompt"
|
||||
|
||||
// LinuxDo Connect OAuth 登录(终端用户 SSO)
|
||||
SettingKeyLinuxDoConnectEnabled = "linuxdo_connect_enabled"
|
||||
SettingKeyLinuxDoConnectClientID = "linuxdo_connect_client_id"
|
||||
SettingKeyLinuxDoConnectClientSecret = "linuxdo_connect_client_secret"
|
||||
SettingKeyLinuxDoConnectRedirectURL = "linuxdo_connect_redirect_url"
|
||||
)
|
||||
|
||||
// LinuxDoConnectSyntheticEmailDomain 是 LinuxDo Connect 用户的合成邮箱后缀(RFC 保留域名)。
|
||||
// 目的:避免第三方登录返回的用户标识与本地真实邮箱发生碰撞,进而造成账号被接管的风险。
|
||||
const LinuxDoConnectSyntheticEmailDomain = "@linuxdo-connect.invalid"
|
||||
|
||||
// AdminAPIKeyPrefix is the prefix for admin API keys (distinct from user "sk-" keys).
|
||||
const AdminAPIKeyPrefix = "admin-"
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"crypto/rand"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"log"
|
||||
"math/big"
|
||||
"net/smtp"
|
||||
"strconv"
|
||||
@@ -256,7 +257,9 @@ func (s *EmailService) VerifyCode(ctx context.Context, email, code string) error
|
||||
// 验证码不匹配
|
||||
if data.Code != code {
|
||||
data.Attempts++
|
||||
_ = s.cache.SetVerificationCode(ctx, email, data, verifyCodeTTL)
|
||||
if err := s.cache.SetVerificationCode(ctx, email, data, verifyCodeTTL); err != nil {
|
||||
log.Printf("[Email] Failed to update verification attempt count: %v", err)
|
||||
}
|
||||
if data.Attempts >= maxVerifyCodeAttempts {
|
||||
return ErrVerifyCodeMaxAttempts
|
||||
}
|
||||
@@ -264,7 +267,9 @@ func (s *EmailService) VerifyCode(ctx context.Context, email, code string) error
|
||||
}
|
||||
|
||||
// 验证成功,删除验证码
|
||||
_ = s.cache.DeleteVerificationCode(ctx, email)
|
||||
if err := s.cache.DeleteVerificationCode(ctx, email); err != nil {
|
||||
log.Printf("[Email] Failed to delete verification code after success: %v", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -166,7 +166,7 @@ func (m *mockGroupRepoForGemini) DeleteCascade(ctx context.Context, id int64) ([
|
||||
func (m *mockGroupRepoForGemini) List(ctx context.Context, params pagination.PaginationParams) ([]Group, *pagination.PaginationResult, error) {
|
||||
return nil, nil, nil
|
||||
}
|
||||
func (m *mockGroupRepoForGemini) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status string, isExclusive *bool) ([]Group, *pagination.PaginationResult, error) {
|
||||
func (m *mockGroupRepoForGemini) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status, search string, isExclusive *bool) ([]Group, *pagination.PaginationResult, error) {
|
||||
return nil, nil, nil
|
||||
}
|
||||
func (m *mockGroupRepoForGemini) ListActive(ctx context.Context) ([]Group, error) { return nil, nil }
|
||||
|
||||
@@ -120,15 +120,16 @@ func (s *GeminiOAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64
|
||||
}
|
||||
|
||||
// OAuth client selection:
|
||||
// - code_assist: always use built-in Gemini CLI OAuth client (public), regardless of configured client_id/secret.
|
||||
// - google_one: uses configured OAuth client when provided; otherwise falls back to built-in client.
|
||||
// - ai_studio: requires a user-provided OAuth client.
|
||||
// - code_assist: always use built-in Gemini CLI OAuth client (public)
|
||||
// - google_one: always use built-in Gemini CLI OAuth client (public)
|
||||
// - ai_studio: requires a user-provided OAuth client
|
||||
oauthCfg := geminicli.OAuthConfig{
|
||||
ClientID: s.cfg.Gemini.OAuth.ClientID,
|
||||
ClientSecret: s.cfg.Gemini.OAuth.ClientSecret,
|
||||
Scopes: s.cfg.Gemini.OAuth.Scopes,
|
||||
}
|
||||
if oauthType == "code_assist" {
|
||||
if oauthType == "code_assist" || oauthType == "google_one" {
|
||||
// Force use of built-in Gemini CLI OAuth client
|
||||
oauthCfg.ClientID = ""
|
||||
oauthCfg.ClientSecret = ""
|
||||
}
|
||||
@@ -576,6 +577,20 @@ func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExch
|
||||
|
||||
case "google_one":
|
||||
log.Printf("[GeminiOAuth] Processing google_one OAuth type")
|
||||
|
||||
// Google One accounts use cloudaicompanion API, which requires a project_id.
|
||||
// For personal accounts, Google auto-assigns a project_id via the LoadCodeAssist API.
|
||||
if projectID == "" {
|
||||
log.Printf("[GeminiOAuth] No project_id provided, attempting to fetch from LoadCodeAssist API...")
|
||||
var err error
|
||||
projectID, _, err = s.fetchProjectID(ctx, tokenResp.AccessToken, proxyURL)
|
||||
if err != nil {
|
||||
log.Printf("[GeminiOAuth] ERROR: Failed to fetch project_id: %v", err)
|
||||
return nil, fmt.Errorf("google One accounts require a project_id, failed to auto-detect: %w", err)
|
||||
}
|
||||
log.Printf("[GeminiOAuth] Successfully fetched project_id: %s", projectID)
|
||||
}
|
||||
|
||||
log.Printf("[GeminiOAuth] Attempting to fetch Google One tier from Drive API...")
|
||||
// Attempt to fetch Drive storage tier
|
||||
var storageInfo *geminicli.DriveStorageInfo
|
||||
|
||||
@@ -40,7 +40,7 @@ func TestGeminiOAuthService_GenerateAuthURL_RedirectURIStrategy(t *testing.T) {
|
||||
wantProjectID: "",
|
||||
},
|
||||
{
|
||||
name: "google_one uses custom client when configured and redirects to localhost",
|
||||
name: "google_one always forces built-in client even when custom client configured",
|
||||
cfg: &config.Config{
|
||||
Gemini: config.GeminiConfig{
|
||||
OAuth: config.GeminiOAuthConfig{
|
||||
@@ -50,9 +50,9 @@ func TestGeminiOAuthService_GenerateAuthURL_RedirectURIStrategy(t *testing.T) {
|
||||
},
|
||||
},
|
||||
oauthType: "google_one",
|
||||
wantClientID: "custom-client-id",
|
||||
wantRedirect: geminicli.AIStudioOAuthRedirectURI,
|
||||
wantScope: geminicli.DefaultGoogleOneScopes,
|
||||
wantClientID: geminicli.GeminiCLIOAuthClientID,
|
||||
wantRedirect: geminicli.GeminiCLIRedirectURI,
|
||||
wantScope: geminicli.DefaultCodeAssistScopes,
|
||||
wantProjectID: "",
|
||||
},
|
||||
{
|
||||
|
||||
@@ -21,7 +21,7 @@ type GroupRepository interface {
|
||||
DeleteCascade(ctx context.Context, id int64) ([]int64, error)
|
||||
|
||||
List(ctx context.Context, params pagination.PaginationParams) ([]Group, *pagination.PaginationResult, error)
|
||||
ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status string, isExclusive *bool) ([]Group, *pagination.PaginationResult, error)
|
||||
ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status, search string, isExclusive *bool) ([]Group, *pagination.PaginationResult, error)
|
||||
ListActive(ctx context.Context) ([]Group, error)
|
||||
ListActiveByPlatform(ctx context.Context, platform string) ([]Group, error)
|
||||
|
||||
|
||||
@@ -540,10 +540,19 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
|
||||
bodyModified = true
|
||||
}
|
||||
|
||||
// For OAuth accounts using ChatGPT internal API, add store: false
|
||||
// For OAuth accounts using ChatGPT internal API:
|
||||
// 1. Add store: false
|
||||
// 2. Normalize input format for Codex API compatibility
|
||||
if account.Type == AccountTypeOAuth {
|
||||
reqBody["store"] = false
|
||||
bodyModified = true
|
||||
|
||||
// Normalize input format: convert AI SDK multi-part content format to simplified format
|
||||
// AI SDK sends: {"content": [{"type": "input_text", "text": "..."}]}
|
||||
// Codex API expects: {"content": "..."}
|
||||
if normalizeInputForCodexAPI(reqBody) {
|
||||
bodyModified = true
|
||||
}
|
||||
}
|
||||
|
||||
// Re-serialize body only if modified
|
||||
@@ -1085,6 +1094,101 @@ func (s *OpenAIGatewayService) replaceModelInResponseBody(body []byte, fromModel
|
||||
return newBody
|
||||
}
|
||||
|
||||
// normalizeInputForCodexAPI converts AI SDK multi-part content format to simplified format
|
||||
// that the ChatGPT internal Codex API expects.
|
||||
//
|
||||
// AI SDK sends content as an array of typed objects:
|
||||
//
|
||||
// {"content": [{"type": "input_text", "text": "hello"}]}
|
||||
//
|
||||
// ChatGPT Codex API expects content as a simple string:
|
||||
//
|
||||
// {"content": "hello"}
|
||||
//
|
||||
// This function modifies reqBody in-place and returns true if any modification was made.
|
||||
func normalizeInputForCodexAPI(reqBody map[string]any) bool {
|
||||
input, ok := reqBody["input"]
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
// Handle case where input is a simple string (already compatible)
|
||||
if _, isString := input.(string); isString {
|
||||
return false
|
||||
}
|
||||
|
||||
// Handle case where input is an array of messages
|
||||
inputArray, ok := input.([]any)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
modified := false
|
||||
for _, item := range inputArray {
|
||||
message, ok := item.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
content, ok := message["content"]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
// If content is already a string, no conversion needed
|
||||
if _, isString := content.(string); isString {
|
||||
continue
|
||||
}
|
||||
|
||||
// If content is an array (AI SDK format), convert to string
|
||||
contentArray, ok := content.([]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
// Extract text from content array
|
||||
var textParts []string
|
||||
for _, part := range contentArray {
|
||||
partMap, ok := part.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
// Handle different content types
|
||||
partType, _ := partMap["type"].(string)
|
||||
switch partType {
|
||||
case "input_text", "text":
|
||||
// Extract text from input_text or text type
|
||||
if text, ok := partMap["text"].(string); ok {
|
||||
textParts = append(textParts, text)
|
||||
}
|
||||
case "input_image", "image":
|
||||
// For images, we need to preserve the original format
|
||||
// as ChatGPT Codex API may support images in a different way
|
||||
// For now, skip image parts (they will be lost in conversion)
|
||||
// TODO: Consider preserving image data or handling it separately
|
||||
continue
|
||||
case "input_file", "file":
|
||||
// Similar to images, file inputs may need special handling
|
||||
continue
|
||||
default:
|
||||
// For unknown types, try to extract text if available
|
||||
if text, ok := partMap["text"].(string); ok {
|
||||
textParts = append(textParts, text)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Convert content array to string
|
||||
if len(textParts) > 0 {
|
||||
message["content"] = strings.Join(textParts, "\n")
|
||||
modified = true
|
||||
}
|
||||
}
|
||||
|
||||
return modified
|
||||
}
|
||||
|
||||
// OpenAIRecordUsageInput input for recording usage
|
||||
type OpenAIRecordUsageInput struct {
|
||||
Result *OpenAIForwardResult
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
@@ -64,6 +65,7 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
|
||||
SettingKeyAPIBaseURL,
|
||||
SettingKeyContactInfo,
|
||||
SettingKeyDocURL,
|
||||
SettingKeyLinuxDoConnectEnabled,
|
||||
}
|
||||
|
||||
settings, err := s.settingRepo.GetMultiple(ctx, keys)
|
||||
@@ -71,6 +73,13 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
|
||||
return nil, fmt.Errorf("get public settings: %w", err)
|
||||
}
|
||||
|
||||
linuxDoEnabled := false
|
||||
if raw, ok := settings[SettingKeyLinuxDoConnectEnabled]; ok {
|
||||
linuxDoEnabled = raw == "true"
|
||||
} else {
|
||||
linuxDoEnabled = s.cfg != nil && s.cfg.LinuxDo.Enabled
|
||||
}
|
||||
|
||||
return &PublicSettings{
|
||||
RegistrationEnabled: settings[SettingKeyRegistrationEnabled] == "true",
|
||||
EmailVerifyEnabled: settings[SettingKeyEmailVerifyEnabled] == "true",
|
||||
@@ -82,6 +91,7 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
|
||||
APIBaseURL: settings[SettingKeyAPIBaseURL],
|
||||
ContactInfo: settings[SettingKeyContactInfo],
|
||||
DocURL: settings[SettingKeyDocURL],
|
||||
LinuxDoOAuthEnabled: linuxDoEnabled,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -111,6 +121,14 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet
|
||||
updates[SettingKeyTurnstileSecretKey] = settings.TurnstileSecretKey
|
||||
}
|
||||
|
||||
// LinuxDo Connect OAuth 登录(终端用户 SSO)
|
||||
updates[SettingKeyLinuxDoConnectEnabled] = strconv.FormatBool(settings.LinuxDoConnectEnabled)
|
||||
updates[SettingKeyLinuxDoConnectClientID] = settings.LinuxDoConnectClientID
|
||||
updates[SettingKeyLinuxDoConnectRedirectURL] = settings.LinuxDoConnectRedirectURL
|
||||
if settings.LinuxDoConnectClientSecret != "" {
|
||||
updates[SettingKeyLinuxDoConnectClientSecret] = settings.LinuxDoConnectClientSecret
|
||||
}
|
||||
|
||||
// OEM设置
|
||||
updates[SettingKeySiteName] = settings.SiteName
|
||||
updates[SettingKeySiteLogo] = settings.SiteLogo
|
||||
@@ -141,8 +159,8 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet
|
||||
func (s *SettingService) IsRegistrationEnabled(ctx context.Context) bool {
|
||||
value, err := s.settingRepo.GetValue(ctx, SettingKeyRegistrationEnabled)
|
||||
if err != nil {
|
||||
// 默认开放注册
|
||||
return true
|
||||
// 安全默认:如果设置不存在或查询出错,默认关闭注册
|
||||
return false
|
||||
}
|
||||
return value == "true"
|
||||
}
|
||||
@@ -271,6 +289,38 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
|
||||
result.SMTPPassword = settings[SettingKeySMTPPassword]
|
||||
result.TurnstileSecretKey = settings[SettingKeyTurnstileSecretKey]
|
||||
|
||||
// LinuxDo Connect 设置:
|
||||
// - 兼容 config.yaml/env(避免老部署因为未迁移到数据库设置而被意外关闭)
|
||||
// - 支持在后台“系统设置”中覆盖并持久化(存储于 DB)
|
||||
linuxDoBase := config.LinuxDoConnectConfig{}
|
||||
if s.cfg != nil {
|
||||
linuxDoBase = s.cfg.LinuxDo
|
||||
}
|
||||
|
||||
if raw, ok := settings[SettingKeyLinuxDoConnectEnabled]; ok {
|
||||
result.LinuxDoConnectEnabled = raw == "true"
|
||||
} else {
|
||||
result.LinuxDoConnectEnabled = linuxDoBase.Enabled
|
||||
}
|
||||
|
||||
if v, ok := settings[SettingKeyLinuxDoConnectClientID]; ok && strings.TrimSpace(v) != "" {
|
||||
result.LinuxDoConnectClientID = strings.TrimSpace(v)
|
||||
} else {
|
||||
result.LinuxDoConnectClientID = linuxDoBase.ClientID
|
||||
}
|
||||
|
||||
if v, ok := settings[SettingKeyLinuxDoConnectRedirectURL]; ok && strings.TrimSpace(v) != "" {
|
||||
result.LinuxDoConnectRedirectURL = strings.TrimSpace(v)
|
||||
} else {
|
||||
result.LinuxDoConnectRedirectURL = linuxDoBase.RedirectURL
|
||||
}
|
||||
|
||||
result.LinuxDoConnectClientSecret = strings.TrimSpace(settings[SettingKeyLinuxDoConnectClientSecret])
|
||||
if result.LinuxDoConnectClientSecret == "" {
|
||||
result.LinuxDoConnectClientSecret = strings.TrimSpace(linuxDoBase.ClientSecret)
|
||||
}
|
||||
result.LinuxDoConnectClientSecretConfigured = result.LinuxDoConnectClientSecret != ""
|
||||
|
||||
// Model fallback settings
|
||||
result.EnableModelFallback = settings[SettingKeyEnableModelFallback] == "true"
|
||||
result.FallbackModelAnthropic = s.getStringOrDefault(settings, SettingKeyFallbackModelAnthropic, "claude-3-5-sonnet-20241022")
|
||||
@@ -289,6 +339,99 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
|
||||
return result
|
||||
}
|
||||
|
||||
// GetLinuxDoConnectOAuthConfig 返回用于登录的“最终生效” LinuxDo Connect 配置。
|
||||
//
|
||||
// 优先级:
|
||||
// - 若对应系统设置键存在,则覆盖 config.yaml/env 的值
|
||||
// - 否则回退到 config.yaml/env 的值
|
||||
func (s *SettingService) GetLinuxDoConnectOAuthConfig(ctx context.Context) (config.LinuxDoConnectConfig, error) {
|
||||
if s == nil || s.cfg == nil {
|
||||
return config.LinuxDoConnectConfig{}, infraerrors.ServiceUnavailable("CONFIG_NOT_READY", "config not loaded")
|
||||
}
|
||||
|
||||
effective := s.cfg.LinuxDo
|
||||
|
||||
keys := []string{
|
||||
SettingKeyLinuxDoConnectEnabled,
|
||||
SettingKeyLinuxDoConnectClientID,
|
||||
SettingKeyLinuxDoConnectClientSecret,
|
||||
SettingKeyLinuxDoConnectRedirectURL,
|
||||
}
|
||||
settings, err := s.settingRepo.GetMultiple(ctx, keys)
|
||||
if err != nil {
|
||||
return config.LinuxDoConnectConfig{}, fmt.Errorf("get linuxdo connect settings: %w", err)
|
||||
}
|
||||
|
||||
if raw, ok := settings[SettingKeyLinuxDoConnectEnabled]; ok {
|
||||
effective.Enabled = raw == "true"
|
||||
}
|
||||
if v, ok := settings[SettingKeyLinuxDoConnectClientID]; ok && strings.TrimSpace(v) != "" {
|
||||
effective.ClientID = strings.TrimSpace(v)
|
||||
}
|
||||
if v, ok := settings[SettingKeyLinuxDoConnectClientSecret]; ok && strings.TrimSpace(v) != "" {
|
||||
effective.ClientSecret = strings.TrimSpace(v)
|
||||
}
|
||||
if v, ok := settings[SettingKeyLinuxDoConnectRedirectURL]; ok && strings.TrimSpace(v) != "" {
|
||||
effective.RedirectURL = strings.TrimSpace(v)
|
||||
}
|
||||
|
||||
if !effective.Enabled {
|
||||
return config.LinuxDoConnectConfig{}, infraerrors.NotFound("OAUTH_DISABLED", "oauth login is disabled")
|
||||
}
|
||||
|
||||
// 基础健壮性校验(避免把用户重定向到一个必然失败或不安全的 OAuth 流程里)。
|
||||
if strings.TrimSpace(effective.ClientID) == "" {
|
||||
return config.LinuxDoConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth client id not configured")
|
||||
}
|
||||
if strings.TrimSpace(effective.AuthorizeURL) == "" {
|
||||
return config.LinuxDoConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth authorize url not configured")
|
||||
}
|
||||
if strings.TrimSpace(effective.TokenURL) == "" {
|
||||
return config.LinuxDoConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth token url not configured")
|
||||
}
|
||||
if strings.TrimSpace(effective.UserInfoURL) == "" {
|
||||
return config.LinuxDoConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth userinfo url not configured")
|
||||
}
|
||||
if strings.TrimSpace(effective.RedirectURL) == "" {
|
||||
return config.LinuxDoConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth redirect url not configured")
|
||||
}
|
||||
if strings.TrimSpace(effective.FrontendRedirectURL) == "" {
|
||||
return config.LinuxDoConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth frontend redirect url not configured")
|
||||
}
|
||||
|
||||
if err := config.ValidateAbsoluteHTTPURL(effective.AuthorizeURL); err != nil {
|
||||
return config.LinuxDoConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth authorize url invalid")
|
||||
}
|
||||
if err := config.ValidateAbsoluteHTTPURL(effective.TokenURL); err != nil {
|
||||
return config.LinuxDoConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth token url invalid")
|
||||
}
|
||||
if err := config.ValidateAbsoluteHTTPURL(effective.UserInfoURL); err != nil {
|
||||
return config.LinuxDoConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth userinfo url invalid")
|
||||
}
|
||||
if err := config.ValidateAbsoluteHTTPURL(effective.RedirectURL); err != nil {
|
||||
return config.LinuxDoConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth redirect url invalid")
|
||||
}
|
||||
if err := config.ValidateFrontendRedirectURL(effective.FrontendRedirectURL); err != nil {
|
||||
return config.LinuxDoConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth frontend redirect url invalid")
|
||||
}
|
||||
|
||||
method := strings.ToLower(strings.TrimSpace(effective.TokenAuthMethod))
|
||||
switch method {
|
||||
case "", "client_secret_post", "client_secret_basic":
|
||||
if strings.TrimSpace(effective.ClientSecret) == "" {
|
||||
return config.LinuxDoConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth client secret not configured")
|
||||
}
|
||||
case "none":
|
||||
if !effective.UsePKCE {
|
||||
return config.LinuxDoConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth pkce must be enabled when token_auth_method=none")
|
||||
}
|
||||
default:
|
||||
return config.LinuxDoConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth token_auth_method invalid")
|
||||
}
|
||||
|
||||
return effective, nil
|
||||
}
|
||||
|
||||
// getStringOrDefault 获取字符串值或默认值
|
||||
func (s *SettingService) getStringOrDefault(settings map[string]string, key, defaultValue string) string {
|
||||
if value, ok := settings[key]; ok && value != "" {
|
||||
|
||||
@@ -18,6 +18,13 @@ type SystemSettings struct {
|
||||
TurnstileSecretKey string
|
||||
TurnstileSecretKeyConfigured bool
|
||||
|
||||
// LinuxDo Connect OAuth 登录(终端用户 SSO)
|
||||
LinuxDoConnectEnabled bool
|
||||
LinuxDoConnectClientID string
|
||||
LinuxDoConnectClientSecret string
|
||||
LinuxDoConnectClientSecretConfigured bool
|
||||
LinuxDoConnectRedirectURL string
|
||||
|
||||
SiteName string
|
||||
SiteLogo string
|
||||
SiteSubtitle string
|
||||
@@ -51,5 +58,6 @@ type PublicSettings struct {
|
||||
APIBaseURL string
|
||||
ContactInfo string
|
||||
DocURL string
|
||||
LinuxDoOAuthEnabled bool
|
||||
Version string
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user