Files
kirogo/proxy/handler.go
huangzhenpc e8ab5b11e7
Some checks failed
Build Docker Image / build (push) Has been cancelled
fix: stop deducting simulated cache tokens from input_tokens
Kiro backend does not support Anthropic prompt cache protocol.
The local cache tracker simulates cache hits/creation for Claude Code
compatibility, but subtracting those values from input_tokens caused
the reported input_tokens to drop to single digits.

input_tokens now reflects the real value; cache_creation_input_tokens
and cache_read_input_tokens are still reported for protocol compliance.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-12 11:14:51 +08:00

3158 lines
89 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package proxy
import (
"encoding/json"
"fmt"
"io"
"kiro-go/auth"
"kiro-go/config"
"kiro-go/pool"
"net/http"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/google/uuid"
)
// Handler HTTP 处理器
type Handler struct {
pool *pool.AccountPool
// 运行时统计 (使用原子操作)
totalRequests int64
successRequests int64
failedRequests int64
totalTokens int64
totalCredits float64 // float64 需要用锁保护
creditsMu sync.RWMutex
startTime int64
stopRefresh chan struct{}
stopStatsSaver chan struct{}
// 模型缓存
cachedModels []ModelInfo
modelsCacheMu sync.RWMutex
modelsCacheTime int64
promptCache *promptCacheTracker
}
type thinkingStreamSource int
const (
thinkingSourceUnknown thinkingStreamSource = iota
thinkingSourceReasoningEvent
thinkingSourceTagBlock
)
func allowReasoningSource(source *thinkingStreamSource) bool {
if *source == thinkingSourceTagBlock {
return false
}
*source = thinkingSourceReasoningEvent
return true
}
func allowTagSource(source *thinkingStreamSource) bool {
if *source == thinkingSourceReasoningEvent {
return false
}
if *source == thinkingSourceUnknown {
*source = thinkingSourceTagBlock
}
return *source == thinkingSourceTagBlock
}
func validateClaudeRequestShape(req *ClaudeRequest) string {
if len(req.Messages) == 0 {
return "messages must not be empty"
}
if msg := validateClaudeThinkingConfig(req.Thinking, req.MaxTokens); msg != "" {
return msg
}
hasUserContext := false
lastRole := ""
for _, msg := range req.Messages {
role := strings.TrimSpace(msg.Role)
if role == "" {
continue
}
lastRole = role
if role != "user" {
continue
}
text, images, toolResults := extractClaudeUserContent(msg.Content)
if normalizeUserContent(text, len(images) > 0) != "" || len(toolResults) > 0 {
hasUserContext = true
}
}
if lastRole == "assistant" {
return "assistant-prefill final message is not supported; last message must be user"
}
if !hasUserContext {
return "at least one non-empty user message is required"
}
return ""
}
func validateClaudeThinkingConfig(thinking *ClaudeThinkingConfig, maxTokens int) string {
if thinking == nil {
return ""
}
kind := strings.ToLower(strings.TrimSpace(thinking.Type))
switch kind {
case "enabled":
if maxTokens == 0 {
return "thinking.type enabled cannot be used with max_tokens=0"
}
if thinking.BudgetTokens <= 0 {
return "thinking.budget_tokens is required when thinking.type is enabled"
}
if thinking.BudgetTokens < 1024 {
return "thinking.budget_tokens must be at least 1024"
}
if maxTokens > 0 && thinking.BudgetTokens >= maxTokens {
return "thinking.budget_tokens must be less than max_tokens"
}
case "adaptive":
if thinking.BudgetTokens != 0 {
return "thinking.budget_tokens is not supported when thinking.type is adaptive"
}
case "disabled":
if thinking.BudgetTokens != 0 {
return "thinking.budget_tokens is not supported when thinking.type is disabled"
}
default:
return "thinking.type must be one of: enabled, adaptive, disabled"
}
display := strings.ToLower(strings.TrimSpace(thinking.Display))
if display != "" && display != "summarized" && display != "omitted" {
return "thinking.display must be one of: summarized, omitted"
}
if kind == "disabled" && display != "" {
return "thinking.display is not supported when thinking.type is disabled"
}
return ""
}
type claudeThinkingResponseOptions struct {
Format string
OmitDisplay bool
}
func resolveClaudeThinkingResponseOptions(thinking *ClaudeThinkingConfig, defaultFormat string) claudeThinkingResponseOptions {
opts := claudeThinkingResponseOptions{Format: defaultFormat}
if opts.Format == "" {
opts.Format = "thinking"
}
if thinking == nil {
return opts
}
display := strings.ToLower(strings.TrimSpace(thinking.Display))
switch display {
case "summarized":
opts.Format = "thinking"
case "omitted":
opts.Format = "thinking"
opts.OmitDisplay = true
}
return opts
}
func validateOpenAIRequestShape(req *OpenAIRequest) string {
if len(req.Messages) == 0 {
return "messages must not be empty"
}
hasNonSystem := false
hasUserContext := false
lastRole := ""
for _, msg := range req.Messages {
role := strings.TrimSpace(msg.Role)
if role == "" {
continue
}
if role != "system" {
hasNonSystem = true
lastRole = role
}
if role != "user" {
continue
}
text, images := extractOpenAIUserContent(msg.Content)
if normalizeUserContent(text, len(images) > 0) != "" {
hasUserContext = true
}
}
if !hasNonSystem {
return "at least one non-system message is required"
}
if lastRole == "assistant" {
return "assistant-prefill final message is not supported; last message must be user or tool"
}
if !hasUserContext {
return "at least one non-empty user message is required"
}
return ""
}
func NewHandler() *Handler {
// 启动时应用代理配置
applyProxyConfig(config.GetProxyURL())
totalReq, successReq, failedReq, totalTokens, totalCredits := config.GetStats()
h := &Handler{
pool: pool.GetPool(),
totalRequests: int64(totalReq),
successRequests: int64(successReq),
failedRequests: int64(failedReq),
totalTokens: int64(totalTokens),
totalCredits: totalCredits,
startTime: time.Now().Unix(),
stopRefresh: make(chan struct{}),
stopStatsSaver: make(chan struct{}),
promptCache: newPromptCacheTracker(defaultPromptCacheTTL),
}
// 启动后台刷新
go h.backgroundRefresh()
// 启动后台统计保存 (每30秒保存一次)
go h.backgroundStatsSaver()
return h
}
// backgroundRefresh 后台定时刷新账户信息
func (h *Handler) backgroundRefresh() {
ticker := time.NewTicker(30 * time.Minute) // 每 30 分钟刷新一次
defer ticker.Stop()
// 启动时延迟 10 秒后执行一次
time.Sleep(10 * time.Second)
h.refreshModelsCache()
h.refreshAllAccounts()
for {
select {
case <-ticker.C:
h.refreshModelsCache()
h.refreshAllAccounts()
case <-h.stopRefresh:
return
}
}
}
// refreshAllAccounts 刷新所有账户信息
func (h *Handler) refreshAllAccounts() {
accounts := config.GetAccounts()
for i := range accounts {
account := &accounts[i]
if !account.Enabled || account.AccessToken == "" {
continue
}
// 检查 token 是否需要刷新
if account.ExpiresAt > 0 && time.Now().Unix() > account.ExpiresAt-300 {
newAccessToken, newRefreshToken, newExpiresAt, err := auth.RefreshToken(account)
if err != nil {
fmt.Printf("[BackgroundRefresh] Token refresh failed for %s: %v\n", account.Email, err)
continue
}
account.AccessToken = newAccessToken
if newRefreshToken != "" {
account.RefreshToken = newRefreshToken
}
account.ExpiresAt = newExpiresAt
config.UpdateAccountToken(account.ID, newAccessToken, newRefreshToken, newExpiresAt)
h.pool.UpdateToken(account.ID, newAccessToken, newRefreshToken, newExpiresAt)
}
// 刷新账户信息
info, err := RefreshAccountInfo(account)
if err != nil {
fmt.Printf("[BackgroundRefresh] Failed to refresh %s: %v\n", account.Email, err)
continue
}
config.UpdateAccountInfo(account.ID, *info)
fmt.Printf("[BackgroundRefresh] Refreshed %s: %s %.1f/%.1f\n", account.Email, info.SubscriptionType, info.UsageCurrent, info.UsageLimit)
}
h.pool.Reload()
}
// validateApiKey 验证 API Key
func (h *Handler) validateApiKey(r *http.Request) bool {
if !config.IsApiKeyRequired() {
return true
}
expectedKey := config.GetApiKey()
if expectedKey == "" {
return true
}
// 从 Authorization 头或 X-Api-Key 头获取
authHeader := r.Header.Get("Authorization")
apiKeyHeader := r.Header.Get("X-Api-Key")
var providedKey string
if strings.HasPrefix(authHeader, "Bearer ") {
providedKey = strings.TrimPrefix(authHeader, "Bearer ")
} else if apiKeyHeader != "" {
providedKey = apiKeyHeader
}
return providedKey == expectedKey
}
// ServeHTTP 路由分发
func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
path := r.URL.Path
// CORS - 完整的头部支持
w.Header().Set("Access-Control-Allow-Origin", "*")
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization, X-Api-Key, anthropic-version, anthropic-beta, x-api-key, x-stainless-os, x-stainless-lang, x-stainless-package-version, x-stainless-runtime, x-stainless-runtime-version, x-stainless-arch")
w.Header().Set("Access-Control-Expose-Headers", "x-request-id, x-ratelimit-limit-requests, x-ratelimit-limit-tokens, x-ratelimit-remaining-requests, x-ratelimit-remaining-tokens, x-ratelimit-reset-requests, x-ratelimit-reset-tokens")
if r.Method == "OPTIONS" {
w.WriteHeader(204)
return
}
// 路由
switch {
// API 端点(需要验证 API Key
case path == "/v1/messages" || path == "/messages" || path == "/anthropic/v1/messages":
if !h.validateApiKey(r) {
h.sendClaudeError(w, 401, "authentication_error", "Invalid or missing API key")
return
}
h.handleClaudeMessages(w, r)
case path == "/v1/messages/count_tokens" || path == "/messages/count_tokens":
if !h.validateApiKey(r) {
h.sendClaudeError(w, 401, "authentication_error", "Invalid or missing API key")
return
}
h.handleCountTokens(w, r)
case path == "/v1/chat/completions" || path == "/chat/completions":
if !h.validateApiKey(r) {
h.sendOpenAIError(w, 401, "authentication_error", "Invalid or missing API key")
return
}
h.handleOpenAIChat(w, r)
case path == "/v1/models" || path == "/models":
h.handleModels(w, r)
case path == "/api/event_logging/batch":
// Claude Code 遥测端点 - 直接返回 200 OK
w.Header().Set("Content-Type", "application/json; charset=utf-8")
w.Write([]byte(`{"status":"ok"}`))
// 管理端点
case path == "/admin" || path == "/admin/":
h.serveAdminPage(w, r)
case strings.HasPrefix(path, "/admin/api/"):
h.handleAdminAPI(w, r)
case strings.HasPrefix(path, "/admin/"):
h.serveStaticFile(w, r)
// 健康检查
case path == "/health" || path == "/":
h.handleHealth(w, r)
// 统计端点(需要 API Key 鉴权)
case path == "/v1/stats":
if !h.validateApiKey(r) {
w.Header().Set("Content-Type", "application/json; charset=utf-8")
w.WriteHeader(401)
json.NewEncoder(w).Encode(map[string]string{"error": "Invalid or missing API key"})
return
}
h.handleStats(w, r)
default:
http.Error(w, "Not Found", 404)
}
}
// handleHealth 健康检查(不暴露统计数据)
func (h *Handler) handleHealth(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json; charset=utf-8")
json.NewEncoder(w).Encode(map[string]interface{}{
"status": "ok",
"version": config.Version,
"uptime": time.Now().Unix() - h.startTime,
})
}
// handleStats 统计数据(需要 API Key 鉴权)
func (h *Handler) handleStats(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json; charset=utf-8")
json.NewEncoder(w).Encode(map[string]interface{}{
"status": "ok",
"version": config.Version,
"accounts": h.pool.Count(),
"available": h.pool.AvailableCount(),
"totalRequests": atomic.LoadInt64(&h.totalRequests),
"successRequests": atomic.LoadInt64(&h.successRequests),
"failedRequests": atomic.LoadInt64(&h.failedRequests),
"totalTokens": atomic.LoadInt64(&h.totalTokens),
"totalCredits": h.getCredits(),
"uptime": time.Now().Unix() - h.startTime,
})
}
// handleModels 模型列表
func (h *Handler) handleModels(w http.ResponseWriter, r *http.Request) {
// 尝试用缓存的真实模型列表
h.modelsCacheMu.RLock()
cached := h.cachedModels
h.modelsCacheMu.RUnlock()
if len(cached) == 0 {
h.refreshModelsCache()
h.modelsCacheMu.RLock()
cached = h.cachedModels
h.modelsCacheMu.RUnlock()
}
thinkingSuffix := config.GetThinkingConfig().Suffix
models := buildAnthropicModelsResponse(cached, thinkingSuffix)
if len(models) == 0 {
models = fallbackAnthropicModels(thinkingSuffix)
}
// 添加别名模型
models = append(models,
buildModelInfo("auto", "kiro-proxy", true),
buildModelInfo("gpt-4o", "kiro-proxy", true),
buildModelInfo("gpt-4", "kiro-proxy", true),
)
w.Header().Set("Content-Type", "application/json; charset=utf-8")
json.NewEncoder(w).Encode(map[string]interface{}{
"object": "list",
"data": models,
})
return
}
func buildAnthropicModelsResponse(cached []ModelInfo, thinkingSuffix string) []map[string]interface{} {
if len(cached) == 0 {
return nil
}
models := make([]map[string]interface{}, 0, len(cached)*2)
if len(cached) > 0 {
for _, m := range cached {
supportsImage := modelSupportsImage(m.InputTypes)
models = append(models, buildModelInfo(m.ModelId, "anthropic", supportsImage))
// 自动生成 thinking 变体
models = append(models, buildModelInfo(m.ModelId+thinkingSuffix, "anthropic", supportsImage))
}
}
return models
}
func fallbackAnthropicModels(thinkingSuffix string) []map[string]interface{} {
return []map[string]interface{}{
buildModelInfo("claude-sonnet-4.6", "anthropic", true),
buildModelInfo("claude-sonnet-4.6"+thinkingSuffix, "anthropic", true),
buildModelInfo("claude-opus-4.6", "anthropic", true),
buildModelInfo("claude-opus-4.6"+thinkingSuffix, "anthropic", true),
buildModelInfo("claude-opus-4.7", "anthropic", true),
buildModelInfo("claude-opus-4.7"+thinkingSuffix, "anthropic", true),
buildModelInfo("claude-sonnet-4.5", "anthropic", true),
buildModelInfo("claude-sonnet-4.5"+thinkingSuffix, "anthropic", true),
buildModelInfo("claude-sonnet-4", "anthropic", true),
buildModelInfo("claude-sonnet-4"+thinkingSuffix, "anthropic", true),
buildModelInfo("claude-haiku-4.5", "anthropic", true),
buildModelInfo("claude-haiku-4.5"+thinkingSuffix, "anthropic", true),
buildModelInfo("claude-opus-4.5", "anthropic", true),
buildModelInfo("claude-opus-4.5"+thinkingSuffix, "anthropic", true),
}
}
func modelSupportsImage(inputTypes []string) bool {
for _, t := range inputTypes {
lt := strings.ToLower(t)
if strings.Contains(lt, "image") || strings.Contains(lt, "vision") {
return true
}
}
return false
}
func buildModelInfo(id, ownedBy string, supportsImage bool) map[string]interface{} {
modalities := []string{"text"}
if supportsImage {
modalities = append(modalities, "image")
}
modalitiesMap := map[string][]string{
"input": modalities,
"output": []string{"text"},
}
return map[string]interface{}{
"id": id,
"object": "model",
"owned_by": ownedBy,
"supports_image": supportsImage,
"input_modalities": modalities,
"modalities": modalitiesMap,
"capabilities": map[string]bool{
"vision": supportsImage,
"image": supportsImage,
"image_vision": supportsImage,
},
"info": map[string]interface{}{
"meta": map[string]interface{}{
"capabilities": map[string]bool{
"vision": supportsImage,
"image_vision": supportsImage,
},
},
},
}
}
// refreshModelsCache 从 Kiro API 拉取模型列表并缓存
func (h *Handler) refreshModelsCache() {
accounts := config.GetEnabledAccounts()
if len(accounts) == 0 {
return
}
aggregated := make([]ModelInfo, 0)
for i := range accounts {
account := &accounts[i]
if err := h.ensureValidToken(account); err != nil {
fmt.Printf("[ModelsCache] Skip %s token refresh failed: %v\n", account.Email, err)
continue
}
models, err := ListAvailableModels(account)
if err != nil {
fmt.Printf("[ModelsCache] Failed to refresh for %s: %v\n", account.Email, err)
continue
}
aggregated = mergeUniqueModels(aggregated, models)
}
if len(aggregated) > 0 {
h.modelsCacheMu.Lock()
h.cachedModels = aggregated
h.modelsCacheTime = time.Now().Unix()
h.modelsCacheMu.Unlock()
fmt.Printf("[ModelsCache] Cached %d models\n", len(aggregated))
}
}
func mergeUniqueModels(existing []ModelInfo, incoming []ModelInfo) []ModelInfo {
if len(incoming) == 0 {
return existing
}
indexByID := make(map[string]int, len(existing))
merged := make([]ModelInfo, len(existing))
copy(merged, existing)
for i, model := range merged {
indexByID[strings.ToLower(strings.TrimSpace(model.ModelId))] = i
}
for _, model := range incoming {
key := strings.ToLower(strings.TrimSpace(model.ModelId))
if key == "" {
continue
}
if idx, ok := indexByID[key]; ok {
merged[idx] = mergeModelInfo(merged[idx], model)
continue
}
indexByID[key] = len(merged)
merged = append(merged, model)
}
return merged
}
func mergeModelInfo(base ModelInfo, extra ModelInfo) ModelInfo {
if base.ModelName == "" {
base.ModelName = extra.ModelName
}
if base.Description == "" {
base.Description = extra.Description
}
if base.RateMultiplier == 0 {
base.RateMultiplier = extra.RateMultiplier
}
if base.TokenLimits == nil {
base.TokenLimits = extra.TokenLimits
}
base.InputTypes = mergeStringLists(base.InputTypes, extra.InputTypes)
return base
}
func mergeStringLists(base []string, extra []string) []string {
if len(extra) == 0 {
return base
}
seen := make(map[string]bool, len(base)+len(extra))
merged := make([]string, 0, len(base)+len(extra))
for _, item := range base {
key := strings.ToLower(strings.TrimSpace(item))
if key == "" || seen[key] {
continue
}
seen[key] = true
merged = append(merged, item)
}
for _, item := range extra {
key := strings.ToLower(strings.TrimSpace(item))
if key == "" || seen[key] {
continue
}
seen[key] = true
merged = append(merged, item)
}
return merged
}
// handleCountTokens Token 计数Claude Code 会调用)
func (h *Handler) handleCountTokens(w http.ResponseWriter, r *http.Request) {
if r.Method != "POST" {
http.Error(w, "Method Not Allowed", 405)
return
}
body, err := io.ReadAll(r.Body)
if err != nil {
h.sendClaudeError(w, 400, "invalid_request_error", "Failed to read request body")
return
}
var req ClaudeRequest
if err := json.Unmarshal(body, &req); err != nil {
h.sendClaudeError(w, 400, "invalid_request_error", "Invalid JSON")
return
}
if msg := validateClaudeThinkingConfig(req.Thinking, req.MaxTokens); msg != "" {
h.sendClaudeError(w, 400, "invalid_request_error", msg)
return
}
thinkingCfg := config.GetThinkingConfig()
actualModel, thinking := resolveClaudeThinkingMode(req.Model, req.Thinking, thinkingCfg.Suffix)
req.Model = actualModel
effectiveReq := cloneClaudeRequestForThinking(&req, thinking)
estimatedTokens := estimateClaudeRequestInputTokens(effectiveReq)
if estimatedTokens < 1 {
estimatedTokens = 1
}
w.Header().Set("Content-Type", "application/json; charset=utf-8")
json.NewEncoder(w).Encode(map[string]int{"input_tokens": estimatedTokens})
}
// handleClaudeMessages Claude API 处理
func (h *Handler) handleClaudeMessages(w http.ResponseWriter, r *http.Request) {
h.handleClaudeMessagesInternal(w, r)
}
func (h *Handler) handleClaudeMessagesInternal(w http.ResponseWriter, r *http.Request) {
if r.Method != "POST" {
http.Error(w, "Method Not Allowed", 405)
return
}
// 读取请求
body, err := io.ReadAll(r.Body)
if err != nil {
h.sendClaudeError(w, 400, "invalid_request_error", "Failed to read request body")
return
}
var req ClaudeRequest
if err := json.Unmarshal(body, &req); err != nil {
h.sendClaudeError(w, 400, "invalid_request_error", "Invalid JSON: "+err.Error())
return
}
if msg := validateClaudeRequestShape(&req); msg != "" {
h.sendClaudeError(w, 400, "invalid_request_error", msg)
return
}
// 获取账号
account := h.pool.GetNext()
if account == nil {
h.sendClaudeError(w, 503, "api_error", "No available accounts")
return
}
// 检查并刷新 token
if err := h.ensureValidToken(account); err != nil {
h.sendClaudeError(w, 503, "api_error", "Token refresh failed: "+err.Error())
return
}
// 解析模型和 thinking 模式
thinkingCfg := config.GetThinkingConfig()
actualModel, thinking := resolveClaudeThinkingMode(req.Model, req.Thinking, thinkingCfg.Suffix)
req.Model = actualModel
effectiveReq := cloneClaudeRequestForThinking(&req, thinking)
thinkingResponseOpts := resolveClaudeThinkingResponseOptions(req.Thinking, thinkingCfg.ClaudeFormat)
estimatedInputTokens := estimateClaudeRequestInputTokens(effectiveReq)
cacheProfile := h.promptCache.BuildClaudeProfile(effectiveReq, estimatedInputTokens)
cacheUsage := h.promptCache.Compute(account.ID, cacheProfile)
// 转换请求
kiroPayload := ClaudeToKiro(&req, thinking)
// Stream or non-stream
if req.Stream {
h.handleClaudeStream(w, account, kiroPayload, req.Model, thinking, thinkingResponseOpts, estimatedInputTokens, cacheUsage, cacheProfile)
} else {
h.handleClaudeNonStream(w, account, kiroPayload, req.Model, thinking, thinkingResponseOpts, estimatedInputTokens, cacheUsage, cacheProfile)
}
}
// handleClaudeStream Claude 流式响应
func (h *Handler) handleClaudeStream(w http.ResponseWriter, account *config.Account, payload *KiroPayload, model string, thinking bool, thinkingOpts claudeThinkingResponseOptions, estimatedInputTokens int, cacheUsage promptCacheUsage, cacheProfile *promptCacheProfile) {
w.Header().Set("Content-Type", "text/event-stream; charset=utf-8")
w.Header().Set("Cache-Control", "no-cache")
w.Header().Set("Connection", "keep-alive")
flusher, ok := w.(http.Flusher)
if !ok {
h.sendClaudeError(w, 500, "api_error", "Streaming not supported")
return
}
// 获取 thinking 输出格式配置
thinkingFormat := thinkingOpts.Format
msgID := "msg_" + uuid.New().String()
var inputTokens, outputTokens int
var credits float64
var realInputTokens int
var toolUses []KiroToolUse
var nextContentIndex int
var rawContentBuilder strings.Builder
var rawThinkingBuilder strings.Builder
activeBlockIndex := -1
activeBlockType := ""
startInputTokens := estimatedInputTokens
closeActiveBlock := func() {
if activeBlockIndex < 0 {
return
}
h.sendSSE(w, flusher, "content_block_stop", map[string]interface{}{
"type": "content_block_stop",
"index": activeBlockIndex,
})
activeBlockIndex = -1
activeBlockType = ""
}
startContentBlock := func(blockType string) {
if activeBlockType == blockType {
return
}
closeActiveBlock()
idx := nextContentIndex
nextContentIndex++
if blockType == "thinking" {
h.sendSSE(w, flusher, "content_block_start", map[string]interface{}{
"type": "content_block_start",
"index": idx,
"content_block": map[string]string{
"type": "thinking",
"thinking": "",
},
})
} else {
h.sendSSE(w, flusher, "content_block_start", map[string]interface{}{
"type": "content_block_start",
"index": idx,
"content_block": map[string]string{
"type": "text",
"text": "",
},
})
}
activeBlockIndex = idx
activeBlockType = blockType
}
// Thinking 标签解析状态
var textBuffer string
var inThinkingBlock bool
var dropTagThinking bool
var thinkingSource thinkingStreamSource
// 发送文本的辅助函数
// thinkingState: 0=普通内容, 1=thinking开始, 2=thinking中间, 3=thinking结束
sendText := func(text string, thinkingState int) {
if thinkingState == 0 {
// 普通内容
if text == "" {
return
}
startContentBlock("text")
h.sendSSE(w, flusher, "content_block_delta", map[string]interface{}{
"type": "content_block_delta",
"index": activeBlockIndex,
"delta": map[string]string{"type": "text_delta", "text": text},
})
return
}
if !thinking {
return
}
switch thinkingFormat {
case "think":
var outputText string
switch thinkingState {
case 1:
outputText = "<think>" + text
case 2:
outputText = text
case 3:
outputText = text + "</think>"
}
if outputText == "" {
return
}
startContentBlock("text")
h.sendSSE(w, flusher, "content_block_delta", map[string]interface{}{
"type": "content_block_delta",
"index": activeBlockIndex,
"delta": map[string]string{"type": "text_delta", "text": outputText},
})
case "reasoning_content":
if text == "" {
return
}
startContentBlock("text")
h.sendSSE(w, flusher, "content_block_delta", map[string]interface{}{
"type": "content_block_delta",
"index": activeBlockIndex,
"delta": map[string]string{"type": "text_delta", "text": text},
})
default:
if thinkingOpts.OmitDisplay {
if thinkingState == 1 {
startContentBlock("thinking")
return
}
if thinkingState == 3 {
if activeBlockType != "thinking" {
startContentBlock("thinking")
}
closeActiveBlock()
}
return
}
if thinkingState == 3 && text == "" {
if activeBlockType == "thinking" {
closeActiveBlock()
}
return
}
if text != "" {
startContentBlock("thinking")
h.sendSSE(w, flusher, "content_block_delta", map[string]interface{}{
"type": "content_block_delta",
"index": activeBlockIndex,
"delta": map[string]string{"type": "thinking_delta", "thinking": text},
})
}
if thinkingState == 3 && activeBlockType == "thinking" {
closeActiveBlock()
}
}
}
// 处理文本,解析 <thinking> 标签
var thinkingStarted bool
var eventThinkingOpen bool
processClaudeText := func(text string, isThinking bool, forceFlush bool) {
if isThinking && !thinking {
return
}
// 如果是 reasoningContentEvent直接输出
if isThinking {
if !allowReasoningSource(&thinkingSource) {
return
}
if !thinkingStarted {
sendText(text, 1)
thinkingStarted = true
eventThinkingOpen = true
} else {
sendText(text, 2)
}
return
}
if eventThinkingOpen {
sendText("", 3)
eventThinkingOpen = false
thinkingStarted = false
}
textBuffer += text
for {
if !inThinkingBlock {
thinkingStart := strings.Index(textBuffer, "<thinking>")
if thinkingStart != -1 {
if thinkingStart > 0 {
sendText(textBuffer[:thinkingStart], 0)
}
textBuffer = textBuffer[thinkingStart+10:]
inThinkingBlock = true
dropTagThinking = !allowTagSource(&thinkingSource)
thinkingStarted = false
} else if forceFlush || len([]rune(textBuffer)) > 50 {
// 使用 rune 切片来正确处理 Unicode 字符
runes := []rune(textBuffer)
safeLen := len(runes)
if !forceFlush {
safeLen = max(0, len(runes)-15)
}
if safeLen > 0 {
sendText(string(runes[:safeLen]), 0)
textBuffer = string(runes[safeLen:])
}
break
} else {
break
}
} else {
thinkingEnd := strings.Index(textBuffer, "</thinking>")
if thinkingEnd != -1 {
content := textBuffer[:thinkingEnd]
if !dropTagThinking {
if !thinkingStarted {
sendText(content, 1)
sendText("", 3)
} else {
sendText(content, 3)
}
}
textBuffer = textBuffer[thinkingEnd+11:]
inThinkingBlock = false
dropTagThinking = false
thinkingStarted = false
} else if forceFlush {
if textBuffer != "" {
if !dropTagThinking {
if !thinkingStarted {
sendText(textBuffer, 1)
sendText("", 3)
} else {
sendText(textBuffer, 3)
}
}
textBuffer = ""
}
inThinkingBlock = false
dropTagThinking = false
thinkingStarted = false
break
} else {
// 流式输出 thinking 块内的内容
runes := []rune(textBuffer)
if len(runes) > 20 {
safeLen := len(runes) - 15
if safeLen > 0 {
if !dropTagThinking {
if !thinkingStarted {
sendText(string(runes[:safeLen]), 1)
thinkingStarted = true
} else {
sendText(string(runes[:safeLen]), 2)
}
}
textBuffer = string(runes[safeLen:])
}
}
break
}
}
}
}
// 发送 message_start
h.sendSSE(w, flusher, "message_start", map[string]interface{}{
"type": "message_start",
"message": map[string]interface{}{
"id": msgID,
"type": "message",
"role": "assistant",
"content": []interface{}{},
"model": model,
"stop_reason": nil,
"stop_sequence": nil,
"usage": buildClaudeUsageMap(startInputTokens, 0, cacheUsage, cacheProfile != nil),
},
})
callback := &KiroStreamCallback{
OnText: func(text string, isThinking bool) {
if text == "" {
return
}
if isThinking {
rawThinkingBuilder.WriteString(text)
} else {
rawContentBuilder.WriteString(text)
}
processClaudeText(text, isThinking, false)
},
OnToolUse: func(tu KiroToolUse) {
// 先刷新缓冲区
processClaudeText("", false, true)
rawContentBuilder.WriteString(tu.Name)
if b, err := json.Marshal(tu.Input); err == nil {
rawContentBuilder.Write(b)
}
toolUses = append(toolUses, tu)
closeActiveBlock()
idx := nextContentIndex
nextContentIndex++
h.sendSSE(w, flusher, "content_block_start", map[string]interface{}{
"type": "content_block_start",
"index": idx,
"content_block": map[string]interface{}{
"type": "tool_use",
"id": tu.ToolUseID,
"name": tu.Name,
"input": map[string]interface{}{},
},
})
inputJSON, _ := json.Marshal(tu.Input)
h.sendSSE(w, flusher, "content_block_delta", map[string]interface{}{
"type": "content_block_delta",
"index": idx,
"delta": map[string]interface{}{
"type": "input_json_delta",
"partial_json": string(inputJSON),
},
})
h.sendSSE(w, flusher, "content_block_stop", map[string]interface{}{
"type": "content_block_stop",
"index": idx,
})
},
OnComplete: func(inTok, outTok int) {
inputTokens = inTok
outputTokens = outTok
},
OnError: func(err error) {
h.pool.RecordError(account.ID, strings.Contains(err.Error(), "429") || strings.Contains(err.Error(), "quota"))
},
OnCredits: func(c float64) {
credits = c
},
OnContextUsage: func(pct float64) {
realInputTokens = int(pct * float64(getContextWindowSize(model)) / 100.0)
},
}
err := CallKiroAPI(account, payload, callback)
if err != nil {
h.recordFailure()
h.pool.RecordError(account.ID, strings.Contains(err.Error(), "429") || strings.Contains(err.Error(), "quota"))
h.sendSSE(w, flusher, "error", map[string]interface{}{
"type": "error",
"error": map[string]string{"type": "api_error", "message": err.Error()},
})
return
}
// 刷新剩余缓冲区
processClaudeText("", false, true)
if eventThinkingOpen {
sendText("", 3)
eventThinkingOpen = false
}
closeActiveBlock()
if realInputTokens > 0 {
inputTokens = realInputTokens
} else if inputTokens <= 0 {
inputTokens = estimatedInputTokens
}
outputContent, extractedReasoning := extractThinkingFromContent(rawContentBuilder.String())
thinkingOutput := rawThinkingBuilder.String()
if thinking && thinkingOutput == "" && extractedReasoning != "" {
thinkingOutput = extractedReasoning
}
if !thinking {
thinkingOutput = ""
}
outputTokens = estimateClaudeOutputTokens(outputContent, thinkingOutput, toolUses)
h.recordSuccess(inputTokens, outputTokens, credits)
h.pool.RecordSuccess(account.ID)
h.pool.UpdateStats(account.ID, inputTokens+outputTokens, credits)
h.promptCache.Update(account.ID, cacheProfile)
// 发送 message_delta
stopReason := "end_turn"
if len(toolUses) > 0 {
stopReason = "tool_use"
}
h.sendSSE(w, flusher, "message_delta", map[string]interface{}{
"type": "message_delta",
"delta": map[string]interface{}{
"stop_reason": stopReason,
},
"usage": buildClaudeUsageMap(inputTokens, outputTokens, cacheUsage, cacheProfile != nil),
})
h.sendSSE(w, flusher, "message_stop", map[string]interface{}{
"type": "message_stop",
})
}
func (h *Handler) sendSSE(w http.ResponseWriter, flusher http.Flusher, event string, data interface{}) {
jsonData, _ := json.Marshal(data)
fmt.Fprintf(w, "event: %s\ndata: %s\n\n", event, string(jsonData))
flusher.Flush()
}
// backgroundStatsSaver 后台定时保存统计数据
func (h *Handler) backgroundStatsSaver() {
ticker := time.NewTicker(30 * time.Second)
defer ticker.Stop()
for {
select {
case <-ticker.C:
h.saveStats()
case <-h.stopStatsSaver:
h.saveStats() // 退出前保存一次
return
}
}
}
// saveStats 保存统计到配置文件
func (h *Handler) saveStats() {
config.UpdateStats(
int(atomic.LoadInt64(&h.totalRequests)),
int(atomic.LoadInt64(&h.successRequests)),
int(atomic.LoadInt64(&h.failedRequests)),
int(atomic.LoadInt64(&h.totalTokens)),
h.getCredits(),
)
}
// getCredits 线程安全获取 credits
func (h *Handler) getCredits() float64 {
h.creditsMu.RLock()
defer h.creditsMu.RUnlock()
return h.totalCredits
}
// addCredits 线程安全增加 credits
func (h *Handler) addCredits(credits float64) {
h.creditsMu.Lock()
h.totalCredits += credits
h.creditsMu.Unlock()
}
// 统计记录 (使用原子操作)
func (h *Handler) recordSuccess(inputTokens, outputTokens int, credits float64) {
atomic.AddInt64(&h.totalRequests, 1)
atomic.AddInt64(&h.successRequests, 1)
atomic.AddInt64(&h.totalTokens, int64(inputTokens+outputTokens))
h.addCredits(credits)
}
func (h *Handler) recordFailure() {
atomic.AddInt64(&h.totalRequests, 1)
atomic.AddInt64(&h.failedRequests, 1)
}
// handleClaudeNonStream Claude 非流式响应
func (h *Handler) handleClaudeNonStream(w http.ResponseWriter, account *config.Account, payload *KiroPayload, model string, thinking bool, thinkingOpts claudeThinkingResponseOptions, estimatedInputTokens int, cacheUsage promptCacheUsage, cacheProfile *promptCacheProfile) {
var content string
var thinkingContent string
var toolUses []KiroToolUse
var inputTokens, outputTokens int
var credits float64
var realInputTokens int
callback := &KiroStreamCallback{
OnText: func(text string, isThinking bool) {
if isThinking {
thinkingContent += text
} else {
content += text
}
},
OnToolUse: func(tu KiroToolUse) {
toolUses = append(toolUses, tu)
},
OnComplete: func(inTok, outTok int) {
inputTokens = inTok
outputTokens = outTok
},
OnError: func(err error) {
h.pool.RecordError(account.ID, strings.Contains(err.Error(), "429"))
},
OnCredits: func(c float64) {
credits = c
},
OnContextUsage: func(pct float64) {
realInputTokens = int(pct * float64(getContextWindowSize(model)) / 100.0)
},
}
err := CallKiroAPI(account, payload, callback)
if err != nil {
h.recordFailure()
h.pool.RecordError(account.ID, strings.Contains(err.Error(), "429"))
h.sendClaudeError(w, 500, "api_error", err.Error())
return
}
// 合并 thinking 内容(如果有 reasoningContentEvent 的内容)
thinkingFormat := thinkingOpts.Format
finalContent, extractedReasoning := extractThinkingFromContent(content)
rawThinkingContent := thinkingContent
if thinking && rawThinkingContent == "" && extractedReasoning != "" {
rawThinkingContent = extractedReasoning
}
if !thinking {
rawThinkingContent = ""
}
if realInputTokens > 0 {
inputTokens = realInputTokens
} else if inputTokens <= 0 {
inputTokens = estimatedInputTokens
}
outputTokens = estimateClaudeOutputTokens(finalContent, rawThinkingContent, toolUses)
h.recordSuccess(inputTokens, outputTokens, credits)
h.pool.RecordSuccess(account.ID)
h.pool.UpdateStats(account.ID, inputTokens+outputTokens, credits)
h.promptCache.Update(account.ID, cacheProfile)
responseThinkingContent := rawThinkingContent
includeEmptyThinkingBlock := thinking && thinkingOpts.OmitDisplay && rawThinkingContent != ""
if includeEmptyThinkingBlock {
responseThinkingContent = ""
}
if thinking && responseThinkingContent != "" {
switch thinkingFormat {
case "think":
finalContent = "<think>" + responseThinkingContent + "</think>" + finalContent
responseThinkingContent = ""
case "reasoning_content":
finalContent = responseThinkingContent + finalContent // Claude 格式不支持 reasoning_content直接拼接
responseThinkingContent = ""
default:
}
}
resp := KiroToClaudeResponse(finalContent, responseThinkingContent, includeEmptyThinkingBlock, toolUses, inputTokens, outputTokens, model)
resp.Usage.InputTokens = inputTokens
resp.Usage.CacheCreationInputTokens = cacheUsage.CacheCreationInputTokens
resp.Usage.CacheReadInputTokens = cacheUsage.CacheReadInputTokens
if cacheProfile != nil {
resp.Usage.CacheCreation = &ClaudeCacheCreationUsage{
Ephemeral5mInputTokens: cacheUsage.CacheCreation5mInputTokens,
Ephemeral1hInputTokens: cacheUsage.CacheCreation1hInputTokens,
}
}
w.Header().Set("Content-Type", "application/json; charset=utf-8")
json.NewEncoder(w).Encode(resp)
}
func (h *Handler) sendClaudeError(w http.ResponseWriter, status int, errType, message string) {
w.Header().Set("Content-Type", "application/json; charset=utf-8")
w.WriteHeader(status)
json.NewEncoder(w).Encode(map[string]interface{}{
"type": "error",
"error": map[string]string{
"type": errType,
"message": message,
},
})
}
// handleOpenAIChat OpenAI API 处理
func (h *Handler) handleOpenAIChat(w http.ResponseWriter, r *http.Request) {
if r.Method != "POST" {
http.Error(w, "Method Not Allowed", 405)
return
}
body, err := io.ReadAll(r.Body)
if err != nil {
h.sendOpenAIError(w, 400, "invalid_request_error", "Failed to read request body")
return
}
var req OpenAIRequest
if err := json.Unmarshal(body, &req); err != nil {
h.sendOpenAIError(w, 400, "invalid_request_error", "Invalid JSON")
return
}
if msg := validateOpenAIRequestShape(&req); msg != "" {
h.sendOpenAIError(w, 400, "invalid_request_error", msg)
return
}
account := h.pool.GetNext()
if account == nil {
h.sendOpenAIError(w, 503, "server_error", "No available accounts")
return
}
if err := h.ensureValidToken(account); err != nil {
h.sendOpenAIError(w, 503, "server_error", "Token refresh failed")
return
}
// 解析模型和 thinking 模式
thinkingCfg := config.GetThinkingConfig()
actualModel, thinking := ParseModelAndThinking(req.Model, thinkingCfg.Suffix)
req.Model = actualModel
estimatedInputTokens := estimateOpenAIRequestInputTokens(&req)
kiroPayload := OpenAIToKiro(&req, thinking)
if req.Stream {
h.handleOpenAIStream(w, account, kiroPayload, req.Model, thinking, estimatedInputTokens)
} else {
h.handleOpenAINonStream(w, account, kiroPayload, req.Model, thinking, estimatedInputTokens)
}
}
// handleOpenAIStream OpenAI 流式响应
func (h *Handler) handleOpenAIStream(w http.ResponseWriter, account *config.Account, payload *KiroPayload, model string, thinking bool, estimatedInputTokens int) {
w.Header().Set("Content-Type", "text/event-stream; charset=utf-8")
w.Header().Set("Cache-Control", "no-cache")
w.Header().Set("Connection", "keep-alive")
flusher, ok := w.(http.Flusher)
if !ok {
h.sendOpenAIError(w, 500, "server_error", "Streaming not supported")
return
}
// 获取 thinking 输出格式配置
thinkingFormat := config.GetThinkingConfig().OpenAIFormat
chatID := "chatcmpl-" + uuid.New().String()
var toolCalls []ToolCall
var toolCallIndex int
var inputTokens, outputTokens int
var credits float64
var realInputTokens int
var rawContentBuilder strings.Builder
var rawReasoningBuilder strings.Builder
// Thinking 标签解析状态
var textBuffer string
var inThinkingBlock bool
var dropTagThinking bool
var thinkingSource thinkingStreamSource
// 发送 chunk 的辅助函数
// thinkingState: 0=普通内容, 1=thinking开始, 2=thinking中间, 3=thinking结束
sendChunk := func(content string, thinkingState int) {
if content == "" && thinkingState == 2 {
return
}
var chunk map[string]interface{}
if thinkingState > 0 {
if !thinking {
return
}
// thinking 内容
switch thinkingFormat {
case "thinking":
// 流式输出标签
var text string
switch thinkingState {
case 1: // 开始
text = "<thinking>" + content
case 2: // 中间
text = content
case 3: // 结束
text = content + "</thinking>"
}
if text == "" {
return
}
chunk = map[string]interface{}{
"id": chatID,
"object": "chat.completion.chunk",
"created": time.Now().Unix(),
"model": model,
"choices": []map[string]interface{}{{
"index": 0,
"delta": map[string]string{"content": text},
"finish_reason": nil,
}},
}
case "think":
var text string
switch thinkingState {
case 1:
text = "<think>" + content
case 2:
text = content
case 3:
text = content + "</think>"
}
if text == "" {
return
}
chunk = map[string]interface{}{
"id": chatID,
"object": "chat.completion.chunk",
"created": time.Now().Unix(),
"model": model,
"choices": []map[string]interface{}{{
"index": 0,
"delta": map[string]string{"content": text},
"finish_reason": nil,
}},
}
default: // "reasoning_content"
if content == "" {
return
}
chunk = map[string]interface{}{
"id": chatID,
"object": "chat.completion.chunk",
"created": time.Now().Unix(),
"model": model,
"choices": []map[string]interface{}{{
"index": 0,
"delta": map[string]string{"reasoning_content": content},
"finish_reason": nil,
}},
}
}
} else {
// 普通内容
if content == "" {
return
}
chunk = map[string]interface{}{
"id": chatID,
"object": "chat.completion.chunk",
"created": time.Now().Unix(),
"model": model,
"choices": []map[string]interface{}{{
"index": 0,
"delta": map[string]string{"content": content},
"finish_reason": nil,
}},
}
}
data, _ := json.Marshal(chunk)
fmt.Fprintf(w, "data: %s\n\n", string(data))
flusher.Flush()
}
// 处理文本,解析 <thinking> 标签
// thinkingStarted 用于跟踪是否已发送开始标签
var thinkingStarted bool
var eventThinkingOpen bool
processText := func(text string, isThinking bool, forceFlush bool) {
if isThinking && !thinking {
return
}
// 如果是 reasoningContentEvent直接输出
if isThinking {
if !allowReasoningSource(&thinkingSource) {
return
}
if !thinkingStarted {
sendChunk(text, 1) // 开始
thinkingStarted = true
eventThinkingOpen = true
} else {
sendChunk(text, 2) // 中间
}
return
}
if eventThinkingOpen {
sendChunk("", 3)
eventThinkingOpen = false
thinkingStarted = false
}
textBuffer += text
for {
if !inThinkingBlock {
// 查找 <thinking> 开始标签
thinkingStart := strings.Index(textBuffer, "<thinking>")
if thinkingStart != -1 {
// 输出 thinking 标签之前的内容
if thinkingStart > 0 {
sendChunk(textBuffer[:thinkingStart], 0)
}
textBuffer = textBuffer[thinkingStart+10:] // 移除 <thinking>
inThinkingBlock = true
dropTagThinking = !allowTagSource(&thinkingSource)
thinkingStarted = false // 重置,准备发送新的开始标签
} else if forceFlush || len([]rune(textBuffer)) > 50 {
// 没有找到标签,安全输出(保留可能的部分标签)
runes := []rune(textBuffer)
safeLen := len(runes)
if !forceFlush {
safeLen = max(0, len(runes)-15)
}
if safeLen > 0 {
sendChunk(string(runes[:safeLen]), 0)
textBuffer = string(runes[safeLen:])
}
break
} else {
break
}
} else {
// 在 thinking 块内,查找 </thinking> 结束标签
thinkingEnd := strings.Index(textBuffer, "</thinking>")
if thinkingEnd != -1 {
// 输出 thinking 内容
content := textBuffer[:thinkingEnd]
if !dropTagThinking {
if !thinkingStarted {
// 一次性输出完整内容(开始+内容+结束)
sendChunk(content, 1) // 开始
sendChunk("", 3) // 结束(空内容,只发结束标签)
} else {
// 已经开始了,发送剩余内容和结束
sendChunk(content, 3) // 结束
}
}
textBuffer = textBuffer[thinkingEnd+11:] // 移除 </thinking>
inThinkingBlock = false
dropTagThinking = false
thinkingStarted = false
} else if forceFlush {
// 强制刷新:输出剩余内容
if textBuffer != "" {
if !dropTagThinking {
if !thinkingStarted {
sendChunk(textBuffer, 1) // 开始
sendChunk("", 3) // 结束
} else {
sendChunk(textBuffer, 3) // 结束
}
}
textBuffer = ""
}
inThinkingBlock = false
dropTagThinking = false
thinkingStarted = false
break
} else {
// 流式输出 thinking 块内的内容
runes := []rune(textBuffer)
if len(runes) > 20 {
safeLen := len(runes) - 15 // 保留可能的 </thinking> 部分
if safeLen > 0 {
if !dropTagThinking {
if !thinkingStarted {
sendChunk(string(runes[:safeLen]), 1) // 开始
thinkingStarted = true
} else {
sendChunk(string(runes[:safeLen]), 2) // 中间
}
}
textBuffer = string(runes[safeLen:])
}
}
break
}
}
}
}
callback := &KiroStreamCallback{
OnText: func(text string, isThinking bool) {
if text == "" {
return
}
if isThinking {
rawReasoningBuilder.WriteString(text)
} else {
rawContentBuilder.WriteString(text)
}
processText(text, isThinking, false)
},
OnToolUse: func(tu KiroToolUse) {
// 先刷新缓冲区
processText("", false, true)
args, _ := json.Marshal(tu.Input)
rawContentBuilder.WriteString(tu.Name)
rawContentBuilder.Write(args)
tc := ToolCall{ID: tu.ToolUseID, Type: "function"}
tc.Function.Name = tu.Name
tc.Function.Arguments = string(args)
toolCalls = append(toolCalls, tc)
chunk := map[string]interface{}{
"id": chatID,
"object": "chat.completion.chunk",
"created": time.Now().Unix(),
"model": model,
"choices": []map[string]interface{}{{
"index": 0,
"delta": map[string]interface{}{
"tool_calls": []map[string]interface{}{{
"index": toolCallIndex,
"id": tu.ToolUseID,
"type": "function",
"function": map[string]string{
"name": tu.Name,
"arguments": string(args),
},
}},
},
"finish_reason": nil,
}},
}
toolCallIndex++
data, _ := json.Marshal(chunk)
fmt.Fprintf(w, "data: %s\n\n", string(data))
flusher.Flush()
},
OnComplete: func(inTok, outTok int) {
inputTokens = inTok
outputTokens = outTok
},
OnError: func(err error) {
h.pool.RecordError(account.ID, strings.Contains(err.Error(), "429"))
},
OnCredits: func(c float64) {
credits = c
},
OnContextUsage: func(pct float64) {
realInputTokens = int(pct * float64(getContextWindowSize(model)) / 100.0)
},
}
err := CallKiroAPI(account, payload, callback)
if err != nil {
h.recordFailure()
h.pool.RecordError(account.ID, strings.Contains(err.Error(), "429"))
return
}
// 刷新剩余缓冲区
processText("", false, true)
if eventThinkingOpen {
sendChunk("", 3)
eventThinkingOpen = false
}
if realInputTokens > 0 {
inputTokens = realInputTokens
} else if inputTokens <= 0 {
inputTokens = estimatedInputTokens
}
outputContent, extractedReasoning := extractThinkingFromContent(rawContentBuilder.String())
reasoningOutput := rawReasoningBuilder.String()
if thinking && reasoningOutput == "" && extractedReasoning != "" {
reasoningOutput = extractedReasoning
}
if !thinking {
reasoningOutput = ""
}
outputTokens = estimateApproxTokens(outputContent) + estimateApproxTokens(reasoningOutput)
for _, tc := range toolCalls {
outputTokens += estimateApproxTokens(tc.Function.Name)
outputTokens += estimateApproxTokens(tc.Function.Arguments)
}
h.recordSuccess(inputTokens, outputTokens, credits)
h.pool.RecordSuccess(account.ID)
h.pool.UpdateStats(account.ID, inputTokens+outputTokens, credits)
// 发送结束
finishReason := "stop"
if len(toolCalls) > 0 {
finishReason = "tool_calls"
}
chunk := map[string]interface{}{
"id": chatID,
"object": "chat.completion.chunk",
"created": time.Now().Unix(),
"model": model,
"choices": []map[string]interface{}{{
"index": 0,
"delta": map[string]interface{}{},
"finish_reason": finishReason,
}},
"usage": map[string]int{
"prompt_tokens": inputTokens,
"completion_tokens": outputTokens,
"total_tokens": inputTokens + outputTokens,
},
}
data, _ := json.Marshal(chunk)
fmt.Fprintf(w, "data: %s\n\n", string(data))
fmt.Fprintf(w, "data: [DONE]\n\n")
flusher.Flush()
}
// handleOpenAINonStream OpenAI 非流式响应
func (h *Handler) handleOpenAINonStream(w http.ResponseWriter, account *config.Account, payload *KiroPayload, model string, thinking bool, estimatedInputTokens int) {
var content string
var reasoningContent string
var toolUses []KiroToolUse
var inputTokens, outputTokens int
var credits float64
var realInputTokens int
callback := &KiroStreamCallback{
OnText: func(text string, isThinking bool) {
if isThinking {
reasoningContent += text
} else {
content += text
}
},
OnToolUse: func(tu KiroToolUse) { toolUses = append(toolUses, tu) },
OnComplete: func(inTok, outTok int) { inputTokens = inTok; outputTokens = outTok },
OnError: func(err error) { h.pool.RecordError(account.ID, strings.Contains(err.Error(), "429")) },
OnCredits: func(c float64) { credits = c },
OnContextUsage: func(pct float64) {
realInputTokens = int(pct * float64(getContextWindowSize(model)) / 100.0)
},
}
err := CallKiroAPI(account, payload, callback)
if err != nil {
h.recordFailure()
h.pool.RecordError(account.ID, strings.Contains(err.Error(), "429"))
h.sendOpenAIError(w, 500, "server_error", err.Error())
return
}
// 解析 content 中的 <thinking> 标签
finalContent, extractedReasoning := extractThinkingFromContent(content)
if thinking && reasoningContent == "" && extractedReasoning != "" {
reasoningContent = extractedReasoning
} else if !thinking {
reasoningContent = ""
}
if realInputTokens > 0 {
inputTokens = realInputTokens
} else if inputTokens <= 0 {
inputTokens = estimatedInputTokens
}
outputTokens = estimateOpenAIOutputTokens(finalContent, reasoningContent, toolUses)
h.recordSuccess(inputTokens, outputTokens, credits)
h.pool.RecordSuccess(account.ID)
h.pool.UpdateStats(account.ID, inputTokens+outputTokens, credits)
thinkingFormat := config.GetThinkingConfig().OpenAIFormat
resp := KiroToOpenAIResponseWithReasoning(finalContent, reasoningContent, toolUses, inputTokens, outputTokens, model, thinkingFormat)
w.Header().Set("Content-Type", "application/json; charset=utf-8")
json.NewEncoder(w).Encode(resp)
}
func (h *Handler) sendOpenAIError(w http.ResponseWriter, status int, errType, message string) {
w.Header().Set("Content-Type", "application/json; charset=utf-8")
w.WriteHeader(status)
json.NewEncoder(w).Encode(map[string]interface{}{
"error": map[string]interface{}{
"type": errType,
"message": message,
},
})
}
// ensureValidToken 确保 token 有效
func (h *Handler) ensureValidToken(account *config.Account) error {
if account.ExpiresAt == 0 || time.Now().Unix() < account.ExpiresAt-300 {
return nil
}
accessToken, refreshToken, expiresAt, err := auth.RefreshToken(account)
if err != nil {
return err
}
// 更新内存
h.pool.UpdateToken(account.ID, accessToken, refreshToken, expiresAt)
account.AccessToken = accessToken
if refreshToken != "" {
account.RefreshToken = refreshToken
}
account.ExpiresAt = expiresAt
// 持久化
config.UpdateAccountToken(account.ID, accessToken, refreshToken, expiresAt)
return nil
}
// ==================== 管理 API ====================
func (h *Handler) handleAdminAPI(w http.ResponseWriter, r *http.Request) {
// 验证密码
password := r.Header.Get("X-Admin-Password")
if password == "" {
cookie, _ := r.Cookie("admin_password")
if cookie != nil {
password = cookie.Value
}
}
if password != config.GetPassword() {
w.WriteHeader(401)
json.NewEncoder(w).Encode(map[string]string{"error": "Unauthorized"})
return
}
path := strings.TrimPrefix(r.URL.Path, "/admin/api")
w.Header().Set("Content-Type", "application/json; charset=utf-8")
switch {
case path == "/accounts" && r.Method == "GET":
h.apiGetAccounts(w, r)
case path == "/accounts" && r.Method == "POST":
h.apiAddAccount(w, r)
case path == "/accounts/batch" && r.Method == "POST":
h.apiBatchAccounts(w, r)
case strings.HasPrefix(path, "/accounts/") && strings.HasSuffix(path, "/refresh") && r.Method == "POST":
id := strings.TrimSuffix(strings.TrimPrefix(path, "/accounts/"), "/refresh")
h.apiRefreshAccount(w, r, id)
case strings.HasPrefix(path, "/accounts/") && strings.HasSuffix(path, "/models") && r.Method == "GET":
id := strings.TrimSuffix(strings.TrimPrefix(path, "/accounts/"), "/models")
h.apiGetAccountModels(w, r, id)
case strings.HasPrefix(path, "/accounts/") && strings.HasSuffix(path, "/full") && r.Method == "GET":
id := strings.TrimSuffix(strings.TrimPrefix(path, "/accounts/"), "/full")
h.apiGetAccountFull(w, r, id)
case strings.HasPrefix(path, "/accounts/") && r.Method == "DELETE":
h.apiDeleteAccount(w, r, strings.TrimPrefix(path, "/accounts/"))
case strings.HasPrefix(path, "/accounts/") && r.Method == "PUT":
h.apiUpdateAccount(w, r, strings.TrimPrefix(path, "/accounts/"))
case path == "/auth/iam-sso/start" && r.Method == "POST":
h.apiStartIamSso(w, r)
case path == "/auth/iam-sso/complete" && r.Method == "POST":
h.apiCompleteIamSso(w, r)
case path == "/auth/builderid/start" && r.Method == "POST":
h.apiStartBuilderIdLogin(w, r)
case path == "/auth/builderid/poll" && r.Method == "POST":
h.apiPollBuilderIdAuth(w, r)
case path == "/auth/sso-token" && r.Method == "POST":
h.apiImportSsoToken(w, r)
case path == "/auth/credentials" && r.Method == "POST":
h.apiImportCredentials(w, r)
case path == "/status" && r.Method == "GET":
h.apiGetStatus(w, r)
case path == "/settings" && r.Method == "GET":
h.apiGetSettings(w, r)
case path == "/settings" && r.Method == "POST":
h.apiUpdateSettings(w, r)
case path == "/stats" && r.Method == "GET":
h.apiGetStats(w, r)
case path == "/stats/reset" && r.Method == "POST":
h.apiResetStats(w, r)
case path == "/generate-machine-id" && r.Method == "GET":
h.apiGenerateMachineId(w, r)
case path == "/thinking" && r.Method == "GET":
h.apiGetThinkingConfig(w, r)
case path == "/thinking" && r.Method == "POST":
h.apiUpdateThinkingConfig(w, r)
case path == "/endpoint" && r.Method == "GET":
h.apiGetEndpointConfig(w, r)
case path == "/endpoint" && r.Method == "POST":
h.apiUpdateEndpointConfig(w, r)
case path == "/proxy" && r.Method == "GET":
h.apiGetProxy(w, r)
case path == "/proxy" && r.Method == "POST":
h.apiUpdateProxy(w, r)
case path == "/general" && r.Method == "GET":
h.apiGetGeneralConfig(w, r)
case path == "/general" && r.Method == "POST":
h.apiUpdateGeneralConfig(w, r)
case path == "/version" && r.Method == "GET":
h.apiGetVersion(w, r)
case path == "/export" && r.Method == "POST":
h.apiExportAccounts(w, r)
default:
w.WriteHeader(404)
json.NewEncoder(w).Encode(map[string]string{"error": "Not Found"})
}
}
func (h *Handler) apiGetAccounts(w http.ResponseWriter, r *http.Request) {
accounts := config.GetAccounts()
poolAccounts := h.pool.GetAllAccounts()
// 合并运行时统计
statsMap := make(map[string]config.Account)
for _, a := range poolAccounts {
statsMap[a.ID] = a
}
// 隐藏敏感信息
result := make([]map[string]interface{}, len(accounts))
for i, a := range accounts {
// 获取运行时统计
stats := statsMap[a.ID]
result[i] = map[string]interface{}{
"id": a.ID,
"email": a.Email,
"userId": a.UserId,
"nickname": a.Nickname,
"authMethod": a.AuthMethod,
"provider": a.Provider,
"region": a.Region,
"enabled": a.Enabled,
"banStatus": a.BanStatus,
"banReason": a.BanReason,
"banTime": a.BanTime,
"expiresAt": a.ExpiresAt,
"hasToken": a.AccessToken != "",
"machineId": a.MachineId,
"weight": a.Weight,
"subscriptionType": a.SubscriptionType,
"subscriptionTitle": a.SubscriptionTitle,
"daysRemaining": a.DaysRemaining,
"usageCurrent": a.UsageCurrent,
"usageLimit": a.UsageLimit,
"usagePercent": a.UsagePercent,
"nextResetDate": a.NextResetDate,
"lastRefresh": a.LastRefresh,
"trialUsageCurrent": a.TrialUsageCurrent,
"trialUsageLimit": a.TrialUsageLimit,
"trialUsagePercent": a.TrialUsagePercent,
"trialStatus": a.TrialStatus,
"trialExpiresAt": a.TrialExpiresAt,
"requestCount": stats.RequestCount,
"errorCount": stats.ErrorCount,
"totalTokens": stats.TotalTokens,
"totalCredits": stats.TotalCredits,
"lastUsed": stats.LastUsed,
}
}
json.NewEncoder(w).Encode(result)
}
func (h *Handler) apiAddAccount(w http.ResponseWriter, r *http.Request) {
var account config.Account
if err := json.NewDecoder(r.Body).Decode(&account); err != nil {
w.WriteHeader(400)
json.NewEncoder(w).Encode(map[string]string{"error": "Invalid JSON"})
return
}
if account.ID == "" {
account.ID = auth.GenerateAccountID()
}
if account.Region == "" {
account.Region = "us-east-1"
}
if err := config.AddAccount(account); err != nil {
w.WriteHeader(500)
json.NewEncoder(w).Encode(map[string]string{"error": err.Error()})
return
}
h.pool.Reload()
json.NewEncoder(w).Encode(map[string]interface{}{"success": true, "id": account.ID})
}
func (h *Handler) apiDeleteAccount(w http.ResponseWriter, r *http.Request, id string) {
if err := config.DeleteAccount(id); err != nil {
w.WriteHeader(500)
json.NewEncoder(w).Encode(map[string]string{"error": err.Error()})
return
}
h.pool.Reload()
json.NewEncoder(w).Encode(map[string]bool{"success": true})
}
func (h *Handler) apiUpdateAccount(w http.ResponseWriter, r *http.Request, id string) {
var updates map[string]interface{}
if err := json.NewDecoder(r.Body).Decode(&updates); err != nil {
w.WriteHeader(400)
json.NewEncoder(w).Encode(map[string]string{"error": "Invalid JSON"})
return
}
// 获取现有账号
accounts := config.GetAccounts()
var existing *config.Account
for i := range accounts {
if accounts[i].ID == id {
existing = &accounts[i]
break
}
}
if existing == nil {
w.WriteHeader(404)
json.NewEncoder(w).Encode(map[string]string{"error": "Account not found"})
return
}
// 只更新传入的字段
if v, ok := updates["enabled"].(bool); ok {
existing.Enabled = v
}
if v, ok := updates["nickname"].(string); ok {
existing.Nickname = v
}
if v, ok := updates["machineId"].(string); ok {
existing.MachineId = v
}
if v, ok := updates["weight"].(float64); ok {
existing.Weight = int(v)
}
if err := config.UpdateAccount(id, *existing); err != nil {
w.WriteHeader(500)
json.NewEncoder(w).Encode(map[string]string{"error": err.Error()})
return
}
h.pool.Reload()
json.NewEncoder(w).Encode(map[string]bool{"success": true})
}
// apiBatchAccounts 批量操作账号(启用/禁用/刷新)
func (h *Handler) apiBatchAccounts(w http.ResponseWriter, r *http.Request) {
var req struct {
IDs []string `json:"ids"`
Action string `json:"action"` // "enable", "disable", "refresh"
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
w.WriteHeader(400)
json.NewEncoder(w).Encode(map[string]string{"error": "Invalid JSON"})
return
}
if len(req.IDs) == 0 {
w.WriteHeader(400)
json.NewEncoder(w).Encode(map[string]string{"error": "No account IDs provided"})
return
}
switch req.Action {
case "enable", "disable":
enabled := req.Action == "enable"
accounts := config.GetAccounts()
idSet := make(map[string]bool)
for _, id := range req.IDs {
idSet[id] = true
}
for _, a := range accounts {
if idSet[a.ID] {
a.Enabled = enabled
if enabled && a.BanStatus != "" && a.BanStatus != "ACTIVE" {
a.BanStatus = "ACTIVE"
a.BanReason = ""
a.BanTime = 0
}
config.UpdateAccount(a.ID, a)
}
}
h.pool.Reload()
json.NewEncoder(w).Encode(map[string]interface{}{"success": true, "count": len(req.IDs)})
case "refresh":
successCount := 0
failCount := 0
for _, id := range req.IDs {
accounts := config.GetAccounts()
var account *config.Account
for i := range accounts {
if accounts[i].ID == id {
account = &accounts[i]
break
}
}
if account == nil {
failCount++
continue
}
// 刷新 token
if account.RefreshToken != "" {
if newAccess, newRefresh, newExpires, err := auth.RefreshToken(account); err == nil {
account.AccessToken = newAccess
if newRefresh != "" {
account.RefreshToken = newRefresh
}
account.ExpiresAt = newExpires
config.UpdateAccountToken(id, newAccess, newRefresh, newExpires)
h.pool.UpdateToken(id, newAccess, newRefresh, newExpires)
}
}
// 刷新账户信息
info, err := RefreshAccountInfo(account)
if err != nil {
failCount++
continue
}
config.UpdateAccountInfo(id, *info)
successCount++
}
h.pool.Reload()
json.NewEncoder(w).Encode(map[string]interface{}{
"success": true,
"refreshed": successCount,
"failed": failCount,
})
default:
w.WriteHeader(400)
json.NewEncoder(w).Encode(map[string]string{"error": "Invalid action: " + req.Action})
}
}
func (h *Handler) apiStartIamSso(w http.ResponseWriter, r *http.Request) {
var req struct {
StartUrl string `json:"startUrl"`
Region string `json:"region"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
w.WriteHeader(400)
json.NewEncoder(w).Encode(map[string]string{"error": "Invalid JSON"})
return
}
if req.StartUrl == "" {
w.WriteHeader(400)
json.NewEncoder(w).Encode(map[string]string{"error": "startUrl is required"})
return
}
sessionID, authorizeUrl, expiresIn, err := auth.StartIamSsoLogin(req.StartUrl, req.Region)
if err != nil {
w.WriteHeader(500)
json.NewEncoder(w).Encode(map[string]string{"error": err.Error()})
return
}
json.NewEncoder(w).Encode(map[string]interface{}{
"sessionId": sessionID,
"authorizeUrl": authorizeUrl,
"expiresIn": expiresIn,
})
}
func (h *Handler) apiCompleteIamSso(w http.ResponseWriter, r *http.Request) {
var req struct {
SessionID string `json:"sessionId"`
CallbackUrl string `json:"callbackUrl"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
w.WriteHeader(400)
json.NewEncoder(w).Encode(map[string]string{"error": "Invalid JSON"})
return
}
accessToken, refreshToken, clientID, clientSecret, region, expiresIn, err := auth.CompleteIamSsoLogin(req.SessionID, req.CallbackUrl)
if err != nil {
w.WriteHeader(400)
json.NewEncoder(w).Encode(map[string]string{"error": err.Error()})
return
}
// 获取用户信息
email, _, _ := auth.GetUserInfo(accessToken)
// 创建账号
account := config.Account{
ID: auth.GenerateAccountID(),
Email: email,
AccessToken: accessToken,
RefreshToken: refreshToken,
ClientID: clientID,
ClientSecret: clientSecret,
AuthMethod: "idc",
Region: region,
ExpiresAt: time.Now().Unix() + int64(expiresIn),
Enabled: true,
MachineId: config.GenerateMachineId(),
}
if err := config.AddAccount(account); err != nil {
w.WriteHeader(500)
json.NewEncoder(w).Encode(map[string]string{"error": err.Error()})
return
}
h.pool.Reload()
json.NewEncoder(w).Encode(map[string]interface{}{
"success": true,
"account": map[string]interface{}{
"id": account.ID,
"email": account.Email,
},
})
}
func (h *Handler) apiStartBuilderIdLogin(w http.ResponseWriter, r *http.Request) {
var req struct {
Region string `json:"region"`
}
json.NewDecoder(r.Body).Decode(&req)
session, err := auth.StartBuilderIdLogin(req.Region)
if err != nil {
w.WriteHeader(500)
json.NewEncoder(w).Encode(map[string]string{"error": err.Error()})
return
}
json.NewEncoder(w).Encode(map[string]interface{}{
"sessionId": session.ID,
"userCode": session.UserCode,
"verificationUri": session.VerificationUri,
"interval": session.Interval,
})
}
func (h *Handler) apiPollBuilderIdAuth(w http.ResponseWriter, r *http.Request) {
var req struct {
SessionID string `json:"sessionId"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
w.WriteHeader(400)
json.NewEncoder(w).Encode(map[string]string{"error": "Invalid JSON"})
return
}
accessToken, refreshToken, clientID, clientSecret, region, expiresIn, status, err := auth.PollBuilderIdAuth(req.SessionID)
if err != nil {
w.WriteHeader(400)
json.NewEncoder(w).Encode(map[string]interface{}{
"success": false,
"error": err.Error(),
})
return
}
if status == "pending" || status == "slow_down" {
// 获取当前间隔
interval := 5
if session := auth.GetBuilderIdSession(req.SessionID); session != nil {
interval = session.Interval
}
json.NewEncoder(w).Encode(map[string]interface{}{
"success": true,
"completed": false,
"status": status,
"interval": interval,
})
return
}
// 授权完成,获取用户信息
email, _, _ := auth.GetUserInfo(accessToken)
// 创建账号
account := config.Account{
ID: auth.GenerateAccountID(),
Email: email,
AccessToken: accessToken,
RefreshToken: refreshToken,
ClientID: clientID,
ClientSecret: clientSecret,
AuthMethod: "idc",
Provider: "BuilderId",
Region: region,
ExpiresAt: time.Now().Unix() + int64(expiresIn),
Enabled: true,
MachineId: config.GenerateMachineId(),
}
if err := config.AddAccount(account); err != nil {
w.WriteHeader(500)
json.NewEncoder(w).Encode(map[string]string{"error": err.Error()})
return
}
h.pool.Reload()
json.NewEncoder(w).Encode(map[string]interface{}{
"success": true,
"completed": true,
"account": map[string]interface{}{
"id": account.ID,
"email": account.Email,
},
})
}
func (h *Handler) apiImportSsoToken(w http.ResponseWriter, r *http.Request) {
var req struct {
BearerToken string `json:"bearerToken"`
Region string `json:"region"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
w.WriteHeader(400)
json.NewEncoder(w).Encode(map[string]string{"error": "Invalid JSON"})
return
}
if req.BearerToken == "" {
w.WriteHeader(400)
json.NewEncoder(w).Encode(map[string]string{"error": "bearerToken is required"})
return
}
// 支持批量导入,按行分割
tokens := strings.Split(strings.TrimSpace(req.BearerToken), "\n")
var imported []map[string]interface{}
var errors []string
for _, token := range tokens {
token = strings.TrimSpace(token)
if token == "" {
continue
}
accessToken, refreshToken, clientID, clientSecret, expiresIn, err := auth.ImportFromSsoToken(token, req.Region)
if err != nil {
errors = append(errors, err.Error())
continue
}
// 获取用户信息
email, _, _ := auth.GetUserInfo(accessToken)
// 创建账号
account := config.Account{
ID: auth.GenerateAccountID(),
Email: email,
AccessToken: accessToken,
RefreshToken: refreshToken,
ClientID: clientID,
ClientSecret: clientSecret,
AuthMethod: "idc",
Region: req.Region,
ExpiresAt: time.Now().Unix() + int64(expiresIn),
Enabled: true,
MachineId: config.GenerateMachineId(),
}
if err := config.AddAccount(account); err != nil {
errors = append(errors, err.Error())
continue
}
imported = append(imported, map[string]interface{}{
"id": account.ID,
"email": account.Email,
})
}
h.pool.Reload()
if len(imported) == 0 && len(errors) > 0 {
w.WriteHeader(500)
json.NewEncoder(w).Encode(map[string]interface{}{
"success": false,
"error": strings.Join(errors, "; "),
})
return
}
json.NewEncoder(w).Encode(map[string]interface{}{
"success": true,
"accounts": imported,
"errors": errors,
})
}
func (h *Handler) apiImportCredentials(w http.ResponseWriter, r *http.Request) {
var req struct {
AccessToken string `json:"accessToken"`
RefreshToken string `json:"refreshToken"`
ClientID string `json:"clientId"`
ClientSecret string `json:"clientSecret"`
AuthMethod string `json:"authMethod"`
Provider string `json:"provider"`
Region string `json:"region"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
w.WriteHeader(400)
json.NewEncoder(w).Encode(map[string]string{"error": "Invalid JSON"})
return
}
if req.RefreshToken == "" {
w.WriteHeader(400)
json.NewEncoder(w).Encode(map[string]string{"error": "refreshToken is required"})
return
}
// 设置默认值
if req.Region == "" {
req.Region = "us-east-1"
}
if req.AuthMethod == "" {
if req.ClientID != "" {
req.AuthMethod = "idc"
} else {
req.AuthMethod = "social"
}
}
// 标准化 authMethod
switch strings.ToLower(req.AuthMethod) {
case "idc", "builderid", "enterprise":
req.AuthMethod = "idc"
case "social", "google", "github":
req.AuthMethod = "social"
default:
if req.ClientID != "" && req.ClientSecret != "" {
req.AuthMethod = "idc"
} else {
req.AuthMethod = "social"
}
}
// 始终尝试用 refreshToken 刷新获取新的 accessToken
var accessToken string
var expiresAt int64
tempAccount := &config.Account{
RefreshToken: req.RefreshToken,
ClientID: req.ClientID,
ClientSecret: req.ClientSecret,
AuthMethod: req.AuthMethod,
Region: req.Region,
}
newAccessToken, newRefreshToken, newExpiresAt, err := auth.RefreshToken(tempAccount)
if err != nil {
// 刷新失败,如果有传入的 accessToken 则尝试使用
if req.AccessToken != "" {
accessToken = req.AccessToken
expiresAt = time.Now().Unix() + 300 // 可能已过期,设短一点
} else {
w.WriteHeader(400)
json.NewEncoder(w).Encode(map[string]string{"error": "Token refresh failed: " + err.Error()})
return
}
} else {
accessToken = newAccessToken
if newRefreshToken != "" {
req.RefreshToken = newRefreshToken
}
expiresAt = newExpiresAt
}
// 获取用户信息
email, _, _ := auth.GetUserInfo(accessToken)
// 创建账号
account := config.Account{
ID: auth.GenerateAccountID(),
Email: email,
AccessToken: accessToken,
RefreshToken: req.RefreshToken,
ClientID: req.ClientID,
ClientSecret: req.ClientSecret,
AuthMethod: req.AuthMethod,
Provider: req.Provider,
Region: req.Region,
ExpiresAt: expiresAt,
Enabled: true,
MachineId: config.GenerateMachineId(),
}
if err := config.AddAccount(account); err != nil {
w.WriteHeader(500)
json.NewEncoder(w).Encode(map[string]string{"error": err.Error()})
return
}
h.pool.Reload()
json.NewEncoder(w).Encode(map[string]interface{}{
"success": true,
"account": map[string]interface{}{
"id": account.ID,
"email": account.Email,
},
})
}
func (h *Handler) apiGetStatus(w http.ResponseWriter, r *http.Request) {
json.NewEncoder(w).Encode(map[string]interface{}{
"accounts": h.pool.Count(),
"available": h.pool.AvailableCount(),
"totalRequests": h.totalRequests,
"successRequests": h.successRequests,
"failedRequests": h.failedRequests,
"totalTokens": h.totalTokens,
"totalCredits": h.totalCredits,
"uptime": time.Now().Unix() - h.startTime,
})
}
func (h *Handler) apiGetSettings(w http.ResponseWriter, r *http.Request) {
json.NewEncoder(w).Encode(map[string]interface{}{
"apiKey": config.GetApiKey(),
"requireApiKey": config.IsApiKeyRequired(),
"port": config.GetPort(),
"host": config.GetHost(),
})
}
func (h *Handler) apiUpdateSettings(w http.ResponseWriter, r *http.Request) {
var req struct {
ApiKey string `json:"apiKey"`
RequireApiKey bool `json:"requireApiKey"`
Password string `json:"password"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
w.WriteHeader(400)
json.NewEncoder(w).Encode(map[string]string{"error": "Invalid JSON"})
return
}
if err := config.UpdateSettings(req.ApiKey, req.RequireApiKey, req.Password); err != nil {
w.WriteHeader(500)
json.NewEncoder(w).Encode(map[string]string{"error": err.Error()})
return
}
json.NewEncoder(w).Encode(map[string]bool{"success": true})
}
func (h *Handler) apiGetStats(w http.ResponseWriter, r *http.Request) {
json.NewEncoder(w).Encode(map[string]interface{}{
"totalRequests": atomic.LoadInt64(&h.totalRequests),
"successRequests": atomic.LoadInt64(&h.successRequests),
"failedRequests": atomic.LoadInt64(&h.failedRequests),
"totalTokens": atomic.LoadInt64(&h.totalTokens),
"totalCredits": h.getCredits(),
"uptime": time.Now().Unix() - h.startTime,
})
}
func (h *Handler) apiResetStats(w http.ResponseWriter, r *http.Request) {
atomic.StoreInt64(&h.totalRequests, 0)
atomic.StoreInt64(&h.successRequests, 0)
atomic.StoreInt64(&h.failedRequests, 0)
atomic.StoreInt64(&h.totalTokens, 0)
h.creditsMu.Lock()
h.totalCredits = 0
h.creditsMu.Unlock()
config.UpdateStats(0, 0, 0, 0, 0)
json.NewEncoder(w).Encode(map[string]bool{"success": true})
}
// apiGenerateMachineId 生成新的机器码
func (h *Handler) apiGenerateMachineId(w http.ResponseWriter, r *http.Request) {
machineId := config.GenerateMachineId()
json.NewEncoder(w).Encode(map[string]string{"machineId": machineId})
}
// apiRefreshAccount 刷新账户信息(使用量、订阅等)
func (h *Handler) apiRefreshAccount(w http.ResponseWriter, r *http.Request, id string) {
accounts := config.GetAccounts()
var account *config.Account
for i := range accounts {
if accounts[i].ID == id {
account = &accounts[i]
break
}
}
if account == nil {
w.WriteHeader(404)
json.NewEncoder(w).Encode(map[string]string{"error": "Account not found"})
return
}
// 先尝试刷新 token不管是否过期确保 token 有效)
refreshTokenIfNeeded := func() error {
if account.RefreshToken == "" {
return nil
}
newAccessToken, newRefreshToken, newExpiresAt, err := auth.RefreshToken(account)
if err != nil {
return err
}
account.AccessToken = newAccessToken
if newRefreshToken != "" {
account.RefreshToken = newRefreshToken
}
account.ExpiresAt = newExpiresAt
config.UpdateAccountToken(id, newAccessToken, newRefreshToken, newExpiresAt)
h.pool.UpdateToken(id, newAccessToken, newRefreshToken, newExpiresAt)
return nil
}
// 检查 token 是否快过期,先刷新
if account.ExpiresAt > 0 && time.Now().Unix() > account.ExpiresAt-300 {
if err := refreshTokenIfNeeded(); err != nil {
w.WriteHeader(500)
json.NewEncoder(w).Encode(map[string]string{"error": "Token refresh failed: " + err.Error()})
return
}
}
// 获取账户信息
info, err := RefreshAccountInfo(account)
if err != nil {
// 检查是否为封禁相关错误
errMsg := err.Error()
if strings.Contains(errMsg, "TEMPORARILY_SUSPENDED") || strings.Contains(errMsg, "Account suspended") {
// 封禁状态已在 RefreshAccountInfo 中处理,静默返回成功
json.NewEncoder(w).Encode(map[string]interface{}{
"success": true,
"message": "Account status updated",
})
return
}
// 如果是 403/401说明 token 无效,尝试刷新后重试
if strings.Contains(errMsg, "403") || strings.Contains(errMsg, "401") || strings.Contains(errMsg, "invalid") || strings.Contains(errMsg, "expired") {
if refreshErr := refreshTokenIfNeeded(); refreshErr == nil {
// 重试
info, err = RefreshAccountInfo(account)
if err != nil {
// 重试后仍然失败,检查是否为封禁状态
if strings.Contains(err.Error(), "TEMPORARILY_SUSPENDED") || strings.Contains(err.Error(), "Account suspended") {
json.NewEncoder(w).Encode(map[string]interface{}{
"success": true,
"message": "Account status updated",
})
return
}
}
}
}
// 其他错误才显示错误信息
if err != nil {
w.WriteHeader(500)
json.NewEncoder(w).Encode(map[string]string{"error": err.Error()})
return
}
}
// 保存到配置
if err := config.UpdateAccountInfo(id, *info); err != nil {
w.WriteHeader(500)
json.NewEncoder(w).Encode(map[string]string{"error": err.Error()})
return
}
json.NewEncoder(w).Encode(map[string]interface{}{
"success": true,
"info": info,
})
}
// apiGetAccountFull 获取单个账号的完整信息(包含敏感字段)
func (h *Handler) apiGetAccountFull(w http.ResponseWriter, r *http.Request, id string) {
accounts := config.GetAccounts()
poolAccounts := h.pool.GetAllAccounts()
// 查找指定账号
var account *config.Account
for i := range accounts {
if accounts[i].ID == id {
account = &accounts[i]
break
}
}
if account == nil {
w.WriteHeader(404)
json.NewEncoder(w).Encode(map[string]string{"error": "Account not found"})
return
}
// 获取运行时统计
var stats config.Account
for _, a := range poolAccounts {
if a.ID == id {
stats = a
break
}
}
// 返回完整账号信息(包含敏感字段)
result := map[string]interface{}{
"id": account.ID,
"email": account.Email,
"userId": account.UserId,
"nickname": account.Nickname,
"accessToken": account.AccessToken,
"refreshToken": account.RefreshToken,
"clientId": account.ClientID,
"clientSecret": account.ClientSecret,
"authMethod": account.AuthMethod,
"provider": account.Provider,
"region": account.Region,
"expiresAt": account.ExpiresAt,
"machineId": account.MachineId,
"enabled": account.Enabled,
"banStatus": account.BanStatus,
"banReason": account.BanReason,
"banTime": account.BanTime,
"subscriptionType": account.SubscriptionType,
"subscriptionTitle": account.SubscriptionTitle,
"daysRemaining": account.DaysRemaining,
"usageCurrent": account.UsageCurrent,
"usageLimit": account.UsageLimit,
"usagePercent": account.UsagePercent,
"nextResetDate": account.NextResetDate,
"lastRefresh": account.LastRefresh,
"trialUsageCurrent": account.TrialUsageCurrent,
"trialUsageLimit": account.TrialUsageLimit,
"trialUsagePercent": account.TrialUsagePercent,
"trialStatus": account.TrialStatus,
"trialExpiresAt": account.TrialExpiresAt,
"requestCount": stats.RequestCount,
"errorCount": stats.ErrorCount,
"totalTokens": stats.TotalTokens,
"totalCredits": stats.TotalCredits,
"lastUsed": stats.LastUsed,
}
json.NewEncoder(w).Encode(result)
}
// apiGetAccountModels 获取账户可用模型
func (h *Handler) apiGetAccountModels(w http.ResponseWriter, r *http.Request, id string) {
accounts := config.GetAccounts()
var account *config.Account
for i := range accounts {
if accounts[i].ID == id {
account = &accounts[i]
break
}
}
if account == nil {
w.WriteHeader(404)
json.NewEncoder(w).Encode(map[string]string{"error": "Account not found"})
return
}
models, err := ListAvailableModels(account)
if err != nil {
w.WriteHeader(500)
json.NewEncoder(w).Encode(map[string]string{"error": err.Error()})
return
}
json.NewEncoder(w).Encode(map[string]interface{}{
"success": true,
"models": models,
})
}
// ==================== 静态文件服务 ====================
func (h *Handler) serveAdminPage(w http.ResponseWriter, r *http.Request) {
http.ServeFile(w, r, "web/index.html")
}
func (h *Handler) serveStaticFile(w http.ResponseWriter, r *http.Request) {
path := strings.TrimPrefix(r.URL.Path, "/admin/")
http.ServeFile(w, r, "web/"+path)
}
// apiGetThinkingConfig 获取 thinking 配置
func (h *Handler) apiGetThinkingConfig(w http.ResponseWriter, r *http.Request) {
cfg := config.GetThinkingConfig()
json.NewEncoder(w).Encode(map[string]interface{}{
"suffix": cfg.Suffix,
"openaiFormat": cfg.OpenAIFormat,
"claudeFormat": cfg.ClaudeFormat,
})
}
// apiUpdateThinkingConfig 更新 thinking 配置
func (h *Handler) apiUpdateThinkingConfig(w http.ResponseWriter, r *http.Request) {
var req struct {
Suffix string `json:"suffix"`
OpenAIFormat string `json:"openaiFormat"`
ClaudeFormat string `json:"claudeFormat"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
w.WriteHeader(400)
json.NewEncoder(w).Encode(map[string]string{"error": "Invalid JSON"})
return
}
// 验证格式
validFormats := map[string]bool{"reasoning_content": true, "thinking": true, "think": true}
if req.OpenAIFormat != "" && !validFormats[req.OpenAIFormat] {
w.WriteHeader(400)
json.NewEncoder(w).Encode(map[string]string{"error": "Invalid openaiFormat, must be: reasoning_content, thinking, or think"})
return
}
if req.ClaudeFormat != "" && !validFormats[req.ClaudeFormat] {
w.WriteHeader(400)
json.NewEncoder(w).Encode(map[string]string{"error": "Invalid claudeFormat, must be: reasoning_content, thinking, or think"})
return
}
if err := config.UpdateThinkingConfig(req.Suffix, req.OpenAIFormat, req.ClaudeFormat); err != nil {
w.WriteHeader(500)
json.NewEncoder(w).Encode(map[string]string{"error": err.Error()})
return
}
json.NewEncoder(w).Encode(map[string]bool{"success": true})
}
// apiGetEndpointConfig 获取端点配置
func (h *Handler) apiGetEndpointConfig(w http.ResponseWriter, r *http.Request) {
json.NewEncoder(w).Encode(map[string]string{
"preferredEndpoint": config.GetPreferredEndpoint(),
})
}
// apiUpdateEndpointConfig 更新端点配置
func (h *Handler) apiUpdateEndpointConfig(w http.ResponseWriter, r *http.Request) {
var req struct {
PreferredEndpoint string `json:"preferredEndpoint"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
w.WriteHeader(400)
json.NewEncoder(w).Encode(map[string]string{"error": "Invalid JSON"})
return
}
valid := map[string]bool{"auto": true, "codewhisperer": true, "amazonq": true}
if !valid[req.PreferredEndpoint] {
w.WriteHeader(400)
json.NewEncoder(w).Encode(map[string]string{"error": "Invalid endpoint, must be: auto, codewhisperer, or amazonq"})
return
}
if err := config.UpdatePreferredEndpoint(req.PreferredEndpoint); err != nil {
w.WriteHeader(500)
json.NewEncoder(w).Encode(map[string]string{"error": err.Error()})
return
}
json.NewEncoder(w).Encode(map[string]bool{"success": true})
}
// applyProxyConfig 将代理配置应用到所有出站 HTTP 客户端Kiro API + auth 模块)
func applyProxyConfig(proxyURL string) {
InitKiroHttpClient(proxyURL)
auth.InitHttpClient(proxyURL)
}
// apiGetProxy 获取当前代理配置
func (h *Handler) apiGetProxy(w http.ResponseWriter, r *http.Request) {
json.NewEncoder(w).Encode(map[string]string{
"proxyURL": config.GetProxyURL(),
})
}
// apiUpdateProxy 更新代理配置并立即生效
func (h *Handler) apiUpdateProxy(w http.ResponseWriter, r *http.Request) {
var req struct {
ProxyURL string `json:"proxyURL"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
w.WriteHeader(400)
json.NewEncoder(w).Encode(map[string]string{"error": "Invalid JSON"})
return
}
// 验证代理 URL 格式(非空时)
if req.ProxyURL != "" {
if !strings.HasPrefix(req.ProxyURL, "http://") &&
!strings.HasPrefix(req.ProxyURL, "https://") &&
!strings.HasPrefix(req.ProxyURL, "socks5://") &&
!strings.HasPrefix(req.ProxyURL, "socks5h://") {
w.WriteHeader(400)
json.NewEncoder(w).Encode(map[string]string{"error": "proxyURL must start with http://, https://, socks5://, or socks5h://"})
return
}
}
if err := config.UpdateProxySettings(req.ProxyURL); err != nil {
w.WriteHeader(500)
json.NewEncoder(w).Encode(map[string]string{"error": err.Error()})
return
}
// 立即应用新的代理配置
applyProxyConfig(req.ProxyURL)
json.NewEncoder(w).Encode(map[string]bool{"success": true})
}
// apiGetGeneralConfig 获取通用设置
func (h *Handler) apiGetGeneralConfig(w http.ResponseWriter, r *http.Request) {
json.NewEncoder(w).Encode(map[string]interface{}{
"invalidModelRetries": config.GetInvalidModelRetries(),
"firstByteTimeoutSec": config.GetFirstByteTimeoutSec(),
"firstByteRetries": config.GetFirstByteRetries(),
})
}
// apiUpdateGeneralConfig 更新通用设置
func (h *Handler) apiUpdateGeneralConfig(w http.ResponseWriter, r *http.Request) {
var req struct {
InvalidModelRetries *int `json:"invalidModelRetries"`
FirstByteTimeoutSec *int `json:"firstByteTimeoutSec"`
FirstByteRetries *int `json:"firstByteRetries"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
w.WriteHeader(400)
json.NewEncoder(w).Encode(map[string]string{"error": "Invalid JSON"})
return
}
if req.InvalidModelRetries != nil {
n := *req.InvalidModelRetries
if n < 0 || n > 20 {
w.WriteHeader(400)
json.NewEncoder(w).Encode(map[string]string{"error": "invalidModelRetries must be 0-20"})
return
}
if err := config.UpdateInvalidModelRetries(n); err != nil {
w.WriteHeader(500)
json.NewEncoder(w).Encode(map[string]string{"error": err.Error()})
return
}
}
if req.FirstByteTimeoutSec != nil {
n := *req.FirstByteTimeoutSec
if n < 0 || n > 300 {
w.WriteHeader(400)
json.NewEncoder(w).Encode(map[string]string{"error": "firstByteTimeoutSec must be 0-300"})
return
}
if err := config.UpdateFirstByteTimeoutSec(n); err != nil {
w.WriteHeader(500)
json.NewEncoder(w).Encode(map[string]string{"error": err.Error()})
return
}
}
if req.FirstByteRetries != nil {
n := *req.FirstByteRetries
if n < 0 || n > 10 {
w.WriteHeader(400)
json.NewEncoder(w).Encode(map[string]string{"error": "firstByteRetries must be 0-10"})
return
}
if err := config.UpdateFirstByteRetries(n); err != nil {
w.WriteHeader(500)
json.NewEncoder(w).Encode(map[string]string{"error": err.Error()})
return
}
}
json.NewEncoder(w).Encode(map[string]bool{"success": true})
}
// apiGetVersion 获取版本信息
func (h *Handler) apiGetVersion(w http.ResponseWriter, r *http.Request) {
json.NewEncoder(w).Encode(map[string]string{
"version": config.Version,
})
}
// apiExportAccounts 导出账号凭证
func (h *Handler) apiExportAccounts(w http.ResponseWriter, r *http.Request) {
var req struct {
IDs []string `json:"ids"` // 为空则导出全部
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
// 如果 body 为空或解析失败,导出全部
req.IDs = nil
}
accounts := config.GetAccounts()
// 如果指定了 ID只导出指定的
if len(req.IDs) > 0 {
idSet := make(map[string]bool)
for _, id := range req.IDs {
idSet[id] = true
}
var filtered []config.Account
for _, a := range accounts {
if idSet[a.ID] {
filtered = append(filtered, a)
}
}
accounts = filtered
}
// 构建兼容 Kiro Account Manager 的导出格式
type ExportCredentials struct {
AccessToken string `json:"accessToken"`
CsrfToken string `json:"csrfToken"`
RefreshToken string `json:"refreshToken"`
ClientID string `json:"clientId,omitempty"`
ClientSecret string `json:"clientSecret,omitempty"`
Region string `json:"region,omitempty"`
ExpiresAt int64 `json:"expiresAt"`
AuthMethod string `json:"authMethod,omitempty"`
Provider string `json:"provider,omitempty"`
}
type ExportSubscription struct {
Type string `json:"type"`
Title string `json:"title,omitempty"`
}
type ExportUsage struct {
Current float64 `json:"current"`
Limit float64 `json:"limit"`
PercentUsed float64 `json:"percentUsed"`
LastUpdated int64 `json:"lastUpdated"`
}
type ExportAccount struct {
ID string `json:"id"`
Email string `json:"email"`
Nickname string `json:"nickname,omitempty"`
Idp string `json:"idp"`
UserId string `json:"userId,omitempty"`
MachineId string `json:"machineId,omitempty"`
Credentials ExportCredentials `json:"credentials"`
Subscription ExportSubscription `json:"subscription"`
Usage ExportUsage `json:"usage"`
Tags []string `json:"tags"`
Status string `json:"status"`
CreatedAt int64 `json:"createdAt"`
LastUsedAt int64 `json:"lastUsedAt"`
}
type ExportData struct {
Version string `json:"version"`
ExportedAt int64 `json:"exportedAt"`
Accounts []ExportAccount `json:"accounts"`
Groups []interface{} `json:"groups"`
Tags []interface{} `json:"tags"`
}
exportAccounts := make([]ExportAccount, 0, len(accounts))
for _, a := range accounts {
// 映射 provider 到 idp
idp := a.Provider
if idp == "" {
if a.AuthMethod == "social" {
idp = "Google"
} else {
idp = "BuilderId"
}
}
// 映射 authMethod
authMethod := a.AuthMethod
if authMethod == "idc" {
authMethod = "IdC"
}
// 映射订阅类型
subType := "Free"
rawType := strings.ToUpper(a.SubscriptionType)
if strings.Contains(rawType, "PRO_PLUS") || strings.Contains(rawType, "PROPLUS") {
subType = "Pro_Plus"
} else if strings.Contains(rawType, "PRO") {
subType = "Pro"
} else if strings.Contains(rawType, "POWER") {
subType = "Pro_Plus"
}
exportAccounts = append(exportAccounts, ExportAccount{
ID: a.ID,
Email: a.Email,
Nickname: a.Nickname,
Idp: idp,
UserId: a.UserId,
MachineId: a.MachineId,
Credentials: ExportCredentials{
AccessToken: a.AccessToken,
CsrfToken: "",
RefreshToken: a.RefreshToken,
ClientID: a.ClientID,
ClientSecret: a.ClientSecret,
Region: a.Region,
ExpiresAt: a.ExpiresAt * 1000, // 转为毫秒时间戳
AuthMethod: authMethod,
Provider: a.Provider,
},
Subscription: ExportSubscription{
Type: subType,
Title: a.SubscriptionTitle,
},
Usage: ExportUsage{
Current: a.UsageCurrent,
Limit: a.UsageLimit,
PercentUsed: a.UsagePercent,
LastUpdated: time.Now().UnixMilli(),
},
Tags: []string{},
Status: "active",
CreatedAt: time.Now().UnixMilli(),
LastUsedAt: time.Now().UnixMilli(),
})
}
data := ExportData{
Version: config.Version,
ExportedAt: time.Now().UnixMilli(),
Accounts: exportAccounts,
Groups: []interface{}{},
Tags: []interface{}{},
}
json.NewEncoder(w).Encode(data)
}