package proxy import ( "encoding/json" "fmt" "io" "kiro-go/auth" "kiro-go/config" "kiro-go/pool" "net/http" "strings" "sync" "sync/atomic" "time" "github.com/google/uuid" ) // Handler HTTP 处理器 type Handler struct { pool *pool.AccountPool // 运行时统计 (使用原子操作) totalRequests int64 successRequests int64 failedRequests int64 totalTokens int64 totalCredits float64 // float64 需要用锁保护 creditsMu sync.RWMutex startTime int64 stopRefresh chan struct{} stopStatsSaver chan struct{} // 模型缓存 cachedModels []ModelInfo modelsCacheMu sync.RWMutex modelsCacheTime int64 promptCache *promptCacheTracker } type thinkingStreamSource int const ( thinkingSourceUnknown thinkingStreamSource = iota thinkingSourceReasoningEvent thinkingSourceTagBlock ) func allowReasoningSource(source *thinkingStreamSource) bool { if *source == thinkingSourceTagBlock { return false } *source = thinkingSourceReasoningEvent return true } func allowTagSource(source *thinkingStreamSource) bool { if *source == thinkingSourceReasoningEvent { return false } if *source == thinkingSourceUnknown { *source = thinkingSourceTagBlock } return *source == thinkingSourceTagBlock } func validateClaudeRequestShape(req *ClaudeRequest) string { if len(req.Messages) == 0 { return "messages must not be empty" } if msg := validateClaudeThinkingConfig(req.Thinking, req.MaxTokens); msg != "" { return msg } hasUserContext := false lastRole := "" for _, msg := range req.Messages { role := strings.TrimSpace(msg.Role) if role == "" { continue } lastRole = role if role != "user" { continue } text, images, toolResults := extractClaudeUserContent(msg.Content) if normalizeUserContent(text, len(images) > 0) != "" || len(toolResults) > 0 { hasUserContext = true } } if lastRole == "assistant" { return "assistant-prefill final message is not supported; last message must be user" } if !hasUserContext { return "at least one non-empty user message is required" } return "" } func validateClaudeThinkingConfig(thinking *ClaudeThinkingConfig, maxTokens int) string { if thinking == nil { return "" } kind := strings.ToLower(strings.TrimSpace(thinking.Type)) switch kind { case "enabled": if maxTokens == 0 { return "thinking.type enabled cannot be used with max_tokens=0" } if thinking.BudgetTokens <= 0 { return "thinking.budget_tokens is required when thinking.type is enabled" } if thinking.BudgetTokens < 1024 { return "thinking.budget_tokens must be at least 1024" } if maxTokens > 0 && thinking.BudgetTokens >= maxTokens { return "thinking.budget_tokens must be less than max_tokens" } case "adaptive": if thinking.BudgetTokens != 0 { return "thinking.budget_tokens is not supported when thinking.type is adaptive" } case "disabled": if thinking.BudgetTokens != 0 { return "thinking.budget_tokens is not supported when thinking.type is disabled" } default: return "thinking.type must be one of: enabled, adaptive, disabled" } display := strings.ToLower(strings.TrimSpace(thinking.Display)) if display != "" && display != "summarized" && display != "omitted" { return "thinking.display must be one of: summarized, omitted" } if kind == "disabled" && display != "" { return "thinking.display is not supported when thinking.type is disabled" } return "" } type claudeThinkingResponseOptions struct { Format string OmitDisplay bool } func resolveClaudeThinkingResponseOptions(thinking *ClaudeThinkingConfig, defaultFormat string) claudeThinkingResponseOptions { opts := claudeThinkingResponseOptions{Format: defaultFormat} if opts.Format == "" { opts.Format = "thinking" } if thinking == nil { return opts } display := strings.ToLower(strings.TrimSpace(thinking.Display)) switch display { case "summarized": opts.Format = "thinking" case "omitted": opts.Format = "thinking" opts.OmitDisplay = true } return opts } func validateOpenAIRequestShape(req *OpenAIRequest) string { if len(req.Messages) == 0 { return "messages must not be empty" } hasNonSystem := false hasUserContext := false lastRole := "" for _, msg := range req.Messages { role := strings.TrimSpace(msg.Role) if role == "" { continue } if role != "system" { hasNonSystem = true lastRole = role } if role != "user" { continue } text, images := extractOpenAIUserContent(msg.Content) if normalizeUserContent(text, len(images) > 0) != "" { hasUserContext = true } } if !hasNonSystem { return "at least one non-system message is required" } if lastRole == "assistant" { return "assistant-prefill final message is not supported; last message must be user or tool" } if !hasUserContext { return "at least one non-empty user message is required" } return "" } func NewHandler() *Handler { // 启动时应用代理配置 applyProxyConfig(config.GetProxyURL()) totalReq, successReq, failedReq, totalTokens, totalCredits := config.GetStats() h := &Handler{ pool: pool.GetPool(), totalRequests: int64(totalReq), successRequests: int64(successReq), failedRequests: int64(failedReq), totalTokens: int64(totalTokens), totalCredits: totalCredits, startTime: time.Now().Unix(), stopRefresh: make(chan struct{}), stopStatsSaver: make(chan struct{}), promptCache: newPromptCacheTracker(defaultPromptCacheTTL), } // 启动后台刷新 go h.backgroundRefresh() // 启动后台统计保存 (每30秒保存一次) go h.backgroundStatsSaver() return h } // backgroundRefresh 后台定时刷新账户信息 func (h *Handler) backgroundRefresh() { ticker := time.NewTicker(30 * time.Minute) // 每 30 分钟刷新一次 defer ticker.Stop() // 启动时延迟 10 秒后执行一次 time.Sleep(10 * time.Second) h.refreshModelsCache() h.refreshAllAccounts() for { select { case <-ticker.C: h.refreshModelsCache() h.refreshAllAccounts() case <-h.stopRefresh: return } } } // refreshAllAccounts 刷新所有账户信息 func (h *Handler) refreshAllAccounts() { accounts := config.GetAccounts() for i := range accounts { account := &accounts[i] if !account.Enabled || account.AccessToken == "" { continue } // 检查 token 是否需要刷新 if account.ExpiresAt > 0 && time.Now().Unix() > account.ExpiresAt-300 { newAccessToken, newRefreshToken, newExpiresAt, err := auth.RefreshToken(account) if err != nil { fmt.Printf("[BackgroundRefresh] Token refresh failed for %s: %v\n", account.Email, err) continue } account.AccessToken = newAccessToken if newRefreshToken != "" { account.RefreshToken = newRefreshToken } account.ExpiresAt = newExpiresAt config.UpdateAccountToken(account.ID, newAccessToken, newRefreshToken, newExpiresAt) h.pool.UpdateToken(account.ID, newAccessToken, newRefreshToken, newExpiresAt) } // 刷新账户信息 info, err := RefreshAccountInfo(account) if err != nil { fmt.Printf("[BackgroundRefresh] Failed to refresh %s: %v\n", account.Email, err) continue } config.UpdateAccountInfo(account.ID, *info) fmt.Printf("[BackgroundRefresh] Refreshed %s: %s %.1f/%.1f\n", account.Email, info.SubscriptionType, info.UsageCurrent, info.UsageLimit) } h.pool.Reload() } // validateApiKey 验证 API Key func (h *Handler) validateApiKey(r *http.Request) bool { if !config.IsApiKeyRequired() { return true } expectedKey := config.GetApiKey() if expectedKey == "" { return true } // 从 Authorization 头或 X-Api-Key 头获取 authHeader := r.Header.Get("Authorization") apiKeyHeader := r.Header.Get("X-Api-Key") var providedKey string if strings.HasPrefix(authHeader, "Bearer ") { providedKey = strings.TrimPrefix(authHeader, "Bearer ") } else if apiKeyHeader != "" { providedKey = apiKeyHeader } return providedKey == expectedKey } // ServeHTTP 路由分发 func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { path := r.URL.Path // CORS - 完整的头部支持 w.Header().Set("Access-Control-Allow-Origin", "*") w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS") w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization, X-Api-Key, anthropic-version, anthropic-beta, x-api-key, x-stainless-os, x-stainless-lang, x-stainless-package-version, x-stainless-runtime, x-stainless-runtime-version, x-stainless-arch") w.Header().Set("Access-Control-Expose-Headers", "x-request-id, x-ratelimit-limit-requests, x-ratelimit-limit-tokens, x-ratelimit-remaining-requests, x-ratelimit-remaining-tokens, x-ratelimit-reset-requests, x-ratelimit-reset-tokens") if r.Method == "OPTIONS" { w.WriteHeader(204) return } // 路由 switch { // API 端点(需要验证 API Key) case path == "/v1/messages" || path == "/messages" || path == "/anthropic/v1/messages": if !h.validateApiKey(r) { h.sendClaudeError(w, 401, "authentication_error", "Invalid or missing API key") return } h.handleClaudeMessages(w, r) case path == "/v1/messages/count_tokens" || path == "/messages/count_tokens": if !h.validateApiKey(r) { h.sendClaudeError(w, 401, "authentication_error", "Invalid or missing API key") return } h.handleCountTokens(w, r) case path == "/v1/chat/completions" || path == "/chat/completions": if !h.validateApiKey(r) { h.sendOpenAIError(w, 401, "authentication_error", "Invalid or missing API key") return } h.handleOpenAIChat(w, r) case path == "/v1/models" || path == "/models": h.handleModels(w, r) case path == "/api/event_logging/batch": // Claude Code 遥测端点 - 直接返回 200 OK w.Header().Set("Content-Type", "application/json; charset=utf-8") w.Write([]byte(`{"status":"ok"}`)) // 管理端点 case path == "/admin" || path == "/admin/": h.serveAdminPage(w, r) case strings.HasPrefix(path, "/admin/api/"): h.handleAdminAPI(w, r) case strings.HasPrefix(path, "/admin/"): h.serveStaticFile(w, r) // 健康检查 case path == "/health" || path == "/": h.handleHealth(w, r) // 统计端点(需要 API Key 鉴权) case path == "/v1/stats": if !h.validateApiKey(r) { w.Header().Set("Content-Type", "application/json; charset=utf-8") w.WriteHeader(401) json.NewEncoder(w).Encode(map[string]string{"error": "Invalid or missing API key"}) return } h.handleStats(w, r) default: http.Error(w, "Not Found", 404) } } // handleHealth 健康检查(不暴露统计数据) func (h *Handler) handleHealth(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json; charset=utf-8") json.NewEncoder(w).Encode(map[string]interface{}{ "status": "ok", "version": config.Version, "uptime": time.Now().Unix() - h.startTime, }) } // handleStats 统计数据(需要 API Key 鉴权) func (h *Handler) handleStats(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json; charset=utf-8") json.NewEncoder(w).Encode(map[string]interface{}{ "status": "ok", "version": config.Version, "accounts": h.pool.Count(), "available": h.pool.AvailableCount(), "totalRequests": atomic.LoadInt64(&h.totalRequests), "successRequests": atomic.LoadInt64(&h.successRequests), "failedRequests": atomic.LoadInt64(&h.failedRequests), "totalTokens": atomic.LoadInt64(&h.totalTokens), "totalCredits": h.getCredits(), "uptime": time.Now().Unix() - h.startTime, }) } // handleModels 模型列表 func (h *Handler) handleModels(w http.ResponseWriter, r *http.Request) { // 尝试用缓存的真实模型列表 h.modelsCacheMu.RLock() cached := h.cachedModels h.modelsCacheMu.RUnlock() if len(cached) == 0 { h.refreshModelsCache() h.modelsCacheMu.RLock() cached = h.cachedModels h.modelsCacheMu.RUnlock() } thinkingSuffix := config.GetThinkingConfig().Suffix models := buildAnthropicModelsResponse(cached, thinkingSuffix) if len(models) == 0 { models = fallbackAnthropicModels(thinkingSuffix) } // 添加别名模型 models = append(models, buildModelInfo("auto", "kiro-proxy", true), buildModelInfo("gpt-4o", "kiro-proxy", true), buildModelInfo("gpt-4", "kiro-proxy", true), ) w.Header().Set("Content-Type", "application/json; charset=utf-8") json.NewEncoder(w).Encode(map[string]interface{}{ "object": "list", "data": models, }) return } func buildAnthropicModelsResponse(cached []ModelInfo, thinkingSuffix string) []map[string]interface{} { if len(cached) == 0 { return nil } models := make([]map[string]interface{}, 0, len(cached)*2) if len(cached) > 0 { for _, m := range cached { supportsImage := modelSupportsImage(m.InputTypes) models = append(models, buildModelInfo(m.ModelId, "anthropic", supportsImage)) // 自动生成 thinking 变体 models = append(models, buildModelInfo(m.ModelId+thinkingSuffix, "anthropic", supportsImage)) } } return models } func fallbackAnthropicModels(thinkingSuffix string) []map[string]interface{} { return []map[string]interface{}{ buildModelInfo("claude-sonnet-4.6", "anthropic", true), buildModelInfo("claude-sonnet-4.6"+thinkingSuffix, "anthropic", true), buildModelInfo("claude-opus-4.6", "anthropic", true), buildModelInfo("claude-opus-4.6"+thinkingSuffix, "anthropic", true), buildModelInfo("claude-opus-4.7", "anthropic", true), buildModelInfo("claude-opus-4.7"+thinkingSuffix, "anthropic", true), buildModelInfo("claude-sonnet-4.5", "anthropic", true), buildModelInfo("claude-sonnet-4.5"+thinkingSuffix, "anthropic", true), buildModelInfo("claude-sonnet-4", "anthropic", true), buildModelInfo("claude-sonnet-4"+thinkingSuffix, "anthropic", true), buildModelInfo("claude-haiku-4.5", "anthropic", true), buildModelInfo("claude-haiku-4.5"+thinkingSuffix, "anthropic", true), buildModelInfo("claude-opus-4.5", "anthropic", true), buildModelInfo("claude-opus-4.5"+thinkingSuffix, "anthropic", true), } } func modelSupportsImage(inputTypes []string) bool { for _, t := range inputTypes { lt := strings.ToLower(t) if strings.Contains(lt, "image") || strings.Contains(lt, "vision") { return true } } return false } func buildModelInfo(id, ownedBy string, supportsImage bool) map[string]interface{} { modalities := []string{"text"} if supportsImage { modalities = append(modalities, "image") } modalitiesMap := map[string][]string{ "input": modalities, "output": []string{"text"}, } return map[string]interface{}{ "id": id, "object": "model", "owned_by": ownedBy, "supports_image": supportsImage, "input_modalities": modalities, "modalities": modalitiesMap, "capabilities": map[string]bool{ "vision": supportsImage, "image": supportsImage, "image_vision": supportsImage, }, "info": map[string]interface{}{ "meta": map[string]interface{}{ "capabilities": map[string]bool{ "vision": supportsImage, "image_vision": supportsImage, }, }, }, } } // refreshModelsCache 从 Kiro API 拉取模型列表并缓存 func (h *Handler) refreshModelsCache() { accounts := config.GetEnabledAccounts() if len(accounts) == 0 { return } aggregated := make([]ModelInfo, 0) for i := range accounts { account := &accounts[i] if err := h.ensureValidToken(account); err != nil { fmt.Printf("[ModelsCache] Skip %s token refresh failed: %v\n", account.Email, err) continue } models, err := ListAvailableModels(account) if err != nil { fmt.Printf("[ModelsCache] Failed to refresh for %s: %v\n", account.Email, err) continue } aggregated = mergeUniqueModels(aggregated, models) } if len(aggregated) > 0 { h.modelsCacheMu.Lock() h.cachedModels = aggregated h.modelsCacheTime = time.Now().Unix() h.modelsCacheMu.Unlock() fmt.Printf("[ModelsCache] Cached %d models\n", len(aggregated)) } } func mergeUniqueModels(existing []ModelInfo, incoming []ModelInfo) []ModelInfo { if len(incoming) == 0 { return existing } indexByID := make(map[string]int, len(existing)) merged := make([]ModelInfo, len(existing)) copy(merged, existing) for i, model := range merged { indexByID[strings.ToLower(strings.TrimSpace(model.ModelId))] = i } for _, model := range incoming { key := strings.ToLower(strings.TrimSpace(model.ModelId)) if key == "" { continue } if idx, ok := indexByID[key]; ok { merged[idx] = mergeModelInfo(merged[idx], model) continue } indexByID[key] = len(merged) merged = append(merged, model) } return merged } func mergeModelInfo(base ModelInfo, extra ModelInfo) ModelInfo { if base.ModelName == "" { base.ModelName = extra.ModelName } if base.Description == "" { base.Description = extra.Description } if base.RateMultiplier == 0 { base.RateMultiplier = extra.RateMultiplier } if base.TokenLimits == nil { base.TokenLimits = extra.TokenLimits } base.InputTypes = mergeStringLists(base.InputTypes, extra.InputTypes) return base } func mergeStringLists(base []string, extra []string) []string { if len(extra) == 0 { return base } seen := make(map[string]bool, len(base)+len(extra)) merged := make([]string, 0, len(base)+len(extra)) for _, item := range base { key := strings.ToLower(strings.TrimSpace(item)) if key == "" || seen[key] { continue } seen[key] = true merged = append(merged, item) } for _, item := range extra { key := strings.ToLower(strings.TrimSpace(item)) if key == "" || seen[key] { continue } seen[key] = true merged = append(merged, item) } return merged } // handleCountTokens Token 计数(Claude Code 会调用) func (h *Handler) handleCountTokens(w http.ResponseWriter, r *http.Request) { if r.Method != "POST" { http.Error(w, "Method Not Allowed", 405) return } body, err := io.ReadAll(r.Body) if err != nil { h.sendClaudeError(w, 400, "invalid_request_error", "Failed to read request body") return } var req ClaudeRequest if err := json.Unmarshal(body, &req); err != nil { h.sendClaudeError(w, 400, "invalid_request_error", "Invalid JSON") return } if msg := validateClaudeThinkingConfig(req.Thinking, req.MaxTokens); msg != "" { h.sendClaudeError(w, 400, "invalid_request_error", msg) return } thinkingCfg := config.GetThinkingConfig() actualModel, thinking := resolveClaudeThinkingMode(req.Model, req.Thinking, thinkingCfg.Suffix) req.Model = actualModel effectiveReq := cloneClaudeRequestForThinking(&req, thinking) estimatedTokens := estimateClaudeRequestInputTokens(effectiveReq) if estimatedTokens < 1 { estimatedTokens = 1 } w.Header().Set("Content-Type", "application/json; charset=utf-8") json.NewEncoder(w).Encode(map[string]int{"input_tokens": estimatedTokens}) } // handleClaudeMessages Claude API 处理 func (h *Handler) handleClaudeMessages(w http.ResponseWriter, r *http.Request) { h.handleClaudeMessagesInternal(w, r) } func (h *Handler) handleClaudeMessagesInternal(w http.ResponseWriter, r *http.Request) { if r.Method != "POST" { http.Error(w, "Method Not Allowed", 405) return } // 读取请求 body, err := io.ReadAll(r.Body) if err != nil { h.sendClaudeError(w, 400, "invalid_request_error", "Failed to read request body") return } var req ClaudeRequest if err := json.Unmarshal(body, &req); err != nil { h.sendClaudeError(w, 400, "invalid_request_error", "Invalid JSON: "+err.Error()) return } if msg := validateClaudeRequestShape(&req); msg != "" { h.sendClaudeError(w, 400, "invalid_request_error", msg) return } // 获取账号 account := h.pool.GetNext() if account == nil { h.sendClaudeError(w, 503, "api_error", "No available accounts") return } // 检查并刷新 token if err := h.ensureValidToken(account); err != nil { h.sendClaudeError(w, 503, "api_error", "Token refresh failed: "+err.Error()) return } // 解析模型和 thinking 模式 thinkingCfg := config.GetThinkingConfig() actualModel, thinking := resolveClaudeThinkingMode(req.Model, req.Thinking, thinkingCfg.Suffix) req.Model = actualModel effectiveReq := cloneClaudeRequestForThinking(&req, thinking) thinkingResponseOpts := resolveClaudeThinkingResponseOptions(req.Thinking, thinkingCfg.ClaudeFormat) estimatedInputTokens := estimateClaudeRequestInputTokens(effectiveReq) cacheProfile := h.promptCache.BuildClaudeProfile(effectiveReq, estimatedInputTokens) cacheUsage := h.promptCache.Compute(account.ID, cacheProfile) // 转换请求 kiroPayload := ClaudeToKiro(&req, thinking) // Stream or non-stream if req.Stream { h.handleClaudeStream(w, account, kiroPayload, req.Model, thinking, thinkingResponseOpts, estimatedInputTokens, cacheUsage, cacheProfile) } else { h.handleClaudeNonStream(w, account, kiroPayload, req.Model, thinking, thinkingResponseOpts, estimatedInputTokens, cacheUsage, cacheProfile) } } // handleClaudeStream Claude 流式响应 func (h *Handler) handleClaudeStream(w http.ResponseWriter, account *config.Account, payload *KiroPayload, model string, thinking bool, thinkingOpts claudeThinkingResponseOptions, estimatedInputTokens int, cacheUsage promptCacheUsage, cacheProfile *promptCacheProfile) { w.Header().Set("Content-Type", "text/event-stream; charset=utf-8") w.Header().Set("Cache-Control", "no-cache") w.Header().Set("Connection", "keep-alive") flusher, ok := w.(http.Flusher) if !ok { h.sendClaudeError(w, 500, "api_error", "Streaming not supported") return } // 获取 thinking 输出格式配置 thinkingFormat := thinkingOpts.Format msgID := "msg_" + uuid.New().String() var inputTokens, outputTokens int var credits float64 var realInputTokens int var toolUses []KiroToolUse var nextContentIndex int var rawContentBuilder strings.Builder var rawThinkingBuilder strings.Builder activeBlockIndex := -1 activeBlockType := "" startInputTokens := estimatedInputTokens closeActiveBlock := func() { if activeBlockIndex < 0 { return } h.sendSSE(w, flusher, "content_block_stop", map[string]interface{}{ "type": "content_block_stop", "index": activeBlockIndex, }) activeBlockIndex = -1 activeBlockType = "" } startContentBlock := func(blockType string) { if activeBlockType == blockType { return } closeActiveBlock() idx := nextContentIndex nextContentIndex++ if blockType == "thinking" { h.sendSSE(w, flusher, "content_block_start", map[string]interface{}{ "type": "content_block_start", "index": idx, "content_block": map[string]string{ "type": "thinking", "thinking": "", }, }) } else { h.sendSSE(w, flusher, "content_block_start", map[string]interface{}{ "type": "content_block_start", "index": idx, "content_block": map[string]string{ "type": "text", "text": "", }, }) } activeBlockIndex = idx activeBlockType = blockType } // Thinking 标签解析状态 var textBuffer string var inThinkingBlock bool var dropTagThinking bool var thinkingSource thinkingStreamSource // 发送文本的辅助函数 // thinkingState: 0=普通内容, 1=thinking开始, 2=thinking中间, 3=thinking结束 sendText := func(text string, thinkingState int) { if thinkingState == 0 { // 普通内容 if text == "" { return } startContentBlock("text") h.sendSSE(w, flusher, "content_block_delta", map[string]interface{}{ "type": "content_block_delta", "index": activeBlockIndex, "delta": map[string]string{"type": "text_delta", "text": text}, }) return } if !thinking { return } switch thinkingFormat { case "think": var outputText string switch thinkingState { case 1: outputText = "" + text case 2: outputText = text case 3: outputText = text + "" } if outputText == "" { return } startContentBlock("text") h.sendSSE(w, flusher, "content_block_delta", map[string]interface{}{ "type": "content_block_delta", "index": activeBlockIndex, "delta": map[string]string{"type": "text_delta", "text": outputText}, }) case "reasoning_content": if text == "" { return } startContentBlock("text") h.sendSSE(w, flusher, "content_block_delta", map[string]interface{}{ "type": "content_block_delta", "index": activeBlockIndex, "delta": map[string]string{"type": "text_delta", "text": text}, }) default: if thinkingOpts.OmitDisplay { if thinkingState == 1 { startContentBlock("thinking") return } if thinkingState == 3 { if activeBlockType != "thinking" { startContentBlock("thinking") } closeActiveBlock() } return } if thinkingState == 3 && text == "" { if activeBlockType == "thinking" { closeActiveBlock() } return } if text != "" { startContentBlock("thinking") h.sendSSE(w, flusher, "content_block_delta", map[string]interface{}{ "type": "content_block_delta", "index": activeBlockIndex, "delta": map[string]string{"type": "thinking_delta", "thinking": text}, }) } if thinkingState == 3 && activeBlockType == "thinking" { closeActiveBlock() } } } // 处理文本,解析 标签 var thinkingStarted bool var eventThinkingOpen bool processClaudeText := func(text string, isThinking bool, forceFlush bool) { if isThinking && !thinking { return } // 如果是 reasoningContentEvent,直接输出 if isThinking { if !allowReasoningSource(&thinkingSource) { return } if !thinkingStarted { sendText(text, 1) thinkingStarted = true eventThinkingOpen = true } else { sendText(text, 2) } return } if eventThinkingOpen { sendText("", 3) eventThinkingOpen = false thinkingStarted = false } textBuffer += text for { if !inThinkingBlock { thinkingStart := strings.Index(textBuffer, "") if thinkingStart != -1 { if thinkingStart > 0 { sendText(textBuffer[:thinkingStart], 0) } textBuffer = textBuffer[thinkingStart+10:] inThinkingBlock = true dropTagThinking = !allowTagSource(&thinkingSource) thinkingStarted = false } else if forceFlush || len([]rune(textBuffer)) > 50 { // 使用 rune 切片来正确处理 Unicode 字符 runes := []rune(textBuffer) safeLen := len(runes) if !forceFlush { safeLen = max(0, len(runes)-15) } if safeLen > 0 { sendText(string(runes[:safeLen]), 0) textBuffer = string(runes[safeLen:]) } break } else { break } } else { thinkingEnd := strings.Index(textBuffer, "") if thinkingEnd != -1 { content := textBuffer[:thinkingEnd] if !dropTagThinking { if !thinkingStarted { sendText(content, 1) sendText("", 3) } else { sendText(content, 3) } } textBuffer = textBuffer[thinkingEnd+11:] inThinkingBlock = false dropTagThinking = false thinkingStarted = false } else if forceFlush { if textBuffer != "" { if !dropTagThinking { if !thinkingStarted { sendText(textBuffer, 1) sendText("", 3) } else { sendText(textBuffer, 3) } } textBuffer = "" } inThinkingBlock = false dropTagThinking = false thinkingStarted = false break } else { // 流式输出 thinking 块内的内容 runes := []rune(textBuffer) if len(runes) > 20 { safeLen := len(runes) - 15 if safeLen > 0 { if !dropTagThinking { if !thinkingStarted { sendText(string(runes[:safeLen]), 1) thinkingStarted = true } else { sendText(string(runes[:safeLen]), 2) } } textBuffer = string(runes[safeLen:]) } } break } } } } // 发送 message_start h.sendSSE(w, flusher, "message_start", map[string]interface{}{ "type": "message_start", "message": map[string]interface{}{ "id": msgID, "type": "message", "role": "assistant", "content": []interface{}{}, "model": model, "stop_reason": nil, "stop_sequence": nil, "usage": buildClaudeUsageMap(startInputTokens, 0, cacheUsage, cacheProfile != nil), }, }) callback := &KiroStreamCallback{ OnText: func(text string, isThinking bool) { if text == "" { return } if isThinking { rawThinkingBuilder.WriteString(text) } else { rawContentBuilder.WriteString(text) } processClaudeText(text, isThinking, false) }, OnToolUse: func(tu KiroToolUse) { // 先刷新缓冲区 processClaudeText("", false, true) rawContentBuilder.WriteString(tu.Name) if b, err := json.Marshal(tu.Input); err == nil { rawContentBuilder.Write(b) } toolUses = append(toolUses, tu) closeActiveBlock() idx := nextContentIndex nextContentIndex++ h.sendSSE(w, flusher, "content_block_start", map[string]interface{}{ "type": "content_block_start", "index": idx, "content_block": map[string]interface{}{ "type": "tool_use", "id": tu.ToolUseID, "name": tu.Name, "input": map[string]interface{}{}, }, }) inputJSON, _ := json.Marshal(tu.Input) h.sendSSE(w, flusher, "content_block_delta", map[string]interface{}{ "type": "content_block_delta", "index": idx, "delta": map[string]interface{}{ "type": "input_json_delta", "partial_json": string(inputJSON), }, }) h.sendSSE(w, flusher, "content_block_stop", map[string]interface{}{ "type": "content_block_stop", "index": idx, }) }, OnComplete: func(inTok, outTok int) { inputTokens = inTok outputTokens = outTok }, OnError: func(err error) { h.pool.RecordError(account.ID, strings.Contains(err.Error(), "429") || strings.Contains(err.Error(), "quota")) }, OnCredits: func(c float64) { credits = c }, OnContextUsage: func(pct float64) { realInputTokens = int(pct * float64(getContextWindowSize(model)) / 100.0) }, } err := CallKiroAPI(account, payload, callback) if err != nil { h.recordFailure() h.pool.RecordError(account.ID, strings.Contains(err.Error(), "429") || strings.Contains(err.Error(), "quota")) h.sendSSE(w, flusher, "error", map[string]interface{}{ "type": "error", "error": map[string]string{"type": "api_error", "message": err.Error()}, }) return } // 刷新剩余缓冲区 processClaudeText("", false, true) if eventThinkingOpen { sendText("", 3) eventThinkingOpen = false } closeActiveBlock() if realInputTokens > 0 { inputTokens = realInputTokens } else if inputTokens <= 0 { inputTokens = estimatedInputTokens } outputContent, extractedReasoning := extractThinkingFromContent(rawContentBuilder.String()) thinkingOutput := rawThinkingBuilder.String() if thinking && thinkingOutput == "" && extractedReasoning != "" { thinkingOutput = extractedReasoning } if !thinking { thinkingOutput = "" } outputTokens = estimateClaudeOutputTokens(outputContent, thinkingOutput, toolUses) h.recordSuccess(inputTokens, outputTokens, credits) h.pool.RecordSuccess(account.ID) h.pool.UpdateStats(account.ID, inputTokens+outputTokens, credits) h.promptCache.Update(account.ID, cacheProfile) // 发送 message_delta stopReason := "end_turn" if len(toolUses) > 0 { stopReason = "tool_use" } h.sendSSE(w, flusher, "message_delta", map[string]interface{}{ "type": "message_delta", "delta": map[string]interface{}{ "stop_reason": stopReason, }, "usage": buildClaudeUsageMap(inputTokens, outputTokens, cacheUsage, cacheProfile != nil), }) h.sendSSE(w, flusher, "message_stop", map[string]interface{}{ "type": "message_stop", }) } func (h *Handler) sendSSE(w http.ResponseWriter, flusher http.Flusher, event string, data interface{}) { jsonData, _ := json.Marshal(data) fmt.Fprintf(w, "event: %s\ndata: %s\n\n", event, string(jsonData)) flusher.Flush() } // backgroundStatsSaver 后台定时保存统计数据 func (h *Handler) backgroundStatsSaver() { ticker := time.NewTicker(30 * time.Second) defer ticker.Stop() for { select { case <-ticker.C: h.saveStats() case <-h.stopStatsSaver: h.saveStats() // 退出前保存一次 return } } } // saveStats 保存统计到配置文件 func (h *Handler) saveStats() { config.UpdateStats( int(atomic.LoadInt64(&h.totalRequests)), int(atomic.LoadInt64(&h.successRequests)), int(atomic.LoadInt64(&h.failedRequests)), int(atomic.LoadInt64(&h.totalTokens)), h.getCredits(), ) } // getCredits 线程安全获取 credits func (h *Handler) getCredits() float64 { h.creditsMu.RLock() defer h.creditsMu.RUnlock() return h.totalCredits } // addCredits 线程安全增加 credits func (h *Handler) addCredits(credits float64) { h.creditsMu.Lock() h.totalCredits += credits h.creditsMu.Unlock() } // 统计记录 (使用原子操作) func (h *Handler) recordSuccess(inputTokens, outputTokens int, credits float64) { atomic.AddInt64(&h.totalRequests, 1) atomic.AddInt64(&h.successRequests, 1) atomic.AddInt64(&h.totalTokens, int64(inputTokens+outputTokens)) h.addCredits(credits) } func (h *Handler) recordFailure() { atomic.AddInt64(&h.totalRequests, 1) atomic.AddInt64(&h.failedRequests, 1) } // handleClaudeNonStream Claude 非流式响应 func (h *Handler) handleClaudeNonStream(w http.ResponseWriter, account *config.Account, payload *KiroPayload, model string, thinking bool, thinkingOpts claudeThinkingResponseOptions, estimatedInputTokens int, cacheUsage promptCacheUsage, cacheProfile *promptCacheProfile) { var content string var thinkingContent string var toolUses []KiroToolUse var inputTokens, outputTokens int var credits float64 var realInputTokens int callback := &KiroStreamCallback{ OnText: func(text string, isThinking bool) { if isThinking { thinkingContent += text } else { content += text } }, OnToolUse: func(tu KiroToolUse) { toolUses = append(toolUses, tu) }, OnComplete: func(inTok, outTok int) { inputTokens = inTok outputTokens = outTok }, OnError: func(err error) { h.pool.RecordError(account.ID, strings.Contains(err.Error(), "429")) }, OnCredits: func(c float64) { credits = c }, OnContextUsage: func(pct float64) { realInputTokens = int(pct * float64(getContextWindowSize(model)) / 100.0) }, } err := CallKiroAPI(account, payload, callback) if err != nil { h.recordFailure() h.pool.RecordError(account.ID, strings.Contains(err.Error(), "429")) h.sendClaudeError(w, 500, "api_error", err.Error()) return } // 合并 thinking 内容(如果有 reasoningContentEvent 的内容) thinkingFormat := thinkingOpts.Format finalContent, extractedReasoning := extractThinkingFromContent(content) rawThinkingContent := thinkingContent if thinking && rawThinkingContent == "" && extractedReasoning != "" { rawThinkingContent = extractedReasoning } if !thinking { rawThinkingContent = "" } if realInputTokens > 0 { inputTokens = realInputTokens } else if inputTokens <= 0 { inputTokens = estimatedInputTokens } outputTokens = estimateClaudeOutputTokens(finalContent, rawThinkingContent, toolUses) h.recordSuccess(inputTokens, outputTokens, credits) h.pool.RecordSuccess(account.ID) h.pool.UpdateStats(account.ID, inputTokens+outputTokens, credits) h.promptCache.Update(account.ID, cacheProfile) responseThinkingContent := rawThinkingContent includeEmptyThinkingBlock := thinking && thinkingOpts.OmitDisplay && rawThinkingContent != "" if includeEmptyThinkingBlock { responseThinkingContent = "" } if thinking && responseThinkingContent != "" { switch thinkingFormat { case "think": finalContent = "" + responseThinkingContent + "" + finalContent responseThinkingContent = "" case "reasoning_content": finalContent = responseThinkingContent + finalContent // Claude 格式不支持 reasoning_content,直接拼接 responseThinkingContent = "" default: } } resp := KiroToClaudeResponse(finalContent, responseThinkingContent, includeEmptyThinkingBlock, toolUses, inputTokens, outputTokens, model) resp.Usage.InputTokens = billedClaudeInputTokens(inputTokens, cacheUsage) resp.Usage.CacheCreationInputTokens = cacheUsage.CacheCreationInputTokens resp.Usage.CacheReadInputTokens = cacheUsage.CacheReadInputTokens if cacheProfile != nil { resp.Usage.CacheCreation = &ClaudeCacheCreationUsage{ Ephemeral5mInputTokens: cacheUsage.CacheCreation5mInputTokens, Ephemeral1hInputTokens: cacheUsage.CacheCreation1hInputTokens, } } w.Header().Set("Content-Type", "application/json; charset=utf-8") json.NewEncoder(w).Encode(resp) } func (h *Handler) sendClaudeError(w http.ResponseWriter, status int, errType, message string) { w.Header().Set("Content-Type", "application/json; charset=utf-8") w.WriteHeader(status) json.NewEncoder(w).Encode(map[string]interface{}{ "type": "error", "error": map[string]string{ "type": errType, "message": message, }, }) } // handleOpenAIChat OpenAI API 处理 func (h *Handler) handleOpenAIChat(w http.ResponseWriter, r *http.Request) { if r.Method != "POST" { http.Error(w, "Method Not Allowed", 405) return } body, err := io.ReadAll(r.Body) if err != nil { h.sendOpenAIError(w, 400, "invalid_request_error", "Failed to read request body") return } var req OpenAIRequest if err := json.Unmarshal(body, &req); err != nil { h.sendOpenAIError(w, 400, "invalid_request_error", "Invalid JSON") return } if msg := validateOpenAIRequestShape(&req); msg != "" { h.sendOpenAIError(w, 400, "invalid_request_error", msg) return } account := h.pool.GetNext() if account == nil { h.sendOpenAIError(w, 503, "server_error", "No available accounts") return } if err := h.ensureValidToken(account); err != nil { h.sendOpenAIError(w, 503, "server_error", "Token refresh failed") return } // 解析模型和 thinking 模式 thinkingCfg := config.GetThinkingConfig() actualModel, thinking := ParseModelAndThinking(req.Model, thinkingCfg.Suffix) req.Model = actualModel estimatedInputTokens := estimateOpenAIRequestInputTokens(&req) kiroPayload := OpenAIToKiro(&req, thinking) if req.Stream { h.handleOpenAIStream(w, account, kiroPayload, req.Model, thinking, estimatedInputTokens) } else { h.handleOpenAINonStream(w, account, kiroPayload, req.Model, thinking, estimatedInputTokens) } } // handleOpenAIStream OpenAI 流式响应 func (h *Handler) handleOpenAIStream(w http.ResponseWriter, account *config.Account, payload *KiroPayload, model string, thinking bool, estimatedInputTokens int) { w.Header().Set("Content-Type", "text/event-stream; charset=utf-8") w.Header().Set("Cache-Control", "no-cache") w.Header().Set("Connection", "keep-alive") flusher, ok := w.(http.Flusher) if !ok { h.sendOpenAIError(w, 500, "server_error", "Streaming not supported") return } // 获取 thinking 输出格式配置 thinkingFormat := config.GetThinkingConfig().OpenAIFormat chatID := "chatcmpl-" + uuid.New().String() var toolCalls []ToolCall var toolCallIndex int var inputTokens, outputTokens int var credits float64 var realInputTokens int var rawContentBuilder strings.Builder var rawReasoningBuilder strings.Builder // Thinking 标签解析状态 var textBuffer string var inThinkingBlock bool var dropTagThinking bool var thinkingSource thinkingStreamSource // 发送 chunk 的辅助函数 // thinkingState: 0=普通内容, 1=thinking开始, 2=thinking中间, 3=thinking结束 sendChunk := func(content string, thinkingState int) { if content == "" && thinkingState == 2 { return } var chunk map[string]interface{} if thinkingState > 0 { if !thinking { return } // thinking 内容 switch thinkingFormat { case "thinking": // 流式输出标签 var text string switch thinkingState { case 1: // 开始 text = "" + content case 2: // 中间 text = content case 3: // 结束 text = content + "" } if text == "" { return } chunk = map[string]interface{}{ "id": chatID, "object": "chat.completion.chunk", "created": time.Now().Unix(), "model": model, "choices": []map[string]interface{}{{ "index": 0, "delta": map[string]string{"content": text}, "finish_reason": nil, }}, } case "think": var text string switch thinkingState { case 1: text = "" + content case 2: text = content case 3: text = content + "" } if text == "" { return } chunk = map[string]interface{}{ "id": chatID, "object": "chat.completion.chunk", "created": time.Now().Unix(), "model": model, "choices": []map[string]interface{}{{ "index": 0, "delta": map[string]string{"content": text}, "finish_reason": nil, }}, } default: // "reasoning_content" if content == "" { return } chunk = map[string]interface{}{ "id": chatID, "object": "chat.completion.chunk", "created": time.Now().Unix(), "model": model, "choices": []map[string]interface{}{{ "index": 0, "delta": map[string]string{"reasoning_content": content}, "finish_reason": nil, }}, } } } else { // 普通内容 if content == "" { return } chunk = map[string]interface{}{ "id": chatID, "object": "chat.completion.chunk", "created": time.Now().Unix(), "model": model, "choices": []map[string]interface{}{{ "index": 0, "delta": map[string]string{"content": content}, "finish_reason": nil, }}, } } data, _ := json.Marshal(chunk) fmt.Fprintf(w, "data: %s\n\n", string(data)) flusher.Flush() } // 处理文本,解析 标签 // thinkingStarted 用于跟踪是否已发送开始标签 var thinkingStarted bool var eventThinkingOpen bool processText := func(text string, isThinking bool, forceFlush bool) { if isThinking && !thinking { return } // 如果是 reasoningContentEvent,直接输出 if isThinking { if !allowReasoningSource(&thinkingSource) { return } if !thinkingStarted { sendChunk(text, 1) // 开始 thinkingStarted = true eventThinkingOpen = true } else { sendChunk(text, 2) // 中间 } return } if eventThinkingOpen { sendChunk("", 3) eventThinkingOpen = false thinkingStarted = false } textBuffer += text for { if !inThinkingBlock { // 查找 开始标签 thinkingStart := strings.Index(textBuffer, "") if thinkingStart != -1 { // 输出 thinking 标签之前的内容 if thinkingStart > 0 { sendChunk(textBuffer[:thinkingStart], 0) } textBuffer = textBuffer[thinkingStart+10:] // 移除 inThinkingBlock = true dropTagThinking = !allowTagSource(&thinkingSource) thinkingStarted = false // 重置,准备发送新的开始标签 } else if forceFlush || len([]rune(textBuffer)) > 50 { // 没有找到标签,安全输出(保留可能的部分标签) runes := []rune(textBuffer) safeLen := len(runes) if !forceFlush { safeLen = max(0, len(runes)-15) } if safeLen > 0 { sendChunk(string(runes[:safeLen]), 0) textBuffer = string(runes[safeLen:]) } break } else { break } } else { // 在 thinking 块内,查找 结束标签 thinkingEnd := strings.Index(textBuffer, "") if thinkingEnd != -1 { // 输出 thinking 内容 content := textBuffer[:thinkingEnd] if !dropTagThinking { if !thinkingStarted { // 一次性输出完整内容(开始+内容+结束) sendChunk(content, 1) // 开始 sendChunk("", 3) // 结束(空内容,只发结束标签) } else { // 已经开始了,发送剩余内容和结束 sendChunk(content, 3) // 结束 } } textBuffer = textBuffer[thinkingEnd+11:] // 移除 inThinkingBlock = false dropTagThinking = false thinkingStarted = false } else if forceFlush { // 强制刷新:输出剩余内容 if textBuffer != "" { if !dropTagThinking { if !thinkingStarted { sendChunk(textBuffer, 1) // 开始 sendChunk("", 3) // 结束 } else { sendChunk(textBuffer, 3) // 结束 } } textBuffer = "" } inThinkingBlock = false dropTagThinking = false thinkingStarted = false break } else { // 流式输出 thinking 块内的内容 runes := []rune(textBuffer) if len(runes) > 20 { safeLen := len(runes) - 15 // 保留可能的 部分 if safeLen > 0 { if !dropTagThinking { if !thinkingStarted { sendChunk(string(runes[:safeLen]), 1) // 开始 thinkingStarted = true } else { sendChunk(string(runes[:safeLen]), 2) // 中间 } } textBuffer = string(runes[safeLen:]) } } break } } } } callback := &KiroStreamCallback{ OnText: func(text string, isThinking bool) { if text == "" { return } if isThinking { rawReasoningBuilder.WriteString(text) } else { rawContentBuilder.WriteString(text) } processText(text, isThinking, false) }, OnToolUse: func(tu KiroToolUse) { // 先刷新缓冲区 processText("", false, true) args, _ := json.Marshal(tu.Input) rawContentBuilder.WriteString(tu.Name) rawContentBuilder.Write(args) tc := ToolCall{ID: tu.ToolUseID, Type: "function"} tc.Function.Name = tu.Name tc.Function.Arguments = string(args) toolCalls = append(toolCalls, tc) chunk := map[string]interface{}{ "id": chatID, "object": "chat.completion.chunk", "created": time.Now().Unix(), "model": model, "choices": []map[string]interface{}{{ "index": 0, "delta": map[string]interface{}{ "tool_calls": []map[string]interface{}{{ "index": toolCallIndex, "id": tu.ToolUseID, "type": "function", "function": map[string]string{ "name": tu.Name, "arguments": string(args), }, }}, }, "finish_reason": nil, }}, } toolCallIndex++ data, _ := json.Marshal(chunk) fmt.Fprintf(w, "data: %s\n\n", string(data)) flusher.Flush() }, OnComplete: func(inTok, outTok int) { inputTokens = inTok outputTokens = outTok }, OnError: func(err error) { h.pool.RecordError(account.ID, strings.Contains(err.Error(), "429")) }, OnCredits: func(c float64) { credits = c }, OnContextUsage: func(pct float64) { realInputTokens = int(pct * float64(getContextWindowSize(model)) / 100.0) }, } err := CallKiroAPI(account, payload, callback) if err != nil { h.recordFailure() h.pool.RecordError(account.ID, strings.Contains(err.Error(), "429")) return } // 刷新剩余缓冲区 processText("", false, true) if eventThinkingOpen { sendChunk("", 3) eventThinkingOpen = false } if realInputTokens > 0 { inputTokens = realInputTokens } else if inputTokens <= 0 { inputTokens = estimatedInputTokens } outputContent, extractedReasoning := extractThinkingFromContent(rawContentBuilder.String()) reasoningOutput := rawReasoningBuilder.String() if thinking && reasoningOutput == "" && extractedReasoning != "" { reasoningOutput = extractedReasoning } if !thinking { reasoningOutput = "" } outputTokens = estimateApproxTokens(outputContent) + estimateApproxTokens(reasoningOutput) for _, tc := range toolCalls { outputTokens += estimateApproxTokens(tc.Function.Name) outputTokens += estimateApproxTokens(tc.Function.Arguments) } h.recordSuccess(inputTokens, outputTokens, credits) h.pool.RecordSuccess(account.ID) h.pool.UpdateStats(account.ID, inputTokens+outputTokens, credits) // 发送结束 finishReason := "stop" if len(toolCalls) > 0 { finishReason = "tool_calls" } chunk := map[string]interface{}{ "id": chatID, "object": "chat.completion.chunk", "created": time.Now().Unix(), "model": model, "choices": []map[string]interface{}{{ "index": 0, "delta": map[string]interface{}{}, "finish_reason": finishReason, }}, "usage": map[string]int{ "prompt_tokens": inputTokens, "completion_tokens": outputTokens, "total_tokens": inputTokens + outputTokens, }, } data, _ := json.Marshal(chunk) fmt.Fprintf(w, "data: %s\n\n", string(data)) fmt.Fprintf(w, "data: [DONE]\n\n") flusher.Flush() } // handleOpenAINonStream OpenAI 非流式响应 func (h *Handler) handleOpenAINonStream(w http.ResponseWriter, account *config.Account, payload *KiroPayload, model string, thinking bool, estimatedInputTokens int) { var content string var reasoningContent string var toolUses []KiroToolUse var inputTokens, outputTokens int var credits float64 var realInputTokens int callback := &KiroStreamCallback{ OnText: func(text string, isThinking bool) { if isThinking { reasoningContent += text } else { content += text } }, OnToolUse: func(tu KiroToolUse) { toolUses = append(toolUses, tu) }, OnComplete: func(inTok, outTok int) { inputTokens = inTok; outputTokens = outTok }, OnError: func(err error) { h.pool.RecordError(account.ID, strings.Contains(err.Error(), "429")) }, OnCredits: func(c float64) { credits = c }, OnContextUsage: func(pct float64) { realInputTokens = int(pct * float64(getContextWindowSize(model)) / 100.0) }, } err := CallKiroAPI(account, payload, callback) if err != nil { h.recordFailure() h.pool.RecordError(account.ID, strings.Contains(err.Error(), "429")) h.sendOpenAIError(w, 500, "server_error", err.Error()) return } // 解析 content 中的 标签 finalContent, extractedReasoning := extractThinkingFromContent(content) if thinking && reasoningContent == "" && extractedReasoning != "" { reasoningContent = extractedReasoning } else if !thinking { reasoningContent = "" } if realInputTokens > 0 { inputTokens = realInputTokens } else if inputTokens <= 0 { inputTokens = estimatedInputTokens } outputTokens = estimateOpenAIOutputTokens(finalContent, reasoningContent, toolUses) h.recordSuccess(inputTokens, outputTokens, credits) h.pool.RecordSuccess(account.ID) h.pool.UpdateStats(account.ID, inputTokens+outputTokens, credits) thinkingFormat := config.GetThinkingConfig().OpenAIFormat resp := KiroToOpenAIResponseWithReasoning(finalContent, reasoningContent, toolUses, inputTokens, outputTokens, model, thinkingFormat) w.Header().Set("Content-Type", "application/json; charset=utf-8") json.NewEncoder(w).Encode(resp) } func (h *Handler) sendOpenAIError(w http.ResponseWriter, status int, errType, message string) { w.Header().Set("Content-Type", "application/json; charset=utf-8") w.WriteHeader(status) json.NewEncoder(w).Encode(map[string]interface{}{ "error": map[string]interface{}{ "type": errType, "message": message, }, }) } // ensureValidToken 确保 token 有效 func (h *Handler) ensureValidToken(account *config.Account) error { if account.ExpiresAt == 0 || time.Now().Unix() < account.ExpiresAt-300 { return nil } accessToken, refreshToken, expiresAt, err := auth.RefreshToken(account) if err != nil { return err } // 更新内存 h.pool.UpdateToken(account.ID, accessToken, refreshToken, expiresAt) account.AccessToken = accessToken if refreshToken != "" { account.RefreshToken = refreshToken } account.ExpiresAt = expiresAt // 持久化 config.UpdateAccountToken(account.ID, accessToken, refreshToken, expiresAt) return nil } // ==================== 管理 API ==================== func (h *Handler) handleAdminAPI(w http.ResponseWriter, r *http.Request) { // 验证密码 password := r.Header.Get("X-Admin-Password") if password == "" { cookie, _ := r.Cookie("admin_password") if cookie != nil { password = cookie.Value } } if password != config.GetPassword() { w.WriteHeader(401) json.NewEncoder(w).Encode(map[string]string{"error": "Unauthorized"}) return } path := strings.TrimPrefix(r.URL.Path, "/admin/api") w.Header().Set("Content-Type", "application/json; charset=utf-8") switch { case path == "/accounts" && r.Method == "GET": h.apiGetAccounts(w, r) case path == "/accounts" && r.Method == "POST": h.apiAddAccount(w, r) case path == "/accounts/batch" && r.Method == "POST": h.apiBatchAccounts(w, r) case strings.HasPrefix(path, "/accounts/") && strings.HasSuffix(path, "/refresh") && r.Method == "POST": id := strings.TrimSuffix(strings.TrimPrefix(path, "/accounts/"), "/refresh") h.apiRefreshAccount(w, r, id) case strings.HasPrefix(path, "/accounts/") && strings.HasSuffix(path, "/models") && r.Method == "GET": id := strings.TrimSuffix(strings.TrimPrefix(path, "/accounts/"), "/models") h.apiGetAccountModels(w, r, id) case strings.HasPrefix(path, "/accounts/") && strings.HasSuffix(path, "/full") && r.Method == "GET": id := strings.TrimSuffix(strings.TrimPrefix(path, "/accounts/"), "/full") h.apiGetAccountFull(w, r, id) case strings.HasPrefix(path, "/accounts/") && r.Method == "DELETE": h.apiDeleteAccount(w, r, strings.TrimPrefix(path, "/accounts/")) case strings.HasPrefix(path, "/accounts/") && r.Method == "PUT": h.apiUpdateAccount(w, r, strings.TrimPrefix(path, "/accounts/")) case path == "/auth/iam-sso/start" && r.Method == "POST": h.apiStartIamSso(w, r) case path == "/auth/iam-sso/complete" && r.Method == "POST": h.apiCompleteIamSso(w, r) case path == "/auth/builderid/start" && r.Method == "POST": h.apiStartBuilderIdLogin(w, r) case path == "/auth/builderid/poll" && r.Method == "POST": h.apiPollBuilderIdAuth(w, r) case path == "/auth/sso-token" && r.Method == "POST": h.apiImportSsoToken(w, r) case path == "/auth/credentials" && r.Method == "POST": h.apiImportCredentials(w, r) case path == "/status" && r.Method == "GET": h.apiGetStatus(w, r) case path == "/settings" && r.Method == "GET": h.apiGetSettings(w, r) case path == "/settings" && r.Method == "POST": h.apiUpdateSettings(w, r) case path == "/stats" && r.Method == "GET": h.apiGetStats(w, r) case path == "/stats/reset" && r.Method == "POST": h.apiResetStats(w, r) case path == "/generate-machine-id" && r.Method == "GET": h.apiGenerateMachineId(w, r) case path == "/thinking" && r.Method == "GET": h.apiGetThinkingConfig(w, r) case path == "/thinking" && r.Method == "POST": h.apiUpdateThinkingConfig(w, r) case path == "/endpoint" && r.Method == "GET": h.apiGetEndpointConfig(w, r) case path == "/endpoint" && r.Method == "POST": h.apiUpdateEndpointConfig(w, r) case path == "/proxy" && r.Method == "GET": h.apiGetProxy(w, r) case path == "/proxy" && r.Method == "POST": h.apiUpdateProxy(w, r) case path == "/general" && r.Method == "GET": h.apiGetGeneralConfig(w, r) case path == "/general" && r.Method == "POST": h.apiUpdateGeneralConfig(w, r) case path == "/version" && r.Method == "GET": h.apiGetVersion(w, r) case path == "/export" && r.Method == "POST": h.apiExportAccounts(w, r) default: w.WriteHeader(404) json.NewEncoder(w).Encode(map[string]string{"error": "Not Found"}) } } func (h *Handler) apiGetAccounts(w http.ResponseWriter, r *http.Request) { accounts := config.GetAccounts() poolAccounts := h.pool.GetAllAccounts() // 合并运行时统计 statsMap := make(map[string]config.Account) for _, a := range poolAccounts { statsMap[a.ID] = a } // 隐藏敏感信息 result := make([]map[string]interface{}, len(accounts)) for i, a := range accounts { // 获取运行时统计 stats := statsMap[a.ID] result[i] = map[string]interface{}{ "id": a.ID, "email": a.Email, "userId": a.UserId, "nickname": a.Nickname, "authMethod": a.AuthMethod, "provider": a.Provider, "region": a.Region, "enabled": a.Enabled, "banStatus": a.BanStatus, "banReason": a.BanReason, "banTime": a.BanTime, "expiresAt": a.ExpiresAt, "hasToken": a.AccessToken != "", "machineId": a.MachineId, "weight": a.Weight, "subscriptionType": a.SubscriptionType, "subscriptionTitle": a.SubscriptionTitle, "daysRemaining": a.DaysRemaining, "usageCurrent": a.UsageCurrent, "usageLimit": a.UsageLimit, "usagePercent": a.UsagePercent, "nextResetDate": a.NextResetDate, "lastRefresh": a.LastRefresh, "trialUsageCurrent": a.TrialUsageCurrent, "trialUsageLimit": a.TrialUsageLimit, "trialUsagePercent": a.TrialUsagePercent, "trialStatus": a.TrialStatus, "trialExpiresAt": a.TrialExpiresAt, "requestCount": stats.RequestCount, "errorCount": stats.ErrorCount, "totalTokens": stats.TotalTokens, "totalCredits": stats.TotalCredits, "lastUsed": stats.LastUsed, } } json.NewEncoder(w).Encode(result) } func (h *Handler) apiAddAccount(w http.ResponseWriter, r *http.Request) { var account config.Account if err := json.NewDecoder(r.Body).Decode(&account); err != nil { w.WriteHeader(400) json.NewEncoder(w).Encode(map[string]string{"error": "Invalid JSON"}) return } if account.ID == "" { account.ID = auth.GenerateAccountID() } if account.Region == "" { account.Region = "us-east-1" } if err := config.AddAccount(account); err != nil { w.WriteHeader(500) json.NewEncoder(w).Encode(map[string]string{"error": err.Error()}) return } h.pool.Reload() json.NewEncoder(w).Encode(map[string]interface{}{"success": true, "id": account.ID}) } func (h *Handler) apiDeleteAccount(w http.ResponseWriter, r *http.Request, id string) { if err := config.DeleteAccount(id); err != nil { w.WriteHeader(500) json.NewEncoder(w).Encode(map[string]string{"error": err.Error()}) return } h.pool.Reload() json.NewEncoder(w).Encode(map[string]bool{"success": true}) } func (h *Handler) apiUpdateAccount(w http.ResponseWriter, r *http.Request, id string) { var updates map[string]interface{} if err := json.NewDecoder(r.Body).Decode(&updates); err != nil { w.WriteHeader(400) json.NewEncoder(w).Encode(map[string]string{"error": "Invalid JSON"}) return } // 获取现有账号 accounts := config.GetAccounts() var existing *config.Account for i := range accounts { if accounts[i].ID == id { existing = &accounts[i] break } } if existing == nil { w.WriteHeader(404) json.NewEncoder(w).Encode(map[string]string{"error": "Account not found"}) return } // 只更新传入的字段 if v, ok := updates["enabled"].(bool); ok { existing.Enabled = v } if v, ok := updates["nickname"].(string); ok { existing.Nickname = v } if v, ok := updates["machineId"].(string); ok { existing.MachineId = v } if v, ok := updates["weight"].(float64); ok { existing.Weight = int(v) } if err := config.UpdateAccount(id, *existing); err != nil { w.WriteHeader(500) json.NewEncoder(w).Encode(map[string]string{"error": err.Error()}) return } h.pool.Reload() json.NewEncoder(w).Encode(map[string]bool{"success": true}) } // apiBatchAccounts 批量操作账号(启用/禁用/刷新) func (h *Handler) apiBatchAccounts(w http.ResponseWriter, r *http.Request) { var req struct { IDs []string `json:"ids"` Action string `json:"action"` // "enable", "disable", "refresh" } if err := json.NewDecoder(r.Body).Decode(&req); err != nil { w.WriteHeader(400) json.NewEncoder(w).Encode(map[string]string{"error": "Invalid JSON"}) return } if len(req.IDs) == 0 { w.WriteHeader(400) json.NewEncoder(w).Encode(map[string]string{"error": "No account IDs provided"}) return } switch req.Action { case "enable", "disable": enabled := req.Action == "enable" accounts := config.GetAccounts() idSet := make(map[string]bool) for _, id := range req.IDs { idSet[id] = true } for _, a := range accounts { if idSet[a.ID] { a.Enabled = enabled if enabled && a.BanStatus != "" && a.BanStatus != "ACTIVE" { a.BanStatus = "ACTIVE" a.BanReason = "" a.BanTime = 0 } config.UpdateAccount(a.ID, a) } } h.pool.Reload() json.NewEncoder(w).Encode(map[string]interface{}{"success": true, "count": len(req.IDs)}) case "refresh": successCount := 0 failCount := 0 for _, id := range req.IDs { accounts := config.GetAccounts() var account *config.Account for i := range accounts { if accounts[i].ID == id { account = &accounts[i] break } } if account == nil { failCount++ continue } // 刷新 token if account.RefreshToken != "" { if newAccess, newRefresh, newExpires, err := auth.RefreshToken(account); err == nil { account.AccessToken = newAccess if newRefresh != "" { account.RefreshToken = newRefresh } account.ExpiresAt = newExpires config.UpdateAccountToken(id, newAccess, newRefresh, newExpires) h.pool.UpdateToken(id, newAccess, newRefresh, newExpires) } } // 刷新账户信息 info, err := RefreshAccountInfo(account) if err != nil { failCount++ continue } config.UpdateAccountInfo(id, *info) successCount++ } h.pool.Reload() json.NewEncoder(w).Encode(map[string]interface{}{ "success": true, "refreshed": successCount, "failed": failCount, }) default: w.WriteHeader(400) json.NewEncoder(w).Encode(map[string]string{"error": "Invalid action: " + req.Action}) } } func (h *Handler) apiStartIamSso(w http.ResponseWriter, r *http.Request) { var req struct { StartUrl string `json:"startUrl"` Region string `json:"region"` } if err := json.NewDecoder(r.Body).Decode(&req); err != nil { w.WriteHeader(400) json.NewEncoder(w).Encode(map[string]string{"error": "Invalid JSON"}) return } if req.StartUrl == "" { w.WriteHeader(400) json.NewEncoder(w).Encode(map[string]string{"error": "startUrl is required"}) return } sessionID, authorizeUrl, expiresIn, err := auth.StartIamSsoLogin(req.StartUrl, req.Region) if err != nil { w.WriteHeader(500) json.NewEncoder(w).Encode(map[string]string{"error": err.Error()}) return } json.NewEncoder(w).Encode(map[string]interface{}{ "sessionId": sessionID, "authorizeUrl": authorizeUrl, "expiresIn": expiresIn, }) } func (h *Handler) apiCompleteIamSso(w http.ResponseWriter, r *http.Request) { var req struct { SessionID string `json:"sessionId"` CallbackUrl string `json:"callbackUrl"` } if err := json.NewDecoder(r.Body).Decode(&req); err != nil { w.WriteHeader(400) json.NewEncoder(w).Encode(map[string]string{"error": "Invalid JSON"}) return } accessToken, refreshToken, clientID, clientSecret, region, expiresIn, err := auth.CompleteIamSsoLogin(req.SessionID, req.CallbackUrl) if err != nil { w.WriteHeader(400) json.NewEncoder(w).Encode(map[string]string{"error": err.Error()}) return } // 获取用户信息 email, _, _ := auth.GetUserInfo(accessToken) // 创建账号 account := config.Account{ ID: auth.GenerateAccountID(), Email: email, AccessToken: accessToken, RefreshToken: refreshToken, ClientID: clientID, ClientSecret: clientSecret, AuthMethod: "idc", Region: region, ExpiresAt: time.Now().Unix() + int64(expiresIn), Enabled: true, MachineId: config.GenerateMachineId(), } if err := config.AddAccount(account); err != nil { w.WriteHeader(500) json.NewEncoder(w).Encode(map[string]string{"error": err.Error()}) return } h.pool.Reload() json.NewEncoder(w).Encode(map[string]interface{}{ "success": true, "account": map[string]interface{}{ "id": account.ID, "email": account.Email, }, }) } func (h *Handler) apiStartBuilderIdLogin(w http.ResponseWriter, r *http.Request) { var req struct { Region string `json:"region"` } json.NewDecoder(r.Body).Decode(&req) session, err := auth.StartBuilderIdLogin(req.Region) if err != nil { w.WriteHeader(500) json.NewEncoder(w).Encode(map[string]string{"error": err.Error()}) return } json.NewEncoder(w).Encode(map[string]interface{}{ "sessionId": session.ID, "userCode": session.UserCode, "verificationUri": session.VerificationUri, "interval": session.Interval, }) } func (h *Handler) apiPollBuilderIdAuth(w http.ResponseWriter, r *http.Request) { var req struct { SessionID string `json:"sessionId"` } if err := json.NewDecoder(r.Body).Decode(&req); err != nil { w.WriteHeader(400) json.NewEncoder(w).Encode(map[string]string{"error": "Invalid JSON"}) return } accessToken, refreshToken, clientID, clientSecret, region, expiresIn, status, err := auth.PollBuilderIdAuth(req.SessionID) if err != nil { w.WriteHeader(400) json.NewEncoder(w).Encode(map[string]interface{}{ "success": false, "error": err.Error(), }) return } if status == "pending" || status == "slow_down" { // 获取当前间隔 interval := 5 if session := auth.GetBuilderIdSession(req.SessionID); session != nil { interval = session.Interval } json.NewEncoder(w).Encode(map[string]interface{}{ "success": true, "completed": false, "status": status, "interval": interval, }) return } // 授权完成,获取用户信息 email, _, _ := auth.GetUserInfo(accessToken) // 创建账号 account := config.Account{ ID: auth.GenerateAccountID(), Email: email, AccessToken: accessToken, RefreshToken: refreshToken, ClientID: clientID, ClientSecret: clientSecret, AuthMethod: "idc", Provider: "BuilderId", Region: region, ExpiresAt: time.Now().Unix() + int64(expiresIn), Enabled: true, MachineId: config.GenerateMachineId(), } if err := config.AddAccount(account); err != nil { w.WriteHeader(500) json.NewEncoder(w).Encode(map[string]string{"error": err.Error()}) return } h.pool.Reload() json.NewEncoder(w).Encode(map[string]interface{}{ "success": true, "completed": true, "account": map[string]interface{}{ "id": account.ID, "email": account.Email, }, }) } func (h *Handler) apiImportSsoToken(w http.ResponseWriter, r *http.Request) { var req struct { BearerToken string `json:"bearerToken"` Region string `json:"region"` } if err := json.NewDecoder(r.Body).Decode(&req); err != nil { w.WriteHeader(400) json.NewEncoder(w).Encode(map[string]string{"error": "Invalid JSON"}) return } if req.BearerToken == "" { w.WriteHeader(400) json.NewEncoder(w).Encode(map[string]string{"error": "bearerToken is required"}) return } // 支持批量导入,按行分割 tokens := strings.Split(strings.TrimSpace(req.BearerToken), "\n") var imported []map[string]interface{} var errors []string for _, token := range tokens { token = strings.TrimSpace(token) if token == "" { continue } accessToken, refreshToken, clientID, clientSecret, expiresIn, err := auth.ImportFromSsoToken(token, req.Region) if err != nil { errors = append(errors, err.Error()) continue } // 获取用户信息 email, _, _ := auth.GetUserInfo(accessToken) // 创建账号 account := config.Account{ ID: auth.GenerateAccountID(), Email: email, AccessToken: accessToken, RefreshToken: refreshToken, ClientID: clientID, ClientSecret: clientSecret, AuthMethod: "idc", Region: req.Region, ExpiresAt: time.Now().Unix() + int64(expiresIn), Enabled: true, MachineId: config.GenerateMachineId(), } if err := config.AddAccount(account); err != nil { errors = append(errors, err.Error()) continue } imported = append(imported, map[string]interface{}{ "id": account.ID, "email": account.Email, }) } h.pool.Reload() if len(imported) == 0 && len(errors) > 0 { w.WriteHeader(500) json.NewEncoder(w).Encode(map[string]interface{}{ "success": false, "error": strings.Join(errors, "; "), }) return } json.NewEncoder(w).Encode(map[string]interface{}{ "success": true, "accounts": imported, "errors": errors, }) } func (h *Handler) apiImportCredentials(w http.ResponseWriter, r *http.Request) { var req struct { AccessToken string `json:"accessToken"` RefreshToken string `json:"refreshToken"` ClientID string `json:"clientId"` ClientSecret string `json:"clientSecret"` AuthMethod string `json:"authMethod"` Provider string `json:"provider"` Region string `json:"region"` } if err := json.NewDecoder(r.Body).Decode(&req); err != nil { w.WriteHeader(400) json.NewEncoder(w).Encode(map[string]string{"error": "Invalid JSON"}) return } if req.RefreshToken == "" { w.WriteHeader(400) json.NewEncoder(w).Encode(map[string]string{"error": "refreshToken is required"}) return } // 设置默认值 if req.Region == "" { req.Region = "us-east-1" } if req.AuthMethod == "" { if req.ClientID != "" { req.AuthMethod = "idc" } else { req.AuthMethod = "social" } } // 标准化 authMethod switch strings.ToLower(req.AuthMethod) { case "idc", "builderid", "enterprise": req.AuthMethod = "idc" case "social", "google", "github": req.AuthMethod = "social" default: if req.ClientID != "" && req.ClientSecret != "" { req.AuthMethod = "idc" } else { req.AuthMethod = "social" } } // 始终尝试用 refreshToken 刷新获取新的 accessToken var accessToken string var expiresAt int64 tempAccount := &config.Account{ RefreshToken: req.RefreshToken, ClientID: req.ClientID, ClientSecret: req.ClientSecret, AuthMethod: req.AuthMethod, Region: req.Region, } newAccessToken, newRefreshToken, newExpiresAt, err := auth.RefreshToken(tempAccount) if err != nil { // 刷新失败,如果有传入的 accessToken 则尝试使用 if req.AccessToken != "" { accessToken = req.AccessToken expiresAt = time.Now().Unix() + 300 // 可能已过期,设短一点 } else { w.WriteHeader(400) json.NewEncoder(w).Encode(map[string]string{"error": "Token refresh failed: " + err.Error()}) return } } else { accessToken = newAccessToken if newRefreshToken != "" { req.RefreshToken = newRefreshToken } expiresAt = newExpiresAt } // 获取用户信息 email, _, _ := auth.GetUserInfo(accessToken) // 创建账号 account := config.Account{ ID: auth.GenerateAccountID(), Email: email, AccessToken: accessToken, RefreshToken: req.RefreshToken, ClientID: req.ClientID, ClientSecret: req.ClientSecret, AuthMethod: req.AuthMethod, Provider: req.Provider, Region: req.Region, ExpiresAt: expiresAt, Enabled: true, MachineId: config.GenerateMachineId(), } if err := config.AddAccount(account); err != nil { w.WriteHeader(500) json.NewEncoder(w).Encode(map[string]string{"error": err.Error()}) return } h.pool.Reload() json.NewEncoder(w).Encode(map[string]interface{}{ "success": true, "account": map[string]interface{}{ "id": account.ID, "email": account.Email, }, }) } func (h *Handler) apiGetStatus(w http.ResponseWriter, r *http.Request) { json.NewEncoder(w).Encode(map[string]interface{}{ "accounts": h.pool.Count(), "available": h.pool.AvailableCount(), "totalRequests": h.totalRequests, "successRequests": h.successRequests, "failedRequests": h.failedRequests, "totalTokens": h.totalTokens, "totalCredits": h.totalCredits, "uptime": time.Now().Unix() - h.startTime, }) } func (h *Handler) apiGetSettings(w http.ResponseWriter, r *http.Request) { json.NewEncoder(w).Encode(map[string]interface{}{ "apiKey": config.GetApiKey(), "requireApiKey": config.IsApiKeyRequired(), "port": config.GetPort(), "host": config.GetHost(), }) } func (h *Handler) apiUpdateSettings(w http.ResponseWriter, r *http.Request) { var req struct { ApiKey string `json:"apiKey"` RequireApiKey bool `json:"requireApiKey"` Password string `json:"password"` } if err := json.NewDecoder(r.Body).Decode(&req); err != nil { w.WriteHeader(400) json.NewEncoder(w).Encode(map[string]string{"error": "Invalid JSON"}) return } if err := config.UpdateSettings(req.ApiKey, req.RequireApiKey, req.Password); err != nil { w.WriteHeader(500) json.NewEncoder(w).Encode(map[string]string{"error": err.Error()}) return } json.NewEncoder(w).Encode(map[string]bool{"success": true}) } func (h *Handler) apiGetStats(w http.ResponseWriter, r *http.Request) { json.NewEncoder(w).Encode(map[string]interface{}{ "totalRequests": atomic.LoadInt64(&h.totalRequests), "successRequests": atomic.LoadInt64(&h.successRequests), "failedRequests": atomic.LoadInt64(&h.failedRequests), "totalTokens": atomic.LoadInt64(&h.totalTokens), "totalCredits": h.getCredits(), "uptime": time.Now().Unix() - h.startTime, }) } func (h *Handler) apiResetStats(w http.ResponseWriter, r *http.Request) { atomic.StoreInt64(&h.totalRequests, 0) atomic.StoreInt64(&h.successRequests, 0) atomic.StoreInt64(&h.failedRequests, 0) atomic.StoreInt64(&h.totalTokens, 0) h.creditsMu.Lock() h.totalCredits = 0 h.creditsMu.Unlock() config.UpdateStats(0, 0, 0, 0, 0) json.NewEncoder(w).Encode(map[string]bool{"success": true}) } // apiGenerateMachineId 生成新的机器码 func (h *Handler) apiGenerateMachineId(w http.ResponseWriter, r *http.Request) { machineId := config.GenerateMachineId() json.NewEncoder(w).Encode(map[string]string{"machineId": machineId}) } // apiRefreshAccount 刷新账户信息(使用量、订阅等) func (h *Handler) apiRefreshAccount(w http.ResponseWriter, r *http.Request, id string) { accounts := config.GetAccounts() var account *config.Account for i := range accounts { if accounts[i].ID == id { account = &accounts[i] break } } if account == nil { w.WriteHeader(404) json.NewEncoder(w).Encode(map[string]string{"error": "Account not found"}) return } // 先尝试刷新 token(不管是否过期,确保 token 有效) refreshTokenIfNeeded := func() error { if account.RefreshToken == "" { return nil } newAccessToken, newRefreshToken, newExpiresAt, err := auth.RefreshToken(account) if err != nil { return err } account.AccessToken = newAccessToken if newRefreshToken != "" { account.RefreshToken = newRefreshToken } account.ExpiresAt = newExpiresAt config.UpdateAccountToken(id, newAccessToken, newRefreshToken, newExpiresAt) h.pool.UpdateToken(id, newAccessToken, newRefreshToken, newExpiresAt) return nil } // 检查 token 是否快过期,先刷新 if account.ExpiresAt > 0 && time.Now().Unix() > account.ExpiresAt-300 { if err := refreshTokenIfNeeded(); err != nil { w.WriteHeader(500) json.NewEncoder(w).Encode(map[string]string{"error": "Token refresh failed: " + err.Error()}) return } } // 获取账户信息 info, err := RefreshAccountInfo(account) if err != nil { // 检查是否为封禁相关错误 errMsg := err.Error() if strings.Contains(errMsg, "TEMPORARILY_SUSPENDED") || strings.Contains(errMsg, "Account suspended") { // 封禁状态已在 RefreshAccountInfo 中处理,静默返回成功 json.NewEncoder(w).Encode(map[string]interface{}{ "success": true, "message": "Account status updated", }) return } // 如果是 403/401,说明 token 无效,尝试刷新后重试 if strings.Contains(errMsg, "403") || strings.Contains(errMsg, "401") || strings.Contains(errMsg, "invalid") || strings.Contains(errMsg, "expired") { if refreshErr := refreshTokenIfNeeded(); refreshErr == nil { // 重试 info, err = RefreshAccountInfo(account) if err != nil { // 重试后仍然失败,检查是否为封禁状态 if strings.Contains(err.Error(), "TEMPORARILY_SUSPENDED") || strings.Contains(err.Error(), "Account suspended") { json.NewEncoder(w).Encode(map[string]interface{}{ "success": true, "message": "Account status updated", }) return } } } } // 其他错误才显示错误信息 if err != nil { w.WriteHeader(500) json.NewEncoder(w).Encode(map[string]string{"error": err.Error()}) return } } // 保存到配置 if err := config.UpdateAccountInfo(id, *info); err != nil { w.WriteHeader(500) json.NewEncoder(w).Encode(map[string]string{"error": err.Error()}) return } json.NewEncoder(w).Encode(map[string]interface{}{ "success": true, "info": info, }) } // apiGetAccountFull 获取单个账号的完整信息(包含敏感字段) func (h *Handler) apiGetAccountFull(w http.ResponseWriter, r *http.Request, id string) { accounts := config.GetAccounts() poolAccounts := h.pool.GetAllAccounts() // 查找指定账号 var account *config.Account for i := range accounts { if accounts[i].ID == id { account = &accounts[i] break } } if account == nil { w.WriteHeader(404) json.NewEncoder(w).Encode(map[string]string{"error": "Account not found"}) return } // 获取运行时统计 var stats config.Account for _, a := range poolAccounts { if a.ID == id { stats = a break } } // 返回完整账号信息(包含敏感字段) result := map[string]interface{}{ "id": account.ID, "email": account.Email, "userId": account.UserId, "nickname": account.Nickname, "accessToken": account.AccessToken, "refreshToken": account.RefreshToken, "clientId": account.ClientID, "clientSecret": account.ClientSecret, "authMethod": account.AuthMethod, "provider": account.Provider, "region": account.Region, "expiresAt": account.ExpiresAt, "machineId": account.MachineId, "enabled": account.Enabled, "banStatus": account.BanStatus, "banReason": account.BanReason, "banTime": account.BanTime, "subscriptionType": account.SubscriptionType, "subscriptionTitle": account.SubscriptionTitle, "daysRemaining": account.DaysRemaining, "usageCurrent": account.UsageCurrent, "usageLimit": account.UsageLimit, "usagePercent": account.UsagePercent, "nextResetDate": account.NextResetDate, "lastRefresh": account.LastRefresh, "trialUsageCurrent": account.TrialUsageCurrent, "trialUsageLimit": account.TrialUsageLimit, "trialUsagePercent": account.TrialUsagePercent, "trialStatus": account.TrialStatus, "trialExpiresAt": account.TrialExpiresAt, "requestCount": stats.RequestCount, "errorCount": stats.ErrorCount, "totalTokens": stats.TotalTokens, "totalCredits": stats.TotalCredits, "lastUsed": stats.LastUsed, } json.NewEncoder(w).Encode(result) } // apiGetAccountModels 获取账户可用模型 func (h *Handler) apiGetAccountModels(w http.ResponseWriter, r *http.Request, id string) { accounts := config.GetAccounts() var account *config.Account for i := range accounts { if accounts[i].ID == id { account = &accounts[i] break } } if account == nil { w.WriteHeader(404) json.NewEncoder(w).Encode(map[string]string{"error": "Account not found"}) return } models, err := ListAvailableModels(account) if err != nil { w.WriteHeader(500) json.NewEncoder(w).Encode(map[string]string{"error": err.Error()}) return } json.NewEncoder(w).Encode(map[string]interface{}{ "success": true, "models": models, }) } // ==================== 静态文件服务 ==================== func (h *Handler) serveAdminPage(w http.ResponseWriter, r *http.Request) { http.ServeFile(w, r, "web/index.html") } func (h *Handler) serveStaticFile(w http.ResponseWriter, r *http.Request) { path := strings.TrimPrefix(r.URL.Path, "/admin/") http.ServeFile(w, r, "web/"+path) } // apiGetThinkingConfig 获取 thinking 配置 func (h *Handler) apiGetThinkingConfig(w http.ResponseWriter, r *http.Request) { cfg := config.GetThinkingConfig() json.NewEncoder(w).Encode(map[string]interface{}{ "suffix": cfg.Suffix, "openaiFormat": cfg.OpenAIFormat, "claudeFormat": cfg.ClaudeFormat, }) } // apiUpdateThinkingConfig 更新 thinking 配置 func (h *Handler) apiUpdateThinkingConfig(w http.ResponseWriter, r *http.Request) { var req struct { Suffix string `json:"suffix"` OpenAIFormat string `json:"openaiFormat"` ClaudeFormat string `json:"claudeFormat"` } if err := json.NewDecoder(r.Body).Decode(&req); err != nil { w.WriteHeader(400) json.NewEncoder(w).Encode(map[string]string{"error": "Invalid JSON"}) return } // 验证格式 validFormats := map[string]bool{"reasoning_content": true, "thinking": true, "think": true} if req.OpenAIFormat != "" && !validFormats[req.OpenAIFormat] { w.WriteHeader(400) json.NewEncoder(w).Encode(map[string]string{"error": "Invalid openaiFormat, must be: reasoning_content, thinking, or think"}) return } if req.ClaudeFormat != "" && !validFormats[req.ClaudeFormat] { w.WriteHeader(400) json.NewEncoder(w).Encode(map[string]string{"error": "Invalid claudeFormat, must be: reasoning_content, thinking, or think"}) return } if err := config.UpdateThinkingConfig(req.Suffix, req.OpenAIFormat, req.ClaudeFormat); err != nil { w.WriteHeader(500) json.NewEncoder(w).Encode(map[string]string{"error": err.Error()}) return } json.NewEncoder(w).Encode(map[string]bool{"success": true}) } // apiGetEndpointConfig 获取端点配置 func (h *Handler) apiGetEndpointConfig(w http.ResponseWriter, r *http.Request) { json.NewEncoder(w).Encode(map[string]string{ "preferredEndpoint": config.GetPreferredEndpoint(), }) } // apiUpdateEndpointConfig 更新端点配置 func (h *Handler) apiUpdateEndpointConfig(w http.ResponseWriter, r *http.Request) { var req struct { PreferredEndpoint string `json:"preferredEndpoint"` } if err := json.NewDecoder(r.Body).Decode(&req); err != nil { w.WriteHeader(400) json.NewEncoder(w).Encode(map[string]string{"error": "Invalid JSON"}) return } valid := map[string]bool{"auto": true, "codewhisperer": true, "amazonq": true} if !valid[req.PreferredEndpoint] { w.WriteHeader(400) json.NewEncoder(w).Encode(map[string]string{"error": "Invalid endpoint, must be: auto, codewhisperer, or amazonq"}) return } if err := config.UpdatePreferredEndpoint(req.PreferredEndpoint); err != nil { w.WriteHeader(500) json.NewEncoder(w).Encode(map[string]string{"error": err.Error()}) return } json.NewEncoder(w).Encode(map[string]bool{"success": true}) } // applyProxyConfig 将代理配置应用到所有出站 HTTP 客户端(Kiro API + auth 模块) func applyProxyConfig(proxyURL string) { InitKiroHttpClient(proxyURL) auth.InitHttpClient(proxyURL) } // apiGetProxy 获取当前代理配置 func (h *Handler) apiGetProxy(w http.ResponseWriter, r *http.Request) { json.NewEncoder(w).Encode(map[string]string{ "proxyURL": config.GetProxyURL(), }) } // apiUpdateProxy 更新代理配置并立即生效 func (h *Handler) apiUpdateProxy(w http.ResponseWriter, r *http.Request) { var req struct { ProxyURL string `json:"proxyURL"` } if err := json.NewDecoder(r.Body).Decode(&req); err != nil { w.WriteHeader(400) json.NewEncoder(w).Encode(map[string]string{"error": "Invalid JSON"}) return } // 验证代理 URL 格式(非空时) if req.ProxyURL != "" { if !strings.HasPrefix(req.ProxyURL, "http://") && !strings.HasPrefix(req.ProxyURL, "https://") && !strings.HasPrefix(req.ProxyURL, "socks5://") && !strings.HasPrefix(req.ProxyURL, "socks5h://") { w.WriteHeader(400) json.NewEncoder(w).Encode(map[string]string{"error": "proxyURL must start with http://, https://, socks5://, or socks5h://"}) return } } if err := config.UpdateProxySettings(req.ProxyURL); err != nil { w.WriteHeader(500) json.NewEncoder(w).Encode(map[string]string{"error": err.Error()}) return } // 立即应用新的代理配置 applyProxyConfig(req.ProxyURL) json.NewEncoder(w).Encode(map[string]bool{"success": true}) } // apiGetGeneralConfig 获取通用设置 func (h *Handler) apiGetGeneralConfig(w http.ResponseWriter, r *http.Request) { json.NewEncoder(w).Encode(map[string]interface{}{ "invalidModelRetries": config.GetInvalidModelRetries(), "firstByteTimeoutSec": config.GetFirstByteTimeoutSec(), "firstByteRetries": config.GetFirstByteRetries(), }) } // apiUpdateGeneralConfig 更新通用设置 func (h *Handler) apiUpdateGeneralConfig(w http.ResponseWriter, r *http.Request) { var req struct { InvalidModelRetries *int `json:"invalidModelRetries"` FirstByteTimeoutSec *int `json:"firstByteTimeoutSec"` FirstByteRetries *int `json:"firstByteRetries"` } if err := json.NewDecoder(r.Body).Decode(&req); err != nil { w.WriteHeader(400) json.NewEncoder(w).Encode(map[string]string{"error": "Invalid JSON"}) return } if req.InvalidModelRetries != nil { n := *req.InvalidModelRetries if n < 0 || n > 20 { w.WriteHeader(400) json.NewEncoder(w).Encode(map[string]string{"error": "invalidModelRetries must be 0-20"}) return } if err := config.UpdateInvalidModelRetries(n); err != nil { w.WriteHeader(500) json.NewEncoder(w).Encode(map[string]string{"error": err.Error()}) return } } if req.FirstByteTimeoutSec != nil { n := *req.FirstByteTimeoutSec if n < 0 || n > 300 { w.WriteHeader(400) json.NewEncoder(w).Encode(map[string]string{"error": "firstByteTimeoutSec must be 0-300"}) return } if err := config.UpdateFirstByteTimeoutSec(n); err != nil { w.WriteHeader(500) json.NewEncoder(w).Encode(map[string]string{"error": err.Error()}) return } } if req.FirstByteRetries != nil { n := *req.FirstByteRetries if n < 0 || n > 10 { w.WriteHeader(400) json.NewEncoder(w).Encode(map[string]string{"error": "firstByteRetries must be 0-10"}) return } if err := config.UpdateFirstByteRetries(n); err != nil { w.WriteHeader(500) json.NewEncoder(w).Encode(map[string]string{"error": err.Error()}) return } } json.NewEncoder(w).Encode(map[string]bool{"success": true}) } // apiGetVersion 获取版本信息 func (h *Handler) apiGetVersion(w http.ResponseWriter, r *http.Request) { json.NewEncoder(w).Encode(map[string]string{ "version": config.Version, }) } // apiExportAccounts 导出账号凭证 func (h *Handler) apiExportAccounts(w http.ResponseWriter, r *http.Request) { var req struct { IDs []string `json:"ids"` // 为空则导出全部 } if err := json.NewDecoder(r.Body).Decode(&req); err != nil { // 如果 body 为空或解析失败,导出全部 req.IDs = nil } accounts := config.GetAccounts() // 如果指定了 ID,只导出指定的 if len(req.IDs) > 0 { idSet := make(map[string]bool) for _, id := range req.IDs { idSet[id] = true } var filtered []config.Account for _, a := range accounts { if idSet[a.ID] { filtered = append(filtered, a) } } accounts = filtered } // 构建兼容 Kiro Account Manager 的导出格式 type ExportCredentials struct { AccessToken string `json:"accessToken"` CsrfToken string `json:"csrfToken"` RefreshToken string `json:"refreshToken"` ClientID string `json:"clientId,omitempty"` ClientSecret string `json:"clientSecret,omitempty"` Region string `json:"region,omitempty"` ExpiresAt int64 `json:"expiresAt"` AuthMethod string `json:"authMethod,omitempty"` Provider string `json:"provider,omitempty"` } type ExportSubscription struct { Type string `json:"type"` Title string `json:"title,omitempty"` } type ExportUsage struct { Current float64 `json:"current"` Limit float64 `json:"limit"` PercentUsed float64 `json:"percentUsed"` LastUpdated int64 `json:"lastUpdated"` } type ExportAccount struct { ID string `json:"id"` Email string `json:"email"` Nickname string `json:"nickname,omitempty"` Idp string `json:"idp"` UserId string `json:"userId,omitempty"` MachineId string `json:"machineId,omitempty"` Credentials ExportCredentials `json:"credentials"` Subscription ExportSubscription `json:"subscription"` Usage ExportUsage `json:"usage"` Tags []string `json:"tags"` Status string `json:"status"` CreatedAt int64 `json:"createdAt"` LastUsedAt int64 `json:"lastUsedAt"` } type ExportData struct { Version string `json:"version"` ExportedAt int64 `json:"exportedAt"` Accounts []ExportAccount `json:"accounts"` Groups []interface{} `json:"groups"` Tags []interface{} `json:"tags"` } exportAccounts := make([]ExportAccount, 0, len(accounts)) for _, a := range accounts { // 映射 provider 到 idp idp := a.Provider if idp == "" { if a.AuthMethod == "social" { idp = "Google" } else { idp = "BuilderId" } } // 映射 authMethod authMethod := a.AuthMethod if authMethod == "idc" { authMethod = "IdC" } // 映射订阅类型 subType := "Free" rawType := strings.ToUpper(a.SubscriptionType) if strings.Contains(rawType, "PRO_PLUS") || strings.Contains(rawType, "PROPLUS") { subType = "Pro_Plus" } else if strings.Contains(rawType, "PRO") { subType = "Pro" } else if strings.Contains(rawType, "POWER") { subType = "Pro_Plus" } exportAccounts = append(exportAccounts, ExportAccount{ ID: a.ID, Email: a.Email, Nickname: a.Nickname, Idp: idp, UserId: a.UserId, MachineId: a.MachineId, Credentials: ExportCredentials{ AccessToken: a.AccessToken, CsrfToken: "", RefreshToken: a.RefreshToken, ClientID: a.ClientID, ClientSecret: a.ClientSecret, Region: a.Region, ExpiresAt: a.ExpiresAt * 1000, // 转为毫秒时间戳 AuthMethod: authMethod, Provider: a.Provider, }, Subscription: ExportSubscription{ Type: subType, Title: a.SubscriptionTitle, }, Usage: ExportUsage{ Current: a.UsageCurrent, Limit: a.UsageLimit, PercentUsed: a.UsagePercent, LastUpdated: time.Now().UnixMilli(), }, Tags: []string{}, Status: "active", CreatedAt: time.Now().UnixMilli(), LastUsedAt: time.Now().UnixMilli(), }) } data := ExportData{ Version: config.Version, ExportedAt: time.Now().UnixMilli(), Accounts: exportAccounts, Groups: []interface{}{}, Tags: []interface{}{}, } json.NewEncoder(w).Encode(data) }