refactor: 调整项目结构为单向依赖
This commit is contained in:
@@ -17,7 +17,6 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
@@ -265,12 +264,12 @@ func (s *GatewayService) replaceModelInBody(body []byte, newModel string) []byte
|
||||
}
|
||||
|
||||
// SelectAccount 选择账号(粘性会话+优先级)
|
||||
func (s *GatewayService) SelectAccount(ctx context.Context, groupID *int64, sessionHash string) (*model.Account, error) {
|
||||
func (s *GatewayService) SelectAccount(ctx context.Context, groupID *int64, sessionHash string) (*Account, error) {
|
||||
return s.SelectAccountForModel(ctx, groupID, sessionHash, "")
|
||||
}
|
||||
|
||||
// SelectAccountForModel 选择支持指定模型的账号(粘性会话+优先级+模型映射)
|
||||
func (s *GatewayService) SelectAccountForModel(ctx context.Context, groupID *int64, sessionHash string, requestedModel string) (*model.Account, error) {
|
||||
func (s *GatewayService) SelectAccountForModel(ctx context.Context, groupID *int64, sessionHash string, requestedModel string) (*Account, error) {
|
||||
// 1. 查询粘性会话
|
||||
if sessionHash != "" {
|
||||
accountID, err := s.cache.GetSessionAccountID(ctx, sessionHash)
|
||||
@@ -289,19 +288,19 @@ func (s *GatewayService) SelectAccountForModel(ctx context.Context, groupID *int
|
||||
}
|
||||
|
||||
// 2. 获取可调度账号列表(排除限流和过载的账号,仅限 Anthropic 平台)
|
||||
var accounts []model.Account
|
||||
var accounts []Account
|
||||
var err error
|
||||
if groupID != nil {
|
||||
accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, model.PlatformAnthropic)
|
||||
accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, PlatformAnthropic)
|
||||
} else {
|
||||
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, model.PlatformAnthropic)
|
||||
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, PlatformAnthropic)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("query accounts failed: %w", err)
|
||||
}
|
||||
|
||||
// 3. 按优先级+最久未用选择(考虑模型支持)
|
||||
var selected *model.Account
|
||||
var selected *Account
|
||||
for i := range accounts {
|
||||
acc := &accounts[i]
|
||||
// 检查模型支持
|
||||
@@ -341,12 +340,12 @@ func (s *GatewayService) SelectAccountForModel(ctx context.Context, groupID *int
|
||||
}
|
||||
|
||||
// GetAccessToken 获取账号凭证
|
||||
func (s *GatewayService) GetAccessToken(ctx context.Context, account *model.Account) (string, string, error) {
|
||||
func (s *GatewayService) GetAccessToken(ctx context.Context, account *Account) (string, string, error) {
|
||||
switch account.Type {
|
||||
case model.AccountTypeOAuth, model.AccountTypeSetupToken:
|
||||
case AccountTypeOAuth, AccountTypeSetupToken:
|
||||
// Both oauth and setup-token use OAuth token flow
|
||||
return s.getOAuthToken(ctx, account)
|
||||
case model.AccountTypeApiKey:
|
||||
case AccountTypeApiKey:
|
||||
apiKey := account.GetCredential("api_key")
|
||||
if apiKey == "" {
|
||||
return "", "", errors.New("api_key not found in credentials")
|
||||
@@ -357,7 +356,7 @@ func (s *GatewayService) GetAccessToken(ctx context.Context, account *model.Acco
|
||||
}
|
||||
}
|
||||
|
||||
func (s *GatewayService) getOAuthToken(ctx context.Context, account *model.Account) (string, string, error) {
|
||||
func (s *GatewayService) getOAuthToken(ctx context.Context, account *Account) (string, string, error) {
|
||||
accessToken := account.GetCredential("access_token")
|
||||
if accessToken == "" {
|
||||
return "", "", errors.New("access_token not found in credentials")
|
||||
@@ -372,10 +371,7 @@ const (
|
||||
retryDelay = 3 * time.Second // 重试等待时间
|
||||
)
|
||||
|
||||
// shouldRetryUpstreamError 判断是否应该重试上游错误
|
||||
// OAuth/Setup Token 账号:仅 403 重试
|
||||
// API Key 账号:未配置的错误码重试
|
||||
func (s *GatewayService) shouldRetryUpstreamError(account *model.Account, statusCode int) bool {
|
||||
func (s *GatewayService) shouldRetryUpstreamError(account *Account, statusCode int) bool {
|
||||
// OAuth/Setup Token 账号:仅 403 重试
|
||||
if account.IsOAuth() {
|
||||
return statusCode == 403
|
||||
@@ -386,7 +382,7 @@ func (s *GatewayService) shouldRetryUpstreamError(account *model.Account, status
|
||||
}
|
||||
|
||||
// Forward 转发请求到Claude API
|
||||
func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *model.Account, body []byte) (*ForwardResult, error) {
|
||||
func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *Account, body []byte) (*ForwardResult, error) {
|
||||
startTime := time.Now()
|
||||
|
||||
// 解析请求获取model和stream
|
||||
@@ -412,7 +408,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *m
|
||||
|
||||
// 应用模型映射(仅对apikey类型账号)
|
||||
originalModel := req.Model
|
||||
if account.Type == model.AccountTypeApiKey {
|
||||
if account.Type == AccountTypeApiKey {
|
||||
mappedModel := account.GetMappedModel(req.Model)
|
||||
if mappedModel != req.Model {
|
||||
// 替换请求体中的模型名
|
||||
@@ -504,10 +500,10 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *m
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Context, account *model.Account, body []byte, token, tokenType string) (*http.Request, error) {
|
||||
func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token, tokenType string) (*http.Request, error) {
|
||||
// 确定目标URL
|
||||
targetURL := claudeAPIURL
|
||||
if account.Type == model.AccountTypeApiKey {
|
||||
if account.Type == AccountTypeApiKey {
|
||||
baseURL := account.GetBaseURL()
|
||||
targetURL = baseURL + "/v1/messages"
|
||||
}
|
||||
@@ -631,7 +627,7 @@ func (s *GatewayService) getBetaHeader(body []byte, clientBetaHeader string) str
|
||||
return claude.DefaultBetaHeader
|
||||
}
|
||||
|
||||
func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *model.Account) (*ForwardResult, error) {
|
||||
func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account) (*ForwardResult, error) {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
|
||||
// 处理上游错误,标记账号状态
|
||||
@@ -686,7 +682,7 @@ func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Res
|
||||
// handleRetryExhaustedError 处理重试耗尽后的错误
|
||||
// OAuth 403:标记账号异常
|
||||
// API Key 未配置错误码:仅返回错误,不标记账号
|
||||
func (s *GatewayService) handleRetryExhaustedError(ctx context.Context, resp *http.Response, c *gin.Context, account *model.Account) (*ForwardResult, error) {
|
||||
func (s *GatewayService) handleRetryExhaustedError(ctx context.Context, resp *http.Response, c *gin.Context, account *Account) (*ForwardResult, error) {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
statusCode := resp.StatusCode
|
||||
|
||||
@@ -717,7 +713,7 @@ type streamingResult struct {
|
||||
firstTokenMs *int
|
||||
}
|
||||
|
||||
func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *model.Account, startTime time.Time, originalModel, mappedModel string) (*streamingResult, error) {
|
||||
func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, startTime time.Time, originalModel, mappedModel string) (*streamingResult, error) {
|
||||
// 更新5h窗口状态
|
||||
s.rateLimitService.UpdateSessionWindow(ctx, account, resp.Header)
|
||||
|
||||
@@ -856,7 +852,7 @@ func (s *GatewayService) parseSSEUsage(data string, usage *ClaudeUsage) {
|
||||
}
|
||||
}
|
||||
|
||||
func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *model.Account, originalModel, mappedModel string) (*ClaudeUsage, error) {
|
||||
func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, originalModel, mappedModel string) (*ClaudeUsage, error) {
|
||||
// 更新5h窗口状态
|
||||
s.rateLimitService.UpdateSessionWindow(ctx, account, resp.Header)
|
||||
|
||||
@@ -915,10 +911,10 @@ func (s *GatewayService) replaceModelInResponseBody(body []byte, fromModel, toMo
|
||||
// RecordUsageInput 记录使用量的输入参数
|
||||
type RecordUsageInput struct {
|
||||
Result *ForwardResult
|
||||
ApiKey *model.ApiKey
|
||||
User *model.User
|
||||
Account *model.Account
|
||||
Subscription *model.UserSubscription // 可选:订阅信息
|
||||
ApiKey *ApiKey
|
||||
User *User
|
||||
Account *Account
|
||||
Subscription *UserSubscription // 可选:订阅信息
|
||||
}
|
||||
|
||||
// RecordUsage 记录使用量并扣费(或更新订阅用量)
|
||||
@@ -952,14 +948,14 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
|
||||
|
||||
// 判断计费方式:订阅模式 vs 余额模式
|
||||
isSubscriptionBilling := subscription != nil && apiKey.Group != nil && apiKey.Group.IsSubscriptionType()
|
||||
billingType := model.BillingTypeBalance
|
||||
billingType := BillingTypeBalance
|
||||
if isSubscriptionBilling {
|
||||
billingType = model.BillingTypeSubscription
|
||||
billingType = BillingTypeSubscription
|
||||
}
|
||||
|
||||
// 创建使用日志
|
||||
durationMs := int(result.Duration.Milliseconds())
|
||||
usageLog := &model.UsageLog{
|
||||
usageLog := &UsageLog{
|
||||
UserID: user.ID,
|
||||
ApiKeyID: apiKey.ID,
|
||||
AccountID: account.ID,
|
||||
@@ -1038,9 +1034,9 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
|
||||
|
||||
// ForwardCountTokens 转发 count_tokens 请求到上游 API
|
||||
// 特点:不记录使用量、仅支持非流式响应
|
||||
func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, account *model.Account, body []byte) error {
|
||||
func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, account *Account, body []byte) error {
|
||||
// 应用模型映射(仅对 apikey 类型账号)
|
||||
if account.Type == model.AccountTypeApiKey {
|
||||
if account.Type == AccountTypeApiKey {
|
||||
var req struct {
|
||||
Model string `json:"model"`
|
||||
}
|
||||
@@ -1113,10 +1109,10 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
|
||||
}
|
||||
|
||||
// buildCountTokensRequest 构建 count_tokens 上游请求
|
||||
func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Context, account *model.Account, body []byte, token, tokenType string) (*http.Request, error) {
|
||||
func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token, tokenType string) (*http.Request, error) {
|
||||
// 确定目标 URL
|
||||
targetURL := claudeAPICountTokensURL
|
||||
if account.Type == model.AccountTypeApiKey {
|
||||
if account.Type == AccountTypeApiKey {
|
||||
baseURL := account.GetBaseURL()
|
||||
targetURL = baseURL + "/v1/messages/count_tokens"
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user