merge upstream main
This commit is contained in:
@@ -410,6 +410,22 @@ func (a *Account) GetExtraString(key string) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func (a *Account) GetClaudeUserID() string {
|
||||
if v := strings.TrimSpace(a.GetExtraString("claude_user_id")); v != "" {
|
||||
return v
|
||||
}
|
||||
if v := strings.TrimSpace(a.GetExtraString("anthropic_user_id")); v != "" {
|
||||
return v
|
||||
}
|
||||
if v := strings.TrimSpace(a.GetCredential("claude_user_id")); v != "" {
|
||||
return v
|
||||
}
|
||||
if v := strings.TrimSpace(a.GetCredential("anthropic_user_id")); v != "" {
|
||||
return v
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (a *Account) IsCustomErrorCodesEnabled() bool {
|
||||
if a.Type != AccountTypeAPIKey || a.Credentials == nil {
|
||||
return false
|
||||
|
||||
@@ -123,7 +123,7 @@ func createTestPayload(modelID string) (map[string]any, error) {
|
||||
"system": []map[string]any{
|
||||
{
|
||||
"type": "text",
|
||||
"text": "You are Claude Code, Anthropic's official CLI for Claude.",
|
||||
"text": claudeCodeSystemPrompt,
|
||||
"cache_control": map[string]string{
|
||||
"type": "ephemeral",
|
||||
},
|
||||
|
||||
@@ -115,6 +115,8 @@ type CreateGroupInput struct {
|
||||
MCPXMLInject *bool
|
||||
// 支持的模型系列(仅 antigravity 平台使用)
|
||||
SupportedModelScopes []string
|
||||
// 从指定分组复制账号(创建分组后在同一事务内绑定)
|
||||
CopyAccountsFromGroupIDs []int64
|
||||
}
|
||||
|
||||
type UpdateGroupInput struct {
|
||||
@@ -142,6 +144,8 @@ type UpdateGroupInput struct {
|
||||
MCPXMLInject *bool
|
||||
// 支持的模型系列(仅 antigravity 平台使用)
|
||||
SupportedModelScopes *[]string
|
||||
// 从指定分组复制账号(同步操作:先清空当前分组的账号绑定,再绑定源分组的账号)
|
||||
CopyAccountsFromGroupIDs []int64
|
||||
}
|
||||
|
||||
type CreateAccountInput struct {
|
||||
@@ -598,6 +602,38 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
|
||||
mcpXMLInject = *input.MCPXMLInject
|
||||
}
|
||||
|
||||
// 如果指定了复制账号的源分组,先获取账号 ID 列表
|
||||
var accountIDsToCopy []int64
|
||||
if len(input.CopyAccountsFromGroupIDs) > 0 {
|
||||
// 去重源分组 IDs
|
||||
seen := make(map[int64]struct{})
|
||||
uniqueSourceGroupIDs := make([]int64, 0, len(input.CopyAccountsFromGroupIDs))
|
||||
for _, srcGroupID := range input.CopyAccountsFromGroupIDs {
|
||||
if _, exists := seen[srcGroupID]; !exists {
|
||||
seen[srcGroupID] = struct{}{}
|
||||
uniqueSourceGroupIDs = append(uniqueSourceGroupIDs, srcGroupID)
|
||||
}
|
||||
}
|
||||
|
||||
// 校验源分组的平台是否与新分组一致
|
||||
for _, srcGroupID := range uniqueSourceGroupIDs {
|
||||
srcGroup, err := s.groupRepo.GetByIDLite(ctx, srcGroupID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("source group %d not found: %w", srcGroupID, err)
|
||||
}
|
||||
if srcGroup.Platform != platform {
|
||||
return nil, fmt.Errorf("source group %d platform mismatch: expected %s, got %s", srcGroupID, platform, srcGroup.Platform)
|
||||
}
|
||||
}
|
||||
|
||||
// 获取所有源分组的账号(去重)
|
||||
var err error
|
||||
accountIDsToCopy, err = s.groupRepo.GetAccountIDsByGroupIDs(ctx, uniqueSourceGroupIDs)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get accounts from source groups: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
group := &Group{
|
||||
Name: input.Name,
|
||||
Description: input.Description,
|
||||
@@ -622,6 +658,15 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
|
||||
if err := s.groupRepo.Create(ctx, group); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 如果有需要复制的账号,绑定到新分组
|
||||
if len(accountIDsToCopy) > 0 {
|
||||
if err := s.groupRepo.BindAccountsToGroup(ctx, group.ID, accountIDsToCopy); err != nil {
|
||||
return nil, fmt.Errorf("failed to bind accounts to new group: %w", err)
|
||||
}
|
||||
group.AccountCount = int64(len(accountIDsToCopy))
|
||||
}
|
||||
|
||||
return group, nil
|
||||
}
|
||||
|
||||
@@ -810,6 +855,54 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd
|
||||
if err := s.groupRepo.Update(ctx, group); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 如果指定了复制账号的源分组,同步绑定(替换当前分组的账号)
|
||||
if len(input.CopyAccountsFromGroupIDs) > 0 {
|
||||
// 去重源分组 IDs
|
||||
seen := make(map[int64]struct{})
|
||||
uniqueSourceGroupIDs := make([]int64, 0, len(input.CopyAccountsFromGroupIDs))
|
||||
for _, srcGroupID := range input.CopyAccountsFromGroupIDs {
|
||||
// 校验:源分组不能是自身
|
||||
if srcGroupID == id {
|
||||
return nil, fmt.Errorf("cannot copy accounts from self")
|
||||
}
|
||||
// 去重
|
||||
if _, exists := seen[srcGroupID]; !exists {
|
||||
seen[srcGroupID] = struct{}{}
|
||||
uniqueSourceGroupIDs = append(uniqueSourceGroupIDs, srcGroupID)
|
||||
}
|
||||
}
|
||||
|
||||
// 校验源分组的平台是否与当前分组一致
|
||||
for _, srcGroupID := range uniqueSourceGroupIDs {
|
||||
srcGroup, err := s.groupRepo.GetByIDLite(ctx, srcGroupID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("source group %d not found: %w", srcGroupID, err)
|
||||
}
|
||||
if srcGroup.Platform != group.Platform {
|
||||
return nil, fmt.Errorf("source group %d platform mismatch: expected %s, got %s", srcGroupID, group.Platform, srcGroup.Platform)
|
||||
}
|
||||
}
|
||||
|
||||
// 获取所有源分组的账号(去重)
|
||||
accountIDsToCopy, err := s.groupRepo.GetAccountIDsByGroupIDs(ctx, uniqueSourceGroupIDs)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get accounts from source groups: %w", err)
|
||||
}
|
||||
|
||||
// 先清空当前分组的所有账号绑定
|
||||
if _, err := s.groupRepo.DeleteAccountGroupsByGroupID(ctx, id); err != nil {
|
||||
return nil, fmt.Errorf("failed to clear existing account bindings: %w", err)
|
||||
}
|
||||
|
||||
// 再绑定源分组的账号
|
||||
if len(accountIDsToCopy) > 0 {
|
||||
if err := s.groupRepo.BindAccountsToGroup(ctx, id, accountIDsToCopy); err != nil {
|
||||
return nil, fmt.Errorf("failed to bind accounts to group: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if s.authCacheInvalidator != nil {
|
||||
s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, id)
|
||||
}
|
||||
|
||||
@@ -164,6 +164,14 @@ func (s *groupRepoStub) DeleteAccountGroupsByGroupID(ctx context.Context, groupI
|
||||
panic("unexpected DeleteAccountGroupsByGroupID call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStub) BindAccountsToGroup(ctx context.Context, groupID int64, accountIDs []int64) error {
|
||||
panic("unexpected BindAccountsToGroup call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStub) GetAccountIDsByGroupIDs(ctx context.Context, groupIDs []int64) ([]int64, error) {
|
||||
panic("unexpected GetAccountIDsByGroupIDs call")
|
||||
}
|
||||
|
||||
type proxyRepoStub struct {
|
||||
deleteErr error
|
||||
countErr error
|
||||
|
||||
@@ -108,6 +108,14 @@ func (s *groupRepoStubForAdmin) DeleteAccountGroupsByGroupID(_ context.Context,
|
||||
panic("unexpected DeleteAccountGroupsByGroupID call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForAdmin) BindAccountsToGroup(_ context.Context, _ int64, _ []int64) error {
|
||||
panic("unexpected BindAccountsToGroup call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForAdmin) GetAccountIDsByGroupIDs(_ context.Context, _ []int64) ([]int64, error) {
|
||||
panic("unexpected GetAccountIDsByGroupIDs call")
|
||||
}
|
||||
|
||||
// TestAdminService_CreateGroup_WithImagePricing 测试创建分组时 ImagePrice 字段正确传递
|
||||
func TestAdminService_CreateGroup_WithImagePricing(t *testing.T) {
|
||||
repo := &groupRepoStubForAdmin{}
|
||||
@@ -379,6 +387,14 @@ func (s *groupRepoStubForFallbackCycle) DeleteAccountGroupsByGroupID(_ context.C
|
||||
panic("unexpected DeleteAccountGroupsByGroupID call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForFallbackCycle) BindAccountsToGroup(_ context.Context, _ int64, _ []int64) error {
|
||||
panic("unexpected BindAccountsToGroup call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForFallbackCycle) GetAccountIDsByGroupIDs(_ context.Context, _ []int64) ([]int64, error) {
|
||||
panic("unexpected GetAccountIDsByGroupIDs call")
|
||||
}
|
||||
|
||||
type groupRepoStubForInvalidRequestFallback struct {
|
||||
groups map[int64]*Group
|
||||
created *Group
|
||||
@@ -748,4 +764,4 @@ func TestAdminService_UpdateGroup_InvalidRequestFallbackAllowsAntigravity(t *tes
|
||||
require.NotNil(t, group)
|
||||
require.NotNil(t, repo.updated)
|
||||
require.Equal(t, fallbackID, *repo.updated.FallbackGroupIDOnInvalidRequest)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -302,13 +302,11 @@ func logPrefix(sessionID, accountName string) string {
|
||||
}
|
||||
|
||||
// Antigravity 直接支持的模型(精确匹配透传)
|
||||
// 注意:gemini-2.5 系列已移除,统一映射到 gemini-3 系列
|
||||
var antigravitySupportedModels = map[string]bool{
|
||||
"claude-opus-4-5-thinking": true,
|
||||
"claude-sonnet-4-5": true,
|
||||
"claude-sonnet-4-5-thinking": true,
|
||||
"gemini-2.5-flash": true,
|
||||
"gemini-2.5-flash-lite": true,
|
||||
"gemini-2.5-flash-thinking": true,
|
||||
"gemini-3-flash": true,
|
||||
"gemini-3-pro-low": true,
|
||||
"gemini-3-pro-high": true,
|
||||
@@ -317,23 +315,32 @@ var antigravitySupportedModels = map[string]bool{
|
||||
|
||||
// Antigravity 前缀映射表(按前缀长度降序排列,确保最长匹配优先)
|
||||
// 用于处理模型版本号变化(如 -20251111, -thinking, -preview 等后缀)
|
||||
// gemini-2.5 系列统一映射到 gemini-3 系列(Antigravity 上游不再支持 2.5)
|
||||
var antigravityPrefixMapping = []struct {
|
||||
prefix string
|
||||
target string
|
||||
}{
|
||||
// 长前缀优先
|
||||
{"gemini-2.5-flash-image", "gemini-3-pro-image"}, // gemini-2.5-flash-image → 3-pro-image
|
||||
{"gemini-3-pro-image", "gemini-3-pro-image"}, // gemini-3-pro-image-preview 等
|
||||
{"gemini-3-flash", "gemini-3-flash"}, // gemini-3-flash-preview 等 → gemini-3-flash
|
||||
{"claude-3-5-sonnet", "claude-sonnet-4-5"}, // 旧版 claude-3-5-sonnet-xxx
|
||||
{"claude-sonnet-4-5", "claude-sonnet-4-5"}, // claude-sonnet-4-5-xxx
|
||||
{"claude-haiku-4-5", "claude-sonnet-4-5"}, // claude-haiku-4-5-xxx → sonnet
|
||||
// gemini-2.5 → gemini-3 映射(长前缀优先)
|
||||
{"gemini-2.5-flash-thinking", "gemini-3-flash"}, // gemini-2.5-flash-thinking → gemini-3-flash
|
||||
{"gemini-2.5-flash-image", "gemini-3-pro-image"}, // gemini-2.5-flash-image → gemini-3-pro-image
|
||||
{"gemini-2.5-flash-lite", "gemini-3-flash"}, // gemini-2.5-flash-lite → gemini-3-flash
|
||||
{"gemini-2.5-flash", "gemini-3-flash"}, // gemini-2.5-flash → gemini-3-flash
|
||||
{"gemini-2.5-pro-preview", "gemini-3-pro-high"}, // gemini-2.5-pro-preview → gemini-3-pro-high
|
||||
{"gemini-2.5-pro-exp", "gemini-3-pro-high"}, // gemini-2.5-pro-exp → gemini-3-pro-high
|
||||
{"gemini-2.5-pro", "gemini-3-pro-high"}, // gemini-2.5-pro → gemini-3-pro-high
|
||||
// gemini-3 前缀映射
|
||||
{"gemini-3-pro-image", "gemini-3-pro-image"}, // gemini-3-pro-image-preview 等
|
||||
{"gemini-3-flash", "gemini-3-flash"}, // gemini-3-flash-preview 等 → gemini-3-flash
|
||||
{"gemini-3-pro", "gemini-3-pro-high"}, // gemini-3-pro, gemini-3-pro-preview 等
|
||||
// Claude 映射
|
||||
{"claude-3-5-sonnet", "claude-sonnet-4-5"}, // 旧版 claude-3-5-sonnet-xxx
|
||||
{"claude-sonnet-4-5", "claude-sonnet-4-5"}, // claude-sonnet-4-5-xxx
|
||||
{"claude-haiku-4-5", "claude-sonnet-4-5"}, // claude-haiku-4-5-xxx → sonnet
|
||||
{"claude-opus-4-5", "claude-opus-4-5-thinking"},
|
||||
{"claude-3-haiku", "claude-sonnet-4-5"}, // 旧版 claude-3-haiku-xxx → sonnet
|
||||
{"claude-sonnet-4", "claude-sonnet-4-5"},
|
||||
{"claude-haiku-4", "claude-sonnet-4-5"}, // → sonnet
|
||||
{"claude-opus-4", "claude-opus-4-5-thinking"},
|
||||
{"gemini-3-pro", "gemini-3-pro-high"}, // gemini-3-pro, gemini-3-pro-preview 等
|
||||
}
|
||||
|
||||
// AntigravityGatewayService 处理 Antigravity 平台的 API 转发
|
||||
|
||||
@@ -103,6 +103,10 @@ func (s *httpUpstreamStub) Do(_ *http.Request, _ string, _ int64, _ int) (*http.
|
||||
return s.resp, s.err
|
||||
}
|
||||
|
||||
func (s *httpUpstreamStub) DoWithTLS(_ *http.Request, _ string, _ int64, _ int, _ bool) (*http.Response, error) {
|
||||
return s.resp, s.err
|
||||
}
|
||||
|
||||
func TestAntigravityGatewayService_Forward_PromptTooLong(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
writer := httptest.NewRecorder()
|
||||
|
||||
@@ -134,18 +134,18 @@ func TestAntigravityGatewayService_GetMappedModel(t *testing.T) {
|
||||
expected: "claude-sonnet-4-5",
|
||||
},
|
||||
|
||||
// 3. Gemini 透传
|
||||
// 3. Gemini 2.5 → 3 映射
|
||||
{
|
||||
name: "Gemini透传 - gemini-2.5-flash",
|
||||
name: "Gemini映射 - gemini-2.5-flash → gemini-3-flash",
|
||||
requestedModel: "gemini-2.5-flash",
|
||||
accountMapping: nil,
|
||||
expected: "gemini-2.5-flash",
|
||||
expected: "gemini-3-flash",
|
||||
},
|
||||
{
|
||||
name: "Gemini透传 - gemini-2.5-pro",
|
||||
name: "Gemini映射 - gemini-2.5-pro → gemini-3-pro-high",
|
||||
requestedModel: "gemini-2.5-pro",
|
||||
accountMapping: nil,
|
||||
expected: "gemini-2.5-pro",
|
||||
expected: "gemini-3-pro-high",
|
||||
},
|
||||
{
|
||||
name: "Gemini透传 - gemini-future-model",
|
||||
|
||||
@@ -19,17 +19,19 @@ import (
|
||||
)
|
||||
|
||||
var (
|
||||
ErrInvalidCredentials = infraerrors.Unauthorized("INVALID_CREDENTIALS", "invalid email or password")
|
||||
ErrUserNotActive = infraerrors.Forbidden("USER_NOT_ACTIVE", "user is not active")
|
||||
ErrEmailExists = infraerrors.Conflict("EMAIL_EXISTS", "email already exists")
|
||||
ErrEmailReserved = infraerrors.BadRequest("EMAIL_RESERVED", "email is reserved")
|
||||
ErrInvalidToken = infraerrors.Unauthorized("INVALID_TOKEN", "invalid token")
|
||||
ErrTokenExpired = infraerrors.Unauthorized("TOKEN_EXPIRED", "token has expired")
|
||||
ErrTokenTooLarge = infraerrors.BadRequest("TOKEN_TOO_LARGE", "token too large")
|
||||
ErrTokenRevoked = infraerrors.Unauthorized("TOKEN_REVOKED", "token has been revoked")
|
||||
ErrEmailVerifyRequired = infraerrors.BadRequest("EMAIL_VERIFY_REQUIRED", "email verification is required")
|
||||
ErrRegDisabled = infraerrors.Forbidden("REGISTRATION_DISABLED", "registration is currently disabled")
|
||||
ErrServiceUnavailable = infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "service temporarily unavailable")
|
||||
ErrInvalidCredentials = infraerrors.Unauthorized("INVALID_CREDENTIALS", "invalid email or password")
|
||||
ErrUserNotActive = infraerrors.Forbidden("USER_NOT_ACTIVE", "user is not active")
|
||||
ErrEmailExists = infraerrors.Conflict("EMAIL_EXISTS", "email already exists")
|
||||
ErrEmailReserved = infraerrors.BadRequest("EMAIL_RESERVED", "email is reserved")
|
||||
ErrInvalidToken = infraerrors.Unauthorized("INVALID_TOKEN", "invalid token")
|
||||
ErrTokenExpired = infraerrors.Unauthorized("TOKEN_EXPIRED", "token has expired")
|
||||
ErrTokenTooLarge = infraerrors.BadRequest("TOKEN_TOO_LARGE", "token too large")
|
||||
ErrTokenRevoked = infraerrors.Unauthorized("TOKEN_REVOKED", "token has been revoked")
|
||||
ErrEmailVerifyRequired = infraerrors.BadRequest("EMAIL_VERIFY_REQUIRED", "email verification is required")
|
||||
ErrRegDisabled = infraerrors.Forbidden("REGISTRATION_DISABLED", "registration is currently disabled")
|
||||
ErrServiceUnavailable = infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "service temporarily unavailable")
|
||||
ErrInvitationCodeRequired = infraerrors.BadRequest("INVITATION_CODE_REQUIRED", "invitation code is required")
|
||||
ErrInvitationCodeInvalid = infraerrors.BadRequest("INVITATION_CODE_INVALID", "invalid or used invitation code")
|
||||
)
|
||||
|
||||
// maxTokenLength 限制 token 大小,避免超长 header 触发解析时的异常内存分配。
|
||||
@@ -47,6 +49,7 @@ type JWTClaims struct {
|
||||
// AuthService 认证服务
|
||||
type AuthService struct {
|
||||
userRepo UserRepository
|
||||
redeemRepo RedeemCodeRepository
|
||||
cfg *config.Config
|
||||
settingService *SettingService
|
||||
emailService *EmailService
|
||||
@@ -58,6 +61,7 @@ type AuthService struct {
|
||||
// NewAuthService 创建认证服务实例
|
||||
func NewAuthService(
|
||||
userRepo UserRepository,
|
||||
redeemRepo RedeemCodeRepository,
|
||||
cfg *config.Config,
|
||||
settingService *SettingService,
|
||||
emailService *EmailService,
|
||||
@@ -67,6 +71,7 @@ func NewAuthService(
|
||||
) *AuthService {
|
||||
return &AuthService{
|
||||
userRepo: userRepo,
|
||||
redeemRepo: redeemRepo,
|
||||
cfg: cfg,
|
||||
settingService: settingService,
|
||||
emailService: emailService,
|
||||
@@ -78,11 +83,11 @@ func NewAuthService(
|
||||
|
||||
// Register 用户注册,返回token和用户
|
||||
func (s *AuthService) Register(ctx context.Context, email, password string) (string, *User, error) {
|
||||
return s.RegisterWithVerification(ctx, email, password, "", "")
|
||||
return s.RegisterWithVerification(ctx, email, password, "", "", "")
|
||||
}
|
||||
|
||||
// RegisterWithVerification 用户注册(支持邮件验证和优惠码),返回token和用户
|
||||
func (s *AuthService) RegisterWithVerification(ctx context.Context, email, password, verifyCode, promoCode string) (string, *User, error) {
|
||||
// RegisterWithVerification 用户注册(支持邮件验证、优惠码和邀请码),返回token和用户
|
||||
func (s *AuthService) RegisterWithVerification(ctx context.Context, email, password, verifyCode, promoCode, invitationCode string) (string, *User, error) {
|
||||
// 检查是否开放注册(默认关闭:settingService 未配置时不允许注册)
|
||||
if s.settingService == nil || !s.settingService.IsRegistrationEnabled(ctx) {
|
||||
return "", nil, ErrRegDisabled
|
||||
@@ -93,6 +98,26 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw
|
||||
return "", nil, ErrEmailReserved
|
||||
}
|
||||
|
||||
// 检查是否需要邀请码
|
||||
var invitationRedeemCode *RedeemCode
|
||||
if s.settingService != nil && s.settingService.IsInvitationCodeEnabled(ctx) {
|
||||
if invitationCode == "" {
|
||||
return "", nil, ErrInvitationCodeRequired
|
||||
}
|
||||
// 验证邀请码
|
||||
redeemCode, err := s.redeemRepo.GetByCode(ctx, invitationCode)
|
||||
if err != nil {
|
||||
log.Printf("[Auth] Invalid invitation code: %s, error: %v", invitationCode, err)
|
||||
return "", nil, ErrInvitationCodeInvalid
|
||||
}
|
||||
// 检查类型和状态
|
||||
if redeemCode.Type != RedeemTypeInvitation || redeemCode.Status != StatusUnused {
|
||||
log.Printf("[Auth] Invitation code invalid: type=%s, status=%s", redeemCode.Type, redeemCode.Status)
|
||||
return "", nil, ErrInvitationCodeInvalid
|
||||
}
|
||||
invitationRedeemCode = redeemCode
|
||||
}
|
||||
|
||||
// 检查是否需要邮件验证
|
||||
if s.settingService != nil && s.settingService.IsEmailVerifyEnabled(ctx) {
|
||||
// 如果邮件验证已开启但邮件服务未配置,拒绝注册
|
||||
@@ -153,6 +178,13 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw
|
||||
return "", nil, ErrServiceUnavailable
|
||||
}
|
||||
|
||||
// 标记邀请码为已使用(如果使用了邀请码)
|
||||
if invitationRedeemCode != nil {
|
||||
if err := s.redeemRepo.Use(ctx, invitationRedeemCode.ID, user.ID); err != nil {
|
||||
// 邀请码标记失败不影响注册,只记录日志
|
||||
log.Printf("[Auth] Failed to mark invitation code as used for user %d: %v", user.ID, err)
|
||||
}
|
||||
}
|
||||
// 应用优惠码(如果提供且功能已启用)
|
||||
if promoCode != "" && s.promoService != nil && s.settingService != nil && s.settingService.IsPromoCodeEnabled(ctx) {
|
||||
if err := s.promoService.ApplyPromoCode(ctx, user.ID, promoCode); err != nil {
|
||||
|
||||
@@ -115,6 +115,7 @@ func newAuthService(repo *userRepoStub, settings map[string]string, emailCache E
|
||||
|
||||
return NewAuthService(
|
||||
repo,
|
||||
nil, // redeemRepo
|
||||
cfg,
|
||||
settingService,
|
||||
emailService,
|
||||
@@ -152,7 +153,7 @@ func TestAuthService_Register_EmailVerifyEnabledButServiceNotConfigured(t *testi
|
||||
}, nil)
|
||||
|
||||
// 应返回服务不可用错误,而不是允许绕过验证
|
||||
_, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "any-code", "")
|
||||
_, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "any-code", "", "")
|
||||
require.ErrorIs(t, err, ErrServiceUnavailable)
|
||||
}
|
||||
|
||||
@@ -164,7 +165,7 @@ func TestAuthService_Register_EmailVerifyRequired(t *testing.T) {
|
||||
SettingKeyEmailVerifyEnabled: "true",
|
||||
}, cache)
|
||||
|
||||
_, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "", "")
|
||||
_, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "", "", "")
|
||||
require.ErrorIs(t, err, ErrEmailVerifyRequired)
|
||||
}
|
||||
|
||||
@@ -178,7 +179,7 @@ func TestAuthService_Register_EmailVerifyInvalid(t *testing.T) {
|
||||
SettingKeyEmailVerifyEnabled: "true",
|
||||
}, cache)
|
||||
|
||||
_, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "wrong", "")
|
||||
_, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "wrong", "", "")
|
||||
require.ErrorIs(t, err, ErrInvalidVerifyCode)
|
||||
require.ErrorContains(t, err, "verify code")
|
||||
}
|
||||
|
||||
@@ -241,6 +241,76 @@ func (s *BillingService) CalculateCostWithConfig(model string, tokens UsageToken
|
||||
return s.CalculateCost(model, tokens, multiplier)
|
||||
}
|
||||
|
||||
// CalculateCostWithLongContext 计算费用,支持长上下文双倍计费
|
||||
// threshold: 阈值(如 200000),超过此值的部分按 extraMultiplier 倍计费
|
||||
// extraMultiplier: 超出部分的倍率(如 2.0 表示双倍)
|
||||
//
|
||||
// 示例:缓存 210k + 输入 10k = 220k,阈值 200k,倍率 2.0
|
||||
// 拆分为:范围内 (200k, 0) + 范围外 (10k, 10k)
|
||||
// 范围内正常计费,范围外 × 2 计费
|
||||
func (s *BillingService) CalculateCostWithLongContext(model string, tokens UsageTokens, rateMultiplier float64, threshold int, extraMultiplier float64) (*CostBreakdown, error) {
|
||||
// 未启用长上下文计费,直接走正常计费
|
||||
if threshold <= 0 || extraMultiplier <= 1 {
|
||||
return s.CalculateCost(model, tokens, rateMultiplier)
|
||||
}
|
||||
|
||||
// 计算总输入 token(缓存读取 + 新输入)
|
||||
total := tokens.CacheReadTokens + tokens.InputTokens
|
||||
if total <= threshold {
|
||||
return s.CalculateCost(model, tokens, rateMultiplier)
|
||||
}
|
||||
|
||||
// 拆分成范围内和范围外
|
||||
var inRangeCacheTokens, inRangeInputTokens int
|
||||
var outRangeCacheTokens, outRangeInputTokens int
|
||||
|
||||
if tokens.CacheReadTokens >= threshold {
|
||||
// 缓存已超过阈值:范围内只有缓存,范围外是超出的缓存+全部输入
|
||||
inRangeCacheTokens = threshold
|
||||
inRangeInputTokens = 0
|
||||
outRangeCacheTokens = tokens.CacheReadTokens - threshold
|
||||
outRangeInputTokens = tokens.InputTokens
|
||||
} else {
|
||||
// 缓存未超过阈值:范围内是全部缓存+部分输入,范围外是剩余输入
|
||||
inRangeCacheTokens = tokens.CacheReadTokens
|
||||
inRangeInputTokens = threshold - tokens.CacheReadTokens
|
||||
outRangeCacheTokens = 0
|
||||
outRangeInputTokens = tokens.InputTokens - inRangeInputTokens
|
||||
}
|
||||
|
||||
// 范围内部分:正常计费
|
||||
inRangeTokens := UsageTokens{
|
||||
InputTokens: inRangeInputTokens,
|
||||
OutputTokens: tokens.OutputTokens, // 输出只算一次
|
||||
CacheCreationTokens: tokens.CacheCreationTokens,
|
||||
CacheReadTokens: inRangeCacheTokens,
|
||||
}
|
||||
inRangeCost, err := s.CalculateCost(model, inRangeTokens, rateMultiplier)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 范围外部分:× extraMultiplier 计费
|
||||
outRangeTokens := UsageTokens{
|
||||
InputTokens: outRangeInputTokens,
|
||||
CacheReadTokens: outRangeCacheTokens,
|
||||
}
|
||||
outRangeCost, err := s.CalculateCost(model, outRangeTokens, rateMultiplier*extraMultiplier)
|
||||
if err != nil {
|
||||
return inRangeCost, nil // 出错时返回范围内成本
|
||||
}
|
||||
|
||||
// 合并成本
|
||||
return &CostBreakdown{
|
||||
InputCost: inRangeCost.InputCost + outRangeCost.InputCost,
|
||||
OutputCost: inRangeCost.OutputCost,
|
||||
CacheCreationCost: inRangeCost.CacheCreationCost,
|
||||
CacheReadCost: inRangeCost.CacheReadCost + outRangeCost.CacheReadCost,
|
||||
TotalCost: inRangeCost.TotalCost + outRangeCost.TotalCost,
|
||||
ActualCost: inRangeCost.ActualCost + outRangeCost.ActualCost,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ListSupportedModels 列出所有支持的模型(现在总是返回true,因为有模糊匹配)
|
||||
func (s *BillingService) ListSupportedModels() []string {
|
||||
models := make([]string, 0)
|
||||
|
||||
@@ -39,6 +39,7 @@ const (
|
||||
RedeemTypeBalance = domain.RedeemTypeBalance
|
||||
RedeemTypeConcurrency = domain.RedeemTypeConcurrency
|
||||
RedeemTypeSubscription = domain.RedeemTypeSubscription
|
||||
RedeemTypeInvitation = domain.RedeemTypeInvitation
|
||||
)
|
||||
|
||||
// PromoCode status constants
|
||||
@@ -72,10 +73,11 @@ const LinuxDoConnectSyntheticEmailDomain = "@linuxdo-connect.invalid"
|
||||
// Setting keys
|
||||
const (
|
||||
// 注册设置
|
||||
SettingKeyRegistrationEnabled = "registration_enabled" // 是否开放注册
|
||||
SettingKeyEmailVerifyEnabled = "email_verify_enabled" // 是否开启邮件验证
|
||||
SettingKeyPromoCodeEnabled = "promo_code_enabled" // 是否启用优惠码功能
|
||||
SettingKeyPasswordResetEnabled = "password_reset_enabled" // 是否启用忘记密码功能(需要先开启邮件验证)
|
||||
SettingKeyRegistrationEnabled = "registration_enabled" // 是否开放注册
|
||||
SettingKeyEmailVerifyEnabled = "email_verify_enabled" // 是否开启邮件验证
|
||||
SettingKeyPromoCodeEnabled = "promo_code_enabled" // 是否启用优惠码功能
|
||||
SettingKeyPasswordResetEnabled = "password_reset_enabled" // 是否启用忘记密码功能(需要先开启邮件验证)
|
||||
SettingKeyInvitationCodeEnabled = "invitation_code_enabled" // 是否启用邀请码注册
|
||||
|
||||
// 邮件服务设置
|
||||
SettingKeySMTPHost = "smtp_host" // SMTP服务器地址
|
||||
|
||||
23
backend/internal/service/gateway_beta_test.go
Normal file
23
backend/internal/service/gateway_beta_test.go
Normal file
@@ -0,0 +1,23 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestMergeAnthropicBeta(t *testing.T) {
|
||||
got := mergeAnthropicBeta(
|
||||
[]string{"oauth-2025-04-20", "interleaved-thinking-2025-05-14"},
|
||||
"foo, oauth-2025-04-20,bar, foo",
|
||||
)
|
||||
require.Equal(t, "oauth-2025-04-20,interleaved-thinking-2025-05-14,foo,bar", got)
|
||||
}
|
||||
|
||||
func TestMergeAnthropicBeta_EmptyIncoming(t *testing.T) {
|
||||
got := mergeAnthropicBeta(
|
||||
[]string{"oauth-2025-04-20", "interleaved-thinking-2025-05-14"},
|
||||
"",
|
||||
)
|
||||
require.Equal(t, "oauth-2025-04-20,interleaved-thinking-2025-05-14", got)
|
||||
}
|
||||
@@ -266,6 +266,14 @@ func (m *mockGroupRepoForGateway) DeleteAccountGroupsByGroupID(ctx context.Conte
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (m *mockGroupRepoForGateway) BindAccountsToGroup(ctx context.Context, groupID int64, accountIDs []int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockGroupRepoForGateway) GetAccountIDsByGroupIDs(ctx context.Context, groupIDs []int64) ([]int64, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func ptr[T any](v T) *T {
|
||||
return &v
|
||||
}
|
||||
|
||||
62
backend/internal/service/gateway_oauth_metadata_test.go
Normal file
62
backend/internal/service/gateway_oauth_metadata_test.go
Normal file
@@ -0,0 +1,62 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"regexp"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestBuildOAuthMetadataUserID_FallbackWithoutAccountUUID(t *testing.T) {
|
||||
svc := &GatewayService{}
|
||||
|
||||
parsed := &ParsedRequest{
|
||||
Model: "claude-sonnet-4-5",
|
||||
Stream: true,
|
||||
MetadataUserID: "",
|
||||
System: nil,
|
||||
Messages: nil,
|
||||
}
|
||||
|
||||
account := &Account{
|
||||
ID: 123,
|
||||
Type: AccountTypeOAuth,
|
||||
Extra: map[string]any{}, // intentionally missing account_uuid / claude_user_id
|
||||
}
|
||||
|
||||
fp := &Fingerprint{ClientID: "deadbeef"} // should be used as user id in legacy format
|
||||
|
||||
got := svc.buildOAuthMetadataUserID(parsed, account, fp)
|
||||
require.NotEmpty(t, got)
|
||||
|
||||
// Legacy format: user_{client}_account__session_{uuid}
|
||||
re := regexp.MustCompile(`^user_[a-zA-Z0-9]+_account__session_[a-f0-9-]{36}$`)
|
||||
require.True(t, re.MatchString(got), "unexpected user_id format: %s", got)
|
||||
}
|
||||
|
||||
func TestBuildOAuthMetadataUserID_UsesAccountUUIDWhenPresent(t *testing.T) {
|
||||
svc := &GatewayService{}
|
||||
|
||||
parsed := &ParsedRequest{
|
||||
Model: "claude-sonnet-4-5",
|
||||
Stream: true,
|
||||
MetadataUserID: "",
|
||||
}
|
||||
|
||||
account := &Account{
|
||||
ID: 123,
|
||||
Type: AccountTypeOAuth,
|
||||
Extra: map[string]any{
|
||||
"account_uuid": "acc-uuid",
|
||||
"claude_user_id": "clientid123",
|
||||
"anthropic_user_id": "",
|
||||
},
|
||||
}
|
||||
|
||||
got := svc.buildOAuthMetadataUserID(parsed, account, nil)
|
||||
require.NotEmpty(t, got)
|
||||
|
||||
// New format: user_{client}_account_{account_uuid}_session_{uuid}
|
||||
re := regexp.MustCompile(`^user_clientid123_account_acc-uuid_session_[a-f0-9-]{36}$`)
|
||||
require.True(t, re.MatchString(got), "unexpected user_id format: %s", got)
|
||||
}
|
||||
@@ -2,6 +2,7 @@ package service
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
@@ -134,6 +135,8 @@ func TestSystemIncludesClaudeCodePrompt(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestInjectClaudeCodePrompt(t *testing.T) {
|
||||
claudePrefix := strings.TrimSpace(claudeCodeSystemPrompt)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
body string
|
||||
@@ -162,7 +165,7 @@ func TestInjectClaudeCodePrompt(t *testing.T) {
|
||||
system: "Custom prompt",
|
||||
wantSystemLen: 2,
|
||||
wantFirstText: claudeCodeSystemPrompt,
|
||||
wantSecondText: "Custom prompt",
|
||||
wantSecondText: claudePrefix + "\n\nCustom prompt",
|
||||
},
|
||||
{
|
||||
name: "string system equals Claude Code prompt",
|
||||
@@ -178,7 +181,7 @@ func TestInjectClaudeCodePrompt(t *testing.T) {
|
||||
// Claude Code + Custom = 2
|
||||
wantSystemLen: 2,
|
||||
wantFirstText: claudeCodeSystemPrompt,
|
||||
wantSecondText: "Custom",
|
||||
wantSecondText: claudePrefix + "\n\nCustom",
|
||||
},
|
||||
{
|
||||
name: "array system with existing Claude Code prompt (should dedupe)",
|
||||
@@ -190,7 +193,7 @@ func TestInjectClaudeCodePrompt(t *testing.T) {
|
||||
// Claude Code at start + Other = 2 (deduped)
|
||||
wantSystemLen: 2,
|
||||
wantFirstText: claudeCodeSystemPrompt,
|
||||
wantSecondText: "Other",
|
||||
wantSecondText: claudePrefix + "\n\nOther",
|
||||
},
|
||||
{
|
||||
name: "empty array",
|
||||
|
||||
21
backend/internal/service/gateway_sanitize_test.go
Normal file
21
backend/internal/service/gateway_sanitize_test.go
Normal file
@@ -0,0 +1,21 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestSanitizeOpenCodeText_RewritesCanonicalSentence(t *testing.T) {
|
||||
in := "You are OpenCode, the best coding agent on the planet."
|
||||
got := sanitizeSystemText(in)
|
||||
require.Equal(t, strings.TrimSpace(claudeCodeSystemPrompt), got)
|
||||
}
|
||||
|
||||
func TestSanitizeToolDescription_DoesNotRewriteKeywords(t *testing.T) {
|
||||
in := "OpenCode and opencode are mentioned."
|
||||
got := sanitizeToolDescription(in)
|
||||
// We no longer rewrite tool descriptions; only redact obvious path leaks.
|
||||
require.Equal(t, in, got)
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -36,6 +36,11 @@ const (
|
||||
geminiRetryMaxDelay = 16 * time.Second
|
||||
)
|
||||
|
||||
// Gemini tool calling now requires `thoughtSignature` in parts that include `functionCall`.
|
||||
// Many clients don't send it; we inject a known dummy signature to satisfy the validator.
|
||||
// Ref: https://ai.google.dev/gemini-api/docs/thought-signatures
|
||||
const geminiDummyThoughtSignature = "skip_thought_signature_validator"
|
||||
|
||||
type GeminiMessagesCompatService struct {
|
||||
accountRepo AccountRepository
|
||||
groupRepo GroupRepository
|
||||
@@ -528,6 +533,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
|
||||
if err != nil {
|
||||
return nil, s.writeClaudeError(c, http.StatusBadRequest, "invalid_request_error", err.Error())
|
||||
}
|
||||
geminiReq = ensureGeminiFunctionCallThoughtSignatures(geminiReq)
|
||||
originalClaudeBody := body
|
||||
|
||||
proxyURL := ""
|
||||
@@ -983,6 +989,10 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
|
||||
return nil, s.writeGoogleError(c, http.StatusNotFound, "Unsupported action: "+action)
|
||||
}
|
||||
|
||||
// Some Gemini upstreams validate tool call parts strictly; ensure any `functionCall` part includes a
|
||||
// `thoughtSignature` to avoid frequent INVALID_ARGUMENT 400s.
|
||||
body = ensureGeminiFunctionCallThoughtSignatures(body)
|
||||
|
||||
mappedModel := originalModel
|
||||
if account.Type == AccountTypeAPIKey {
|
||||
mappedModel = account.GetMappedModel(originalModel)
|
||||
@@ -2662,6 +2672,58 @@ func nextGeminiDailyResetUnix() *int64 {
|
||||
return &ts
|
||||
}
|
||||
|
||||
func ensureGeminiFunctionCallThoughtSignatures(body []byte) []byte {
|
||||
// Fast path: only run when functionCall is present.
|
||||
if !bytes.Contains(body, []byte(`"functionCall"`)) {
|
||||
return body
|
||||
}
|
||||
|
||||
var payload map[string]any
|
||||
if err := json.Unmarshal(body, &payload); err != nil {
|
||||
return body
|
||||
}
|
||||
|
||||
contentsAny, ok := payload["contents"].([]any)
|
||||
if !ok || len(contentsAny) == 0 {
|
||||
return body
|
||||
}
|
||||
|
||||
modified := false
|
||||
for _, c := range contentsAny {
|
||||
cm, ok := c.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
partsAny, ok := cm["parts"].([]any)
|
||||
if !ok || len(partsAny) == 0 {
|
||||
continue
|
||||
}
|
||||
for _, p := range partsAny {
|
||||
pm, ok := p.(map[string]any)
|
||||
if !ok || pm == nil {
|
||||
continue
|
||||
}
|
||||
if fc, ok := pm["functionCall"].(map[string]any); !ok || fc == nil {
|
||||
continue
|
||||
}
|
||||
ts, _ := pm["thoughtSignature"].(string)
|
||||
if strings.TrimSpace(ts) == "" {
|
||||
pm["thoughtSignature"] = geminiDummyThoughtSignature
|
||||
modified = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !modified {
|
||||
return body
|
||||
}
|
||||
b, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return body
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
func extractGeminiFinishReason(geminiResp map[string]any) string {
|
||||
if candidates, ok := geminiResp["candidates"].([]any); ok && len(candidates) > 0 {
|
||||
if cand, ok := candidates[0].(map[string]any); ok {
|
||||
@@ -2861,7 +2923,13 @@ func convertClaudeMessagesToGeminiContents(messages any, toolUseIDToName map[str
|
||||
if strings.TrimSpace(id) != "" && strings.TrimSpace(name) != "" {
|
||||
toolUseIDToName[id] = name
|
||||
}
|
||||
signature, _ := bm["signature"].(string)
|
||||
signature = strings.TrimSpace(signature)
|
||||
if signature == "" {
|
||||
signature = geminiDummyThoughtSignature
|
||||
}
|
||||
parts = append(parts, map[string]any{
|
||||
"thoughtSignature": signature,
|
||||
"functionCall": map[string]any{
|
||||
"name": name,
|
||||
"args": bm["input"],
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
@@ -126,3 +128,78 @@ func TestConvertClaudeToolsToGeminiTools_CustomType(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertClaudeMessagesToGeminiGenerateContent_AddsThoughtSignatureForToolUse(t *testing.T) {
|
||||
claudeReq := map[string]any{
|
||||
"model": "claude-haiku-4-5-20251001",
|
||||
"max_tokens": 10,
|
||||
"messages": []any{
|
||||
map[string]any{
|
||||
"role": "user",
|
||||
"content": []any{
|
||||
map[string]any{"type": "text", "text": "hi"},
|
||||
},
|
||||
},
|
||||
map[string]any{
|
||||
"role": "assistant",
|
||||
"content": []any{
|
||||
map[string]any{"type": "text", "text": "ok"},
|
||||
map[string]any{
|
||||
"type": "tool_use",
|
||||
"id": "toolu_123",
|
||||
"name": "default_api:write_file",
|
||||
"input": map[string]any{"path": "a.txt", "content": "x"},
|
||||
// no signature on purpose
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
"tools": []any{
|
||||
map[string]any{
|
||||
"name": "default_api:write_file",
|
||||
"description": "write file",
|
||||
"input_schema": map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{"path": map[string]any{"type": "string"}},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
b, _ := json.Marshal(claudeReq)
|
||||
|
||||
out, err := convertClaudeMessagesToGeminiGenerateContent(b)
|
||||
if err != nil {
|
||||
t.Fatalf("convert failed: %v", err)
|
||||
}
|
||||
s := string(out)
|
||||
if !strings.Contains(s, "\"functionCall\"") {
|
||||
t.Fatalf("expected functionCall in output, got: %s", s)
|
||||
}
|
||||
if !strings.Contains(s, "\"thoughtSignature\":\""+geminiDummyThoughtSignature+"\"") {
|
||||
t.Fatalf("expected injected thoughtSignature %q, got: %s", geminiDummyThoughtSignature, s)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnsureGeminiFunctionCallThoughtSignatures_InsertsWhenMissing(t *testing.T) {
|
||||
geminiReq := map[string]any{
|
||||
"contents": []any{
|
||||
map[string]any{
|
||||
"role": "user",
|
||||
"parts": []any{
|
||||
map[string]any{
|
||||
"functionCall": map[string]any{
|
||||
"name": "default_api:write_file",
|
||||
"args": map[string]any{"path": "a.txt"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
b, _ := json.Marshal(geminiReq)
|
||||
out := ensureGeminiFunctionCallThoughtSignatures(b)
|
||||
s := string(out)
|
||||
if !strings.Contains(s, "\"thoughtSignature\":\""+geminiDummyThoughtSignature+"\"") {
|
||||
t.Fatalf("expected injected thoughtSignature %q, got: %s", geminiDummyThoughtSignature, s)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -218,6 +218,14 @@ func (m *mockGroupRepoForGemini) DeleteAccountGroupsByGroupID(ctx context.Contex
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (m *mockGroupRepoForGemini) BindAccountsToGroup(ctx context.Context, groupID int64, accountIDs []int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockGroupRepoForGemini) GetAccountIDsByGroupIDs(ctx context.Context, groupIDs []int64) ([]int64, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var _ GroupRepository = (*mockGroupRepoForGemini)(nil)
|
||||
|
||||
// mockGatewayCacheForGemini Gemini 测试用的 cache mock
|
||||
|
||||
@@ -29,6 +29,10 @@ type GroupRepository interface {
|
||||
ExistsByName(ctx context.Context, name string) (bool, error)
|
||||
GetAccountCount(ctx context.Context, groupID int64) (int64, error)
|
||||
DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error)
|
||||
// GetAccountIDsByGroupIDs 获取多个分组的所有账号 ID(去重)
|
||||
GetAccountIDsByGroupIDs(ctx context.Context, groupIDs []int64) ([]int64, error)
|
||||
// BindAccountsToGroup 将多个账号绑定到指定分组
|
||||
BindAccountsToGroup(ctx context.Context, groupID int64, accountIDs []int64) error
|
||||
}
|
||||
|
||||
// CreateGroupRequest 创建分组请求
|
||||
|
||||
@@ -26,13 +26,13 @@ var (
|
||||
|
||||
// 默认指纹值(当客户端未提供时使用)
|
||||
var defaultFingerprint = Fingerprint{
|
||||
UserAgent: "claude-cli/2.0.62 (external, cli)",
|
||||
UserAgent: "claude-cli/2.1.22 (external, cli)",
|
||||
StainlessLang: "js",
|
||||
StainlessPackageVersion: "0.52.0",
|
||||
StainlessPackageVersion: "0.70.0",
|
||||
StainlessOS: "Linux",
|
||||
StainlessArch: "x64",
|
||||
StainlessArch: "arm64",
|
||||
StainlessRuntime: "node",
|
||||
StainlessRuntimeVersion: "v22.14.0",
|
||||
StainlessRuntimeVersion: "v24.13.0",
|
||||
}
|
||||
|
||||
// Fingerprint represents account fingerprint data
|
||||
@@ -327,7 +327,7 @@ func generateUUIDFromSeed(seed string) string {
|
||||
}
|
||||
|
||||
// parseUserAgentVersion 解析user-agent版本号
|
||||
// 例如:claude-cli/2.0.62 -> (2, 0, 62)
|
||||
// 例如:claude-cli/2.1.2 -> (2, 1, 2)
|
||||
func parseUserAgentVersion(ua string) (major, minor, patch int, ok bool) {
|
||||
// 匹配 xxx/x.y.z 格式
|
||||
matches := userAgentVersionRegex.FindStringSubmatch(ua)
|
||||
|
||||
@@ -156,12 +156,15 @@ type OpenAIUsage struct {
|
||||
|
||||
// OpenAIForwardResult represents the result of forwarding
|
||||
type OpenAIForwardResult struct {
|
||||
RequestID string
|
||||
Usage OpenAIUsage
|
||||
Model string
|
||||
Stream bool
|
||||
Duration time.Duration
|
||||
FirstTokenMs *int
|
||||
RequestID string
|
||||
Usage OpenAIUsage
|
||||
Model string
|
||||
// ReasoningEffort is extracted from request body (reasoning.effort) or derived from model suffix.
|
||||
// Stored for usage records display; nil means not provided / not applicable.
|
||||
ReasoningEffort *string
|
||||
Stream bool
|
||||
Duration time.Duration
|
||||
FirstTokenMs *int
|
||||
}
|
||||
|
||||
// OpenAIGatewayService handles OpenAI API gateway operations
|
||||
@@ -958,13 +961,16 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
|
||||
}
|
||||
}
|
||||
|
||||
reasoningEffort := extractOpenAIReasoningEffort(reqBody, originalModel)
|
||||
|
||||
return &OpenAIForwardResult{
|
||||
RequestID: resp.Header.Get("x-request-id"),
|
||||
Usage: *usage,
|
||||
Model: originalModel,
|
||||
Stream: reqStream,
|
||||
Duration: time.Since(startTime),
|
||||
FirstTokenMs: firstTokenMs,
|
||||
RequestID: resp.Header.Get("x-request-id"),
|
||||
Usage: *usage,
|
||||
Model: originalModel,
|
||||
ReasoningEffort: reasoningEffort,
|
||||
Stream: reqStream,
|
||||
Duration: time.Since(startTime),
|
||||
FirstTokenMs: firstTokenMs,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -1260,15 +1266,29 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
|
||||
// 记录上次收到上游数据的时间,用于控制 keepalive 发送频率
|
||||
lastDataAt := time.Now()
|
||||
|
||||
// 仅发送一次错误事件,避免多次写入导致协议混乱(写失败时尽力通知客户端)
|
||||
// 仅发送一次错误事件,避免多次写入导致协议混乱。
|
||||
// 注意:OpenAI `/v1/responses` streaming 事件必须符合 OpenAI Responses schema;
|
||||
// 否则下游 SDK(例如 OpenCode)会因为类型校验失败而报错。
|
||||
errorEventSent := false
|
||||
clientDisconnected := false // 客户端断开后继续 drain 上游以收集 usage
|
||||
sendErrorEvent := func(reason string) {
|
||||
if errorEventSent {
|
||||
if errorEventSent || clientDisconnected {
|
||||
return
|
||||
}
|
||||
errorEventSent = true
|
||||
_, _ = fmt.Fprintf(w, "event: error\ndata: {\"error\":\"%s\"}\n\n", reason)
|
||||
flusher.Flush()
|
||||
payload := map[string]any{
|
||||
"type": "error",
|
||||
"sequence_number": 0,
|
||||
"error": map[string]any{
|
||||
"type": "upstream_error",
|
||||
"message": reason,
|
||||
"code": reason,
|
||||
},
|
||||
}
|
||||
if b, err := json.Marshal(payload); err == nil {
|
||||
_, _ = fmt.Fprintf(w, "data: %s\n\n", b)
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
needModelReplace := originalModel != mappedModel
|
||||
@@ -1280,6 +1300,17 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
|
||||
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, nil
|
||||
}
|
||||
if ev.err != nil {
|
||||
// 客户端断开/取消请求时,上游读取往往会返回 context canceled。
|
||||
// /v1/responses 的 SSE 事件必须符合 OpenAI 协议;这里不注入自定义 error event,避免下游 SDK 解析失败。
|
||||
if errors.Is(ev.err, context.Canceled) || errors.Is(ev.err, context.DeadlineExceeded) {
|
||||
log.Printf("Context canceled during streaming, returning collected usage")
|
||||
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, nil
|
||||
}
|
||||
// 客户端已断开时,上游出错仅影响体验,不影响计费;返回已收集 usage
|
||||
if clientDisconnected {
|
||||
log.Printf("Upstream read error after client disconnect: %v, returning collected usage", ev.err)
|
||||
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, nil
|
||||
}
|
||||
if errors.Is(ev.err, bufio.ErrTooLong) {
|
||||
log.Printf("SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, ev.err)
|
||||
sendErrorEvent("response_too_large")
|
||||
@@ -1303,15 +1334,19 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
|
||||
|
||||
// Correct Codex tool calls if needed (apply_patch -> edit, etc.)
|
||||
if correctedData, corrected := s.toolCorrector.CorrectToolCallsInSSEData(data); corrected {
|
||||
data = correctedData
|
||||
line = "data: " + correctedData
|
||||
}
|
||||
|
||||
// Forward line
|
||||
if _, err := fmt.Fprintf(w, "%s\n", line); err != nil {
|
||||
sendErrorEvent("write_failed")
|
||||
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, err
|
||||
// 写入客户端(客户端断开后继续 drain 上游)
|
||||
if !clientDisconnected {
|
||||
if _, err := fmt.Fprintf(w, "%s\n", line); err != nil {
|
||||
clientDisconnected = true
|
||||
log.Printf("Client disconnected during streaming, continuing to drain upstream for billing")
|
||||
} else {
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
flusher.Flush()
|
||||
|
||||
// Record first token time
|
||||
if firstTokenMs == nil && data != "" && data != "[DONE]" {
|
||||
@@ -1321,11 +1356,14 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
|
||||
s.parseSSEUsage(data, usage)
|
||||
} else {
|
||||
// Forward non-data lines as-is
|
||||
if _, err := fmt.Fprintf(w, "%s\n", line); err != nil {
|
||||
sendErrorEvent("write_failed")
|
||||
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, err
|
||||
if !clientDisconnected {
|
||||
if _, err := fmt.Fprintf(w, "%s\n", line); err != nil {
|
||||
clientDisconnected = true
|
||||
log.Printf("Client disconnected during streaming, continuing to drain upstream for billing")
|
||||
} else {
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
flusher.Flush()
|
||||
}
|
||||
|
||||
case <-intervalCh:
|
||||
@@ -1333,6 +1371,10 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
|
||||
if time.Since(lastRead) < streamInterval {
|
||||
continue
|
||||
}
|
||||
if clientDisconnected {
|
||||
log.Printf("Upstream timeout after client disconnect, returning collected usage")
|
||||
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, nil
|
||||
}
|
||||
log.Printf("Stream data interval timeout: account=%d model=%s interval=%s", account.ID, originalModel, streamInterval)
|
||||
// 处理流超时,可能标记账户为临时不可调度或错误状态
|
||||
if s.rateLimitService != nil {
|
||||
@@ -1342,11 +1384,16 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
|
||||
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout")
|
||||
|
||||
case <-keepaliveCh:
|
||||
if clientDisconnected {
|
||||
continue
|
||||
}
|
||||
if time.Since(lastDataAt) < keepaliveInterval {
|
||||
continue
|
||||
}
|
||||
if _, err := fmt.Fprint(w, ":\n\n"); err != nil {
|
||||
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, err
|
||||
clientDisconnected = true
|
||||
log.Printf("Client disconnected during streaming, continuing to drain upstream for billing")
|
||||
continue
|
||||
}
|
||||
flusher.Flush()
|
||||
}
|
||||
@@ -1687,6 +1734,7 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
|
||||
AccountID: account.ID,
|
||||
RequestID: result.RequestID,
|
||||
Model: result.Model,
|
||||
ReasoningEffort: result.ReasoningEffort,
|
||||
InputTokens: actualInputTokens,
|
||||
OutputTokens: result.Usage.OutputTokens,
|
||||
CacheCreationTokens: result.Usage.CacheCreationInputTokens,
|
||||
@@ -1881,3 +1929,86 @@ func (s *OpenAIGatewayService) updateCodexUsageSnapshot(ctx context.Context, acc
|
||||
_ = s.accountRepo.UpdateExtra(updateCtx, accountID, updates)
|
||||
}()
|
||||
}
|
||||
|
||||
func getOpenAIReasoningEffortFromReqBody(reqBody map[string]any) (value string, present bool) {
|
||||
if reqBody == nil {
|
||||
return "", false
|
||||
}
|
||||
|
||||
// Primary: reasoning.effort
|
||||
if reasoning, ok := reqBody["reasoning"].(map[string]any); ok {
|
||||
if effort, ok := reasoning["effort"].(string); ok {
|
||||
return normalizeOpenAIReasoningEffort(effort), true
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback: some clients may use a flat field.
|
||||
if effort, ok := reqBody["reasoning_effort"].(string); ok {
|
||||
return normalizeOpenAIReasoningEffort(effort), true
|
||||
}
|
||||
|
||||
return "", false
|
||||
}
|
||||
|
||||
func deriveOpenAIReasoningEffortFromModel(model string) string {
|
||||
if strings.TrimSpace(model) == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
modelID := strings.TrimSpace(model)
|
||||
if strings.Contains(modelID, "/") {
|
||||
parts := strings.Split(modelID, "/")
|
||||
modelID = parts[len(parts)-1]
|
||||
}
|
||||
|
||||
parts := strings.FieldsFunc(strings.ToLower(modelID), func(r rune) bool {
|
||||
switch r {
|
||||
case '-', '_', ' ':
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
})
|
||||
if len(parts) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
return normalizeOpenAIReasoningEffort(parts[len(parts)-1])
|
||||
}
|
||||
|
||||
func extractOpenAIReasoningEffort(reqBody map[string]any, requestedModel string) *string {
|
||||
if value, present := getOpenAIReasoningEffortFromReqBody(reqBody); present {
|
||||
if value == "" {
|
||||
return nil
|
||||
}
|
||||
return &value
|
||||
}
|
||||
|
||||
value := deriveOpenAIReasoningEffortFromModel(requestedModel)
|
||||
if value == "" {
|
||||
return nil
|
||||
}
|
||||
return &value
|
||||
}
|
||||
|
||||
func normalizeOpenAIReasoningEffort(raw string) string {
|
||||
value := strings.ToLower(strings.TrimSpace(raw))
|
||||
if value == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Normalize separators for "x-high"/"x_high" variants.
|
||||
value = strings.NewReplacer("-", "", "_", "", " ", "").Replace(value)
|
||||
|
||||
switch value {
|
||||
case "none", "minimal":
|
||||
return ""
|
||||
case "low", "medium", "high":
|
||||
return value
|
||||
case "xhigh", "extrahigh":
|
||||
return "xhigh"
|
||||
default:
|
||||
// Only store known effort levels for now to keep UI consistent.
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
@@ -59,6 +59,25 @@ type stubConcurrencyCache struct {
|
||||
skipDefaultLoad bool
|
||||
}
|
||||
|
||||
type cancelReadCloser struct{}
|
||||
|
||||
func (c cancelReadCloser) Read(p []byte) (int, error) { return 0, context.Canceled }
|
||||
func (c cancelReadCloser) Close() error { return nil }
|
||||
|
||||
type failingGinWriter struct {
|
||||
gin.ResponseWriter
|
||||
failAfter int
|
||||
writes int
|
||||
}
|
||||
|
||||
func (w *failingGinWriter) Write(p []byte) (int, error) {
|
||||
if w.writes >= w.failAfter {
|
||||
return 0, errors.New("write failed")
|
||||
}
|
||||
w.writes++
|
||||
return w.ResponseWriter.Write(p)
|
||||
}
|
||||
|
||||
func (c stubConcurrencyCache) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) {
|
||||
if c.acquireResults != nil {
|
||||
if result, ok := c.acquireResults[accountID]; ok {
|
||||
@@ -814,8 +833,85 @@ func TestOpenAIStreamingTimeout(t *testing.T) {
|
||||
if err == nil || !strings.Contains(err.Error(), "stream data interval timeout") {
|
||||
t.Fatalf("expected stream timeout error, got %v", err)
|
||||
}
|
||||
if !strings.Contains(rec.Body.String(), "stream_timeout") {
|
||||
t.Fatalf("expected stream_timeout SSE error, got %q", rec.Body.String())
|
||||
if !strings.Contains(rec.Body.String(), "\"type\":\"error\"") || !strings.Contains(rec.Body.String(), "stream_timeout") {
|
||||
t.Fatalf("expected OpenAI-compatible error SSE event, got %q", rec.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAIStreamingContextCanceledDoesNotInjectErrorEvent(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
cfg := &config.Config{
|
||||
Gateway: config.GatewayConfig{
|
||||
StreamDataIntervalTimeout: 0,
|
||||
StreamKeepaliveInterval: 0,
|
||||
MaxLineSize: defaultMaxLineSize,
|
||||
},
|
||||
}
|
||||
svc := &OpenAIGatewayService{cfg: cfg}
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/", nil).WithContext(ctx)
|
||||
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: cancelReadCloser{},
|
||||
Header: http.Header{},
|
||||
}
|
||||
|
||||
_, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 1}, time.Now(), "model", "model")
|
||||
if err != nil {
|
||||
t.Fatalf("expected nil error, got %v", err)
|
||||
}
|
||||
if strings.Contains(rec.Body.String(), "event: error") || strings.Contains(rec.Body.String(), "stream_read_error") {
|
||||
t.Fatalf("expected no injected SSE error event, got %q", rec.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAIStreamingClientDisconnectDrainsUpstreamUsage(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
cfg := &config.Config{
|
||||
Gateway: config.GatewayConfig{
|
||||
StreamDataIntervalTimeout: 0,
|
||||
StreamKeepaliveInterval: 0,
|
||||
MaxLineSize: defaultMaxLineSize,
|
||||
},
|
||||
}
|
||||
svc := &OpenAIGatewayService{cfg: cfg}
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
|
||||
c.Writer = &failingGinWriter{ResponseWriter: c.Writer, failAfter: 0}
|
||||
|
||||
pr, pw := io.Pipe()
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: pr,
|
||||
Header: http.Header{},
|
||||
}
|
||||
|
||||
go func() {
|
||||
defer func() { _ = pw.Close() }()
|
||||
_, _ = pw.Write([]byte("data: {\"type\":\"response.in_progress\",\"response\":{}}\n\n"))
|
||||
_, _ = pw.Write([]byte("data: {\"type\":\"response.completed\",\"response\":{\"usage\":{\"input_tokens\":3,\"output_tokens\":5,\"input_tokens_details\":{\"cached_tokens\":1}}}}\n\n"))
|
||||
}()
|
||||
|
||||
result, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 1}, time.Now(), "model", "model")
|
||||
_ = pr.Close()
|
||||
if err != nil {
|
||||
t.Fatalf("expected nil error, got %v", err)
|
||||
}
|
||||
if result == nil || result.usage == nil {
|
||||
t.Fatalf("expected usage result")
|
||||
}
|
||||
if result.usage.InputTokens != 3 || result.usage.OutputTokens != 5 || result.usage.CacheReadInputTokens != 1 {
|
||||
t.Fatalf("unexpected usage: %+v", *result.usage)
|
||||
}
|
||||
if strings.Contains(rec.Body.String(), "event: error") || strings.Contains(rec.Body.String(), "write_failed") {
|
||||
t.Fatalf("expected no injected SSE error event, got %q", rec.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -854,8 +950,8 @@ func TestOpenAIStreamingTooLong(t *testing.T) {
|
||||
if !errors.Is(err, bufio.ErrTooLong) {
|
||||
t.Fatalf("expected ErrTooLong, got %v", err)
|
||||
}
|
||||
if !strings.Contains(rec.Body.String(), "response_too_large") {
|
||||
t.Fatalf("expected response_too_large SSE error, got %q", rec.Body.String())
|
||||
if !strings.Contains(rec.Body.String(), "\"type\":\"error\"") || !strings.Contains(rec.Body.String(), "response_too_large") {
|
||||
t.Fatalf("expected OpenAI-compatible error SSE event, got %q", rec.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -126,7 +126,8 @@ func (s *RedeemService) GenerateCodes(ctx context.Context, req GenerateCodesRequ
|
||||
return nil, errors.New("count must be greater than 0")
|
||||
}
|
||||
|
||||
if req.Value <= 0 {
|
||||
// 邀请码类型不需要数值,其他类型需要
|
||||
if req.Type != RedeemTypeInvitation && req.Value <= 0 {
|
||||
return nil, errors.New("value must be greater than 0")
|
||||
}
|
||||
|
||||
@@ -139,6 +140,12 @@ func (s *RedeemService) GenerateCodes(ctx context.Context, req GenerateCodesRequ
|
||||
codeType = RedeemTypeBalance
|
||||
}
|
||||
|
||||
// 邀请码类型的 value 设为 0
|
||||
value := req.Value
|
||||
if codeType == RedeemTypeInvitation {
|
||||
value = 0
|
||||
}
|
||||
|
||||
codes := make([]RedeemCode, 0, req.Count)
|
||||
for i := 0; i < req.Count; i++ {
|
||||
code, err := s.GenerateRandomCode()
|
||||
@@ -149,7 +156,7 @@ func (s *RedeemService) GenerateCodes(ctx context.Context, req GenerateCodesRequ
|
||||
codes = append(codes, RedeemCode{
|
||||
Code: code,
|
||||
Type: codeType,
|
||||
Value: req.Value,
|
||||
Value: value,
|
||||
Status: StatusUnused,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -62,6 +62,7 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
|
||||
SettingKeyEmailVerifyEnabled,
|
||||
SettingKeyPromoCodeEnabled,
|
||||
SettingKeyPasswordResetEnabled,
|
||||
SettingKeyInvitationCodeEnabled,
|
||||
SettingKeyTotpEnabled,
|
||||
SettingKeyTurnstileEnabled,
|
||||
SettingKeyTurnstileSiteKey,
|
||||
@@ -99,6 +100,7 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
|
||||
EmailVerifyEnabled: emailVerifyEnabled,
|
||||
PromoCodeEnabled: settings[SettingKeyPromoCodeEnabled] != "false", // 默认启用
|
||||
PasswordResetEnabled: passwordResetEnabled,
|
||||
InvitationCodeEnabled: settings[SettingKeyInvitationCodeEnabled] == "true",
|
||||
TotpEnabled: settings[SettingKeyTotpEnabled] == "true",
|
||||
TurnstileEnabled: settings[SettingKeyTurnstileEnabled] == "true",
|
||||
TurnstileSiteKey: settings[SettingKeyTurnstileSiteKey],
|
||||
@@ -141,6 +143,7 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any
|
||||
EmailVerifyEnabled bool `json:"email_verify_enabled"`
|
||||
PromoCodeEnabled bool `json:"promo_code_enabled"`
|
||||
PasswordResetEnabled bool `json:"password_reset_enabled"`
|
||||
InvitationCodeEnabled bool `json:"invitation_code_enabled"`
|
||||
TotpEnabled bool `json:"totp_enabled"`
|
||||
TurnstileEnabled bool `json:"turnstile_enabled"`
|
||||
TurnstileSiteKey string `json:"turnstile_site_key,omitempty"`
|
||||
@@ -161,6 +164,7 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any
|
||||
EmailVerifyEnabled: settings.EmailVerifyEnabled,
|
||||
PromoCodeEnabled: settings.PromoCodeEnabled,
|
||||
PasswordResetEnabled: settings.PasswordResetEnabled,
|
||||
InvitationCodeEnabled: settings.InvitationCodeEnabled,
|
||||
TotpEnabled: settings.TotpEnabled,
|
||||
TurnstileEnabled: settings.TurnstileEnabled,
|
||||
TurnstileSiteKey: settings.TurnstileSiteKey,
|
||||
@@ -188,6 +192,7 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet
|
||||
updates[SettingKeyEmailVerifyEnabled] = strconv.FormatBool(settings.EmailVerifyEnabled)
|
||||
updates[SettingKeyPromoCodeEnabled] = strconv.FormatBool(settings.PromoCodeEnabled)
|
||||
updates[SettingKeyPasswordResetEnabled] = strconv.FormatBool(settings.PasswordResetEnabled)
|
||||
updates[SettingKeyInvitationCodeEnabled] = strconv.FormatBool(settings.InvitationCodeEnabled)
|
||||
updates[SettingKeyTotpEnabled] = strconv.FormatBool(settings.TotpEnabled)
|
||||
|
||||
// 邮件服务设置(只有非空才更新密码)
|
||||
@@ -286,6 +291,14 @@ func (s *SettingService) IsPromoCodeEnabled(ctx context.Context) bool {
|
||||
return value != "false"
|
||||
}
|
||||
|
||||
// IsInvitationCodeEnabled 检查是否启用邀请码注册功能
|
||||
func (s *SettingService) IsInvitationCodeEnabled(ctx context.Context) bool {
|
||||
value, err := s.settingRepo.GetValue(ctx, SettingKeyInvitationCodeEnabled)
|
||||
if err != nil {
|
||||
return false // 默认关闭
|
||||
}
|
||||
return value == "true"
|
||||
}
|
||||
// IsPasswordResetEnabled 检查是否启用密码重置功能
|
||||
// 要求:必须同时开启邮件验证
|
||||
func (s *SettingService) IsPasswordResetEnabled(ctx context.Context) bool {
|
||||
@@ -401,6 +414,7 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
|
||||
EmailVerifyEnabled: emailVerifyEnabled,
|
||||
PromoCodeEnabled: settings[SettingKeyPromoCodeEnabled] != "false", // 默认启用
|
||||
PasswordResetEnabled: emailVerifyEnabled && settings[SettingKeyPasswordResetEnabled] == "true",
|
||||
InvitationCodeEnabled: settings[SettingKeyInvitationCodeEnabled] == "true",
|
||||
TotpEnabled: settings[SettingKeyTotpEnabled] == "true",
|
||||
SMTPHost: settings[SettingKeySMTPHost],
|
||||
SMTPUsername: settings[SettingKeySMTPUsername],
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
package service
|
||||
|
||||
type SystemSettings struct {
|
||||
RegistrationEnabled bool
|
||||
EmailVerifyEnabled bool
|
||||
PromoCodeEnabled bool
|
||||
PasswordResetEnabled bool
|
||||
TotpEnabled bool // TOTP 双因素认证
|
||||
RegistrationEnabled bool
|
||||
EmailVerifyEnabled bool
|
||||
PromoCodeEnabled bool
|
||||
PasswordResetEnabled bool
|
||||
InvitationCodeEnabled bool
|
||||
TotpEnabled bool // TOTP 双因素认证
|
||||
|
||||
SMTPHost string
|
||||
SMTPPort int
|
||||
@@ -61,21 +62,22 @@ type SystemSettings struct {
|
||||
}
|
||||
|
||||
type PublicSettings struct {
|
||||
RegistrationEnabled bool
|
||||
EmailVerifyEnabled bool
|
||||
PromoCodeEnabled bool
|
||||
PasswordResetEnabled bool
|
||||
TotpEnabled bool // TOTP 双因素认证
|
||||
TurnstileEnabled bool
|
||||
TurnstileSiteKey string
|
||||
SiteName string
|
||||
SiteLogo string
|
||||
SiteSubtitle string
|
||||
APIBaseURL string
|
||||
ContactInfo string
|
||||
DocURL string
|
||||
HomeContent string
|
||||
HideCcsImportButton bool
|
||||
RegistrationEnabled bool
|
||||
EmailVerifyEnabled bool
|
||||
PromoCodeEnabled bool
|
||||
PasswordResetEnabled bool
|
||||
InvitationCodeEnabled bool
|
||||
TotpEnabled bool // TOTP 双因素认证
|
||||
TurnstileEnabled bool
|
||||
TurnstileSiteKey string
|
||||
SiteName string
|
||||
SiteLogo string
|
||||
SiteSubtitle string
|
||||
APIBaseURL string
|
||||
ContactInfo string
|
||||
DocURL string
|
||||
HomeContent string
|
||||
HideCcsImportButton bool
|
||||
|
||||
PurchaseSubscriptionEnabled bool
|
||||
PurchaseSubscriptionURL string
|
||||
|
||||
@@ -14,6 +14,9 @@ type UsageLog struct {
|
||||
AccountID int64
|
||||
RequestID string
|
||||
Model string
|
||||
// ReasoningEffort is the request's reasoning effort level (OpenAI Responses API),
|
||||
// e.g. "low" / "medium" / "high" / "xhigh". Nil means not provided / not applicable.
|
||||
ReasoningEffort *string
|
||||
|
||||
GroupID *int64
|
||||
SubscriptionID *int64
|
||||
|
||||
Reference in New Issue
Block a user