package proxy
import (
"encoding/json"
"fmt"
"io"
"kiro-go/auth"
"kiro-go/config"
"kiro-go/pool"
"net/http"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/google/uuid"
)
// Handler HTTP 处理器
type Handler struct {
pool *pool.AccountPool
// 运行时统计 (使用原子操作)
totalRequests int64
successRequests int64
failedRequests int64
totalTokens int64
totalCredits float64 // float64 需要用锁保护
creditsMu sync.RWMutex
startTime int64
stopRefresh chan struct{}
stopStatsSaver chan struct{}
// 模型缓存
cachedModels []ModelInfo
modelsCacheMu sync.RWMutex
modelsCacheTime int64
promptCache *promptCacheTracker
}
type thinkingStreamSource int
const (
thinkingSourceUnknown thinkingStreamSource = iota
thinkingSourceReasoningEvent
thinkingSourceTagBlock
)
func allowReasoningSource(source *thinkingStreamSource) bool {
if *source == thinkingSourceTagBlock {
return false
}
*source = thinkingSourceReasoningEvent
return true
}
func allowTagSource(source *thinkingStreamSource) bool {
if *source == thinkingSourceReasoningEvent {
return false
}
if *source == thinkingSourceUnknown {
*source = thinkingSourceTagBlock
}
return *source == thinkingSourceTagBlock
}
func validateClaudeRequestShape(req *ClaudeRequest) string {
if len(req.Messages) == 0 {
return "messages must not be empty"
}
if msg := validateClaudeThinkingConfig(req.Thinking, req.MaxTokens); msg != "" {
return msg
}
hasUserContext := false
lastRole := ""
for _, msg := range req.Messages {
role := strings.TrimSpace(msg.Role)
if role == "" {
continue
}
lastRole = role
if role != "user" {
continue
}
text, images, toolResults := extractClaudeUserContent(msg.Content)
if normalizeUserContent(text, len(images) > 0) != "" || len(toolResults) > 0 {
hasUserContext = true
}
}
if lastRole == "assistant" {
return "assistant-prefill final message is not supported; last message must be user"
}
if !hasUserContext {
return "at least one non-empty user message is required"
}
return ""
}
func validateClaudeThinkingConfig(thinking *ClaudeThinkingConfig, maxTokens int) string {
if thinking == nil {
return ""
}
kind := strings.ToLower(strings.TrimSpace(thinking.Type))
switch kind {
case "enabled":
if maxTokens == 0 {
return "thinking.type enabled cannot be used with max_tokens=0"
}
if thinking.BudgetTokens <= 0 {
return "thinking.budget_tokens is required when thinking.type is enabled"
}
if thinking.BudgetTokens < 1024 {
return "thinking.budget_tokens must be at least 1024"
}
if maxTokens > 0 && thinking.BudgetTokens >= maxTokens {
return "thinking.budget_tokens must be less than max_tokens"
}
case "adaptive":
if thinking.BudgetTokens != 0 {
return "thinking.budget_tokens is not supported when thinking.type is adaptive"
}
case "disabled":
if thinking.BudgetTokens != 0 {
return "thinking.budget_tokens is not supported when thinking.type is disabled"
}
default:
return "thinking.type must be one of: enabled, adaptive, disabled"
}
display := strings.ToLower(strings.TrimSpace(thinking.Display))
if display != "" && display != "summarized" && display != "omitted" {
return "thinking.display must be one of: summarized, omitted"
}
if kind == "disabled" && display != "" {
return "thinking.display is not supported when thinking.type is disabled"
}
return ""
}
type claudeThinkingResponseOptions struct {
Format string
OmitDisplay bool
}
func resolveClaudeThinkingResponseOptions(thinking *ClaudeThinkingConfig, defaultFormat string) claudeThinkingResponseOptions {
opts := claudeThinkingResponseOptions{Format: defaultFormat}
if opts.Format == "" {
opts.Format = "thinking"
}
if thinking == nil {
return opts
}
display := strings.ToLower(strings.TrimSpace(thinking.Display))
switch display {
case "summarized":
opts.Format = "thinking"
case "omitted":
opts.Format = "thinking"
opts.OmitDisplay = true
}
return opts
}
func validateOpenAIRequestShape(req *OpenAIRequest) string {
if len(req.Messages) == 0 {
return "messages must not be empty"
}
hasNonSystem := false
hasUserContext := false
lastRole := ""
for _, msg := range req.Messages {
role := strings.TrimSpace(msg.Role)
if role == "" {
continue
}
if role != "system" {
hasNonSystem = true
lastRole = role
}
if role != "user" {
continue
}
text, images := extractOpenAIUserContent(msg.Content)
if normalizeUserContent(text, len(images) > 0) != "" {
hasUserContext = true
}
}
if !hasNonSystem {
return "at least one non-system message is required"
}
if lastRole == "assistant" {
return "assistant-prefill final message is not supported; last message must be user or tool"
}
if !hasUserContext {
return "at least one non-empty user message is required"
}
return ""
}
func NewHandler() *Handler {
// 启动时应用代理配置
applyProxyConfig(config.GetProxyURL())
totalReq, successReq, failedReq, totalTokens, totalCredits := config.GetStats()
h := &Handler{
pool: pool.GetPool(),
totalRequests: int64(totalReq),
successRequests: int64(successReq),
failedRequests: int64(failedReq),
totalTokens: int64(totalTokens),
totalCredits: totalCredits,
startTime: time.Now().Unix(),
stopRefresh: make(chan struct{}),
stopStatsSaver: make(chan struct{}),
promptCache: newPromptCacheTracker(defaultPromptCacheTTL),
}
// 启动后台刷新
go h.backgroundRefresh()
// 启动后台统计保存 (每30秒保存一次)
go h.backgroundStatsSaver()
return h
}
// backgroundRefresh 后台定时刷新账户信息
func (h *Handler) backgroundRefresh() {
ticker := time.NewTicker(30 * time.Minute) // 每 30 分钟刷新一次
defer ticker.Stop()
// 启动时延迟 10 秒后执行一次
time.Sleep(10 * time.Second)
h.refreshModelsCache()
h.refreshAllAccounts()
for {
select {
case <-ticker.C:
h.refreshModelsCache()
h.refreshAllAccounts()
case <-h.stopRefresh:
return
}
}
}
// refreshAllAccounts 刷新所有账户信息
func (h *Handler) refreshAllAccounts() {
accounts := config.GetAccounts()
for i := range accounts {
account := &accounts[i]
if !account.Enabled || account.AccessToken == "" {
continue
}
// 检查 token 是否需要刷新
if account.ExpiresAt > 0 && time.Now().Unix() > account.ExpiresAt-300 {
newAccessToken, newRefreshToken, newExpiresAt, err := auth.RefreshToken(account)
if err != nil {
fmt.Printf("[BackgroundRefresh] Token refresh failed for %s: %v\n", account.Email, err)
continue
}
account.AccessToken = newAccessToken
if newRefreshToken != "" {
account.RefreshToken = newRefreshToken
}
account.ExpiresAt = newExpiresAt
config.UpdateAccountToken(account.ID, newAccessToken, newRefreshToken, newExpiresAt)
h.pool.UpdateToken(account.ID, newAccessToken, newRefreshToken, newExpiresAt)
}
// 刷新账户信息
info, err := RefreshAccountInfo(account)
if err != nil {
fmt.Printf("[BackgroundRefresh] Failed to refresh %s: %v\n", account.Email, err)
continue
}
config.UpdateAccountInfo(account.ID, *info)
fmt.Printf("[BackgroundRefresh] Refreshed %s: %s %.1f/%.1f\n", account.Email, info.SubscriptionType, info.UsageCurrent, info.UsageLimit)
}
h.pool.Reload()
}
// validateApiKey 验证 API Key
func (h *Handler) validateApiKey(r *http.Request) bool {
if !config.IsApiKeyRequired() {
return true
}
expectedKey := config.GetApiKey()
if expectedKey == "" {
return true
}
// 从 Authorization 头或 X-Api-Key 头获取
authHeader := r.Header.Get("Authorization")
apiKeyHeader := r.Header.Get("X-Api-Key")
var providedKey string
if strings.HasPrefix(authHeader, "Bearer ") {
providedKey = strings.TrimPrefix(authHeader, "Bearer ")
} else if apiKeyHeader != "" {
providedKey = apiKeyHeader
}
return providedKey == expectedKey
}
// ServeHTTP 路由分发
func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
path := r.URL.Path
// CORS - 完整的头部支持
w.Header().Set("Access-Control-Allow-Origin", "*")
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization, X-Api-Key, anthropic-version, anthropic-beta, x-api-key, x-stainless-os, x-stainless-lang, x-stainless-package-version, x-stainless-runtime, x-stainless-runtime-version, x-stainless-arch")
w.Header().Set("Access-Control-Expose-Headers", "x-request-id, x-ratelimit-limit-requests, x-ratelimit-limit-tokens, x-ratelimit-remaining-requests, x-ratelimit-remaining-tokens, x-ratelimit-reset-requests, x-ratelimit-reset-tokens")
if r.Method == "OPTIONS" {
w.WriteHeader(204)
return
}
// 路由
switch {
// API 端点(需要验证 API Key)
case path == "/v1/messages" || path == "/messages" || path == "/anthropic/v1/messages":
if !h.validateApiKey(r) {
h.sendClaudeError(w, 401, "authentication_error", "Invalid or missing API key")
return
}
h.handleClaudeMessages(w, r)
case path == "/v1/messages/count_tokens" || path == "/messages/count_tokens":
if !h.validateApiKey(r) {
h.sendClaudeError(w, 401, "authentication_error", "Invalid or missing API key")
return
}
h.handleCountTokens(w, r)
case path == "/v1/chat/completions" || path == "/chat/completions":
if !h.validateApiKey(r) {
h.sendOpenAIError(w, 401, "authentication_error", "Invalid or missing API key")
return
}
h.handleOpenAIChat(w, r)
case path == "/v1/models" || path == "/models":
h.handleModels(w, r)
case path == "/api/event_logging/batch":
// Claude Code 遥测端点 - 直接返回 200 OK
w.Header().Set("Content-Type", "application/json; charset=utf-8")
w.Write([]byte(`{"status":"ok"}`))
// 管理端点
case path == "/admin" || path == "/admin/":
h.serveAdminPage(w, r)
case strings.HasPrefix(path, "/admin/api/"):
h.handleAdminAPI(w, r)
case strings.HasPrefix(path, "/admin/"):
h.serveStaticFile(w, r)
// 健康检查
case path == "/health" || path == "/":
h.handleHealth(w, r)
// 统计端点(需要 API Key 鉴权)
case path == "/v1/stats":
if !h.validateApiKey(r) {
w.Header().Set("Content-Type", "application/json; charset=utf-8")
w.WriteHeader(401)
json.NewEncoder(w).Encode(map[string]string{"error": "Invalid or missing API key"})
return
}
h.handleStats(w, r)
default:
http.Error(w, "Not Found", 404)
}
}
// handleHealth 健康检查(不暴露统计数据)
func (h *Handler) handleHealth(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json; charset=utf-8")
json.NewEncoder(w).Encode(map[string]interface{}{
"status": "ok",
"version": config.Version,
"uptime": time.Now().Unix() - h.startTime,
})
}
// handleStats 统计数据(需要 API Key 鉴权)
func (h *Handler) handleStats(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json; charset=utf-8")
json.NewEncoder(w).Encode(map[string]interface{}{
"status": "ok",
"version": config.Version,
"accounts": h.pool.Count(),
"available": h.pool.AvailableCount(),
"totalRequests": atomic.LoadInt64(&h.totalRequests),
"successRequests": atomic.LoadInt64(&h.successRequests),
"failedRequests": atomic.LoadInt64(&h.failedRequests),
"totalTokens": atomic.LoadInt64(&h.totalTokens),
"totalCredits": h.getCredits(),
"uptime": time.Now().Unix() - h.startTime,
})
}
// handleModels 模型列表
func (h *Handler) handleModels(w http.ResponseWriter, r *http.Request) {
// 尝试用缓存的真实模型列表
h.modelsCacheMu.RLock()
cached := h.cachedModels
h.modelsCacheMu.RUnlock()
if len(cached) == 0 {
h.refreshModelsCache()
h.modelsCacheMu.RLock()
cached = h.cachedModels
h.modelsCacheMu.RUnlock()
}
thinkingSuffix := config.GetThinkingConfig().Suffix
models := buildAnthropicModelsResponse(cached, thinkingSuffix)
if len(models) == 0 {
models = fallbackAnthropicModels(thinkingSuffix)
}
// 添加别名模型
models = append(models,
buildModelInfo("auto", "kiro-proxy", true),
buildModelInfo("gpt-4o", "kiro-proxy", true),
buildModelInfo("gpt-4", "kiro-proxy", true),
)
w.Header().Set("Content-Type", "application/json; charset=utf-8")
json.NewEncoder(w).Encode(map[string]interface{}{
"object": "list",
"data": models,
})
return
}
func buildAnthropicModelsResponse(cached []ModelInfo, thinkingSuffix string) []map[string]interface{} {
if len(cached) == 0 {
return nil
}
models := make([]map[string]interface{}, 0, len(cached)*2)
if len(cached) > 0 {
for _, m := range cached {
supportsImage := modelSupportsImage(m.InputTypes)
models = append(models, buildModelInfo(m.ModelId, "anthropic", supportsImage))
// 自动生成 thinking 变体
models = append(models, buildModelInfo(m.ModelId+thinkingSuffix, "anthropic", supportsImage))
}
}
return models
}
func fallbackAnthropicModels(thinkingSuffix string) []map[string]interface{} {
return []map[string]interface{}{
buildModelInfo("claude-sonnet-4.6", "anthropic", true),
buildModelInfo("claude-sonnet-4.6"+thinkingSuffix, "anthropic", true),
buildModelInfo("claude-opus-4.6", "anthropic", true),
buildModelInfo("claude-opus-4.6"+thinkingSuffix, "anthropic", true),
buildModelInfo("claude-opus-4.7", "anthropic", true),
buildModelInfo("claude-opus-4.7"+thinkingSuffix, "anthropic", true),
buildModelInfo("claude-sonnet-4.5", "anthropic", true),
buildModelInfo("claude-sonnet-4.5"+thinkingSuffix, "anthropic", true),
buildModelInfo("claude-sonnet-4", "anthropic", true),
buildModelInfo("claude-sonnet-4"+thinkingSuffix, "anthropic", true),
buildModelInfo("claude-haiku-4.5", "anthropic", true),
buildModelInfo("claude-haiku-4.5"+thinkingSuffix, "anthropic", true),
buildModelInfo("claude-opus-4.5", "anthropic", true),
buildModelInfo("claude-opus-4.5"+thinkingSuffix, "anthropic", true),
}
}
func modelSupportsImage(inputTypes []string) bool {
for _, t := range inputTypes {
lt := strings.ToLower(t)
if strings.Contains(lt, "image") || strings.Contains(lt, "vision") {
return true
}
}
return false
}
func buildModelInfo(id, ownedBy string, supportsImage bool) map[string]interface{} {
modalities := []string{"text"}
if supportsImage {
modalities = append(modalities, "image")
}
modalitiesMap := map[string][]string{
"input": modalities,
"output": []string{"text"},
}
return map[string]interface{}{
"id": id,
"object": "model",
"owned_by": ownedBy,
"supports_image": supportsImage,
"input_modalities": modalities,
"modalities": modalitiesMap,
"capabilities": map[string]bool{
"vision": supportsImage,
"image": supportsImage,
"image_vision": supportsImage,
},
"info": map[string]interface{}{
"meta": map[string]interface{}{
"capabilities": map[string]bool{
"vision": supportsImage,
"image_vision": supportsImage,
},
},
},
}
}
// refreshModelsCache 从 Kiro API 拉取模型列表并缓存
func (h *Handler) refreshModelsCache() {
accounts := config.GetEnabledAccounts()
if len(accounts) == 0 {
return
}
aggregated := make([]ModelInfo, 0)
for i := range accounts {
account := &accounts[i]
if err := h.ensureValidToken(account); err != nil {
fmt.Printf("[ModelsCache] Skip %s token refresh failed: %v\n", account.Email, err)
continue
}
models, err := ListAvailableModels(account)
if err != nil {
fmt.Printf("[ModelsCache] Failed to refresh for %s: %v\n", account.Email, err)
continue
}
aggregated = mergeUniqueModels(aggregated, models)
}
if len(aggregated) > 0 {
h.modelsCacheMu.Lock()
h.cachedModels = aggregated
h.modelsCacheTime = time.Now().Unix()
h.modelsCacheMu.Unlock()
fmt.Printf("[ModelsCache] Cached %d models\n", len(aggregated))
}
}
func mergeUniqueModels(existing []ModelInfo, incoming []ModelInfo) []ModelInfo {
if len(incoming) == 0 {
return existing
}
indexByID := make(map[string]int, len(existing))
merged := make([]ModelInfo, len(existing))
copy(merged, existing)
for i, model := range merged {
indexByID[strings.ToLower(strings.TrimSpace(model.ModelId))] = i
}
for _, model := range incoming {
key := strings.ToLower(strings.TrimSpace(model.ModelId))
if key == "" {
continue
}
if idx, ok := indexByID[key]; ok {
merged[idx] = mergeModelInfo(merged[idx], model)
continue
}
indexByID[key] = len(merged)
merged = append(merged, model)
}
return merged
}
func mergeModelInfo(base ModelInfo, extra ModelInfo) ModelInfo {
if base.ModelName == "" {
base.ModelName = extra.ModelName
}
if base.Description == "" {
base.Description = extra.Description
}
if base.RateMultiplier == 0 {
base.RateMultiplier = extra.RateMultiplier
}
if base.TokenLimits == nil {
base.TokenLimits = extra.TokenLimits
}
base.InputTypes = mergeStringLists(base.InputTypes, extra.InputTypes)
return base
}
func mergeStringLists(base []string, extra []string) []string {
if len(extra) == 0 {
return base
}
seen := make(map[string]bool, len(base)+len(extra))
merged := make([]string, 0, len(base)+len(extra))
for _, item := range base {
key := strings.ToLower(strings.TrimSpace(item))
if key == "" || seen[key] {
continue
}
seen[key] = true
merged = append(merged, item)
}
for _, item := range extra {
key := strings.ToLower(strings.TrimSpace(item))
if key == "" || seen[key] {
continue
}
seen[key] = true
merged = append(merged, item)
}
return merged
}
// handleCountTokens Token 计数(Claude Code 会调用)
func (h *Handler) handleCountTokens(w http.ResponseWriter, r *http.Request) {
if r.Method != "POST" {
http.Error(w, "Method Not Allowed", 405)
return
}
body, err := io.ReadAll(r.Body)
if err != nil {
h.sendClaudeError(w, 400, "invalid_request_error", "Failed to read request body")
return
}
var req ClaudeRequest
if err := json.Unmarshal(body, &req); err != nil {
h.sendClaudeError(w, 400, "invalid_request_error", "Invalid JSON")
return
}
if msg := validateClaudeThinkingConfig(req.Thinking, req.MaxTokens); msg != "" {
h.sendClaudeError(w, 400, "invalid_request_error", msg)
return
}
thinkingCfg := config.GetThinkingConfig()
actualModel, thinking := resolveClaudeThinkingMode(req.Model, req.Thinking, thinkingCfg.Suffix)
req.Model = actualModel
effectiveReq := cloneClaudeRequestForThinking(&req, thinking)
estimatedTokens := estimateClaudeRequestInputTokens(effectiveReq)
if estimatedTokens < 1 {
estimatedTokens = 1
}
w.Header().Set("Content-Type", "application/json; charset=utf-8")
json.NewEncoder(w).Encode(map[string]int{"input_tokens": estimatedTokens})
}
// handleClaudeMessages Claude API 处理
func (h *Handler) handleClaudeMessages(w http.ResponseWriter, r *http.Request) {
h.handleClaudeMessagesInternal(w, r)
}
func (h *Handler) handleClaudeMessagesInternal(w http.ResponseWriter, r *http.Request) {
if r.Method != "POST" {
http.Error(w, "Method Not Allowed", 405)
return
}
// 读取请求
body, err := io.ReadAll(r.Body)
if err != nil {
h.sendClaudeError(w, 400, "invalid_request_error", "Failed to read request body")
return
}
var req ClaudeRequest
if err := json.Unmarshal(body, &req); err != nil {
h.sendClaudeError(w, 400, "invalid_request_error", "Invalid JSON: "+err.Error())
return
}
if msg := validateClaudeRequestShape(&req); msg != "" {
h.sendClaudeError(w, 400, "invalid_request_error", msg)
return
}
// 获取账号
account := h.pool.GetNext()
if account == nil {
h.sendClaudeError(w, 503, "api_error", "No available accounts")
return
}
// 检查并刷新 token
if err := h.ensureValidToken(account); err != nil {
h.sendClaudeError(w, 503, "api_error", "Token refresh failed: "+err.Error())
return
}
// 解析模型和 thinking 模式
thinkingCfg := config.GetThinkingConfig()
actualModel, thinking := resolveClaudeThinkingMode(req.Model, req.Thinking, thinkingCfg.Suffix)
req.Model = actualModel
effectiveReq := cloneClaudeRequestForThinking(&req, thinking)
thinkingResponseOpts := resolveClaudeThinkingResponseOptions(req.Thinking, thinkingCfg.ClaudeFormat)
estimatedInputTokens := estimateClaudeRequestInputTokens(effectiveReq)
cacheProfile := h.promptCache.BuildClaudeProfile(effectiveReq, estimatedInputTokens)
cacheUsage := h.promptCache.Compute(account.ID, cacheProfile)
// 转换请求
kiroPayload := ClaudeToKiro(&req, thinking)
// Stream or non-stream
if req.Stream {
h.handleClaudeStream(w, account, kiroPayload, req.Model, thinking, thinkingResponseOpts, estimatedInputTokens, cacheUsage, cacheProfile)
} else {
h.handleClaudeNonStream(w, account, kiroPayload, req.Model, thinking, thinkingResponseOpts, estimatedInputTokens, cacheUsage, cacheProfile)
}
}
// handleClaudeStream Claude 流式响应
func (h *Handler) handleClaudeStream(w http.ResponseWriter, account *config.Account, payload *KiroPayload, model string, thinking bool, thinkingOpts claudeThinkingResponseOptions, estimatedInputTokens int, cacheUsage promptCacheUsage, cacheProfile *promptCacheProfile) {
w.Header().Set("Content-Type", "text/event-stream; charset=utf-8")
w.Header().Set("Cache-Control", "no-cache")
w.Header().Set("Connection", "keep-alive")
flusher, ok := w.(http.Flusher)
if !ok {
h.sendClaudeError(w, 500, "api_error", "Streaming not supported")
return
}
// 获取 thinking 输出格式配置
thinkingFormat := thinkingOpts.Format
msgID := "msg_" + uuid.New().String()
var inputTokens, outputTokens int
var credits float64
var realInputTokens int
var toolUses []KiroToolUse
var nextContentIndex int
var rawContentBuilder strings.Builder
var rawThinkingBuilder strings.Builder
activeBlockIndex := -1
activeBlockType := ""
startInputTokens := estimatedInputTokens
closeActiveBlock := func() {
if activeBlockIndex < 0 {
return
}
h.sendSSE(w, flusher, "content_block_stop", map[string]interface{}{
"type": "content_block_stop",
"index": activeBlockIndex,
})
activeBlockIndex = -1
activeBlockType = ""
}
startContentBlock := func(blockType string) {
if activeBlockType == blockType {
return
}
closeActiveBlock()
idx := nextContentIndex
nextContentIndex++
if blockType == "thinking" {
h.sendSSE(w, flusher, "content_block_start", map[string]interface{}{
"type": "content_block_start",
"index": idx,
"content_block": map[string]string{
"type": "thinking",
"thinking": "",
},
})
} else {
h.sendSSE(w, flusher, "content_block_start", map[string]interface{}{
"type": "content_block_start",
"index": idx,
"content_block": map[string]string{
"type": "text",
"text": "",
},
})
}
activeBlockIndex = idx
activeBlockType = blockType
}
// Thinking 标签解析状态
var textBuffer string
var inThinkingBlock bool
var dropTagThinking bool
var thinkingSource thinkingStreamSource
// 发送文本的辅助函数
// thinkingState: 0=普通内容, 1=thinking开始, 2=thinking中间, 3=thinking结束
sendText := func(text string, thinkingState int) {
if thinkingState == 0 {
// 普通内容
if text == "" {
return
}
startContentBlock("text")
h.sendSSE(w, flusher, "content_block_delta", map[string]interface{}{
"type": "content_block_delta",
"index": activeBlockIndex,
"delta": map[string]string{"type": "text_delta", "text": text},
})
return
}
if !thinking {
return
}
switch thinkingFormat {
case "think":
var outputText string
switch thinkingState {
case 1:
outputText = "" + text
case 2:
outputText = text
case 3:
outputText = text + ""
}
if outputText == "" {
return
}
startContentBlock("text")
h.sendSSE(w, flusher, "content_block_delta", map[string]interface{}{
"type": "content_block_delta",
"index": activeBlockIndex,
"delta": map[string]string{"type": "text_delta", "text": outputText},
})
case "reasoning_content":
if text == "" {
return
}
startContentBlock("text")
h.sendSSE(w, flusher, "content_block_delta", map[string]interface{}{
"type": "content_block_delta",
"index": activeBlockIndex,
"delta": map[string]string{"type": "text_delta", "text": text},
})
default:
if thinkingOpts.OmitDisplay {
if thinkingState == 1 {
startContentBlock("thinking")
return
}
if thinkingState == 3 {
if activeBlockType != "thinking" {
startContentBlock("thinking")
}
closeActiveBlock()
}
return
}
if thinkingState == 3 && text == "" {
if activeBlockType == "thinking" {
closeActiveBlock()
}
return
}
if text != "" {
startContentBlock("thinking")
h.sendSSE(w, flusher, "content_block_delta", map[string]interface{}{
"type": "content_block_delta",
"index": activeBlockIndex,
"delta": map[string]string{"type": "thinking_delta", "thinking": text},
})
}
if thinkingState == 3 && activeBlockType == "thinking" {
closeActiveBlock()
}
}
}
// 处理文本,解析 标签
var thinkingStarted bool
var eventThinkingOpen bool
processClaudeText := func(text string, isThinking bool, forceFlush bool) {
if isThinking && !thinking {
return
}
// 如果是 reasoningContentEvent,直接输出
if isThinking {
if !allowReasoningSource(&thinkingSource) {
return
}
if !thinkingStarted {
sendText(text, 1)
thinkingStarted = true
eventThinkingOpen = true
} else {
sendText(text, 2)
}
return
}
if eventThinkingOpen {
sendText("", 3)
eventThinkingOpen = false
thinkingStarted = false
}
textBuffer += text
for {
if !inThinkingBlock {
thinkingStart := strings.Index(textBuffer, "")
if thinkingStart != -1 {
if thinkingStart > 0 {
sendText(textBuffer[:thinkingStart], 0)
}
textBuffer = textBuffer[thinkingStart+10:]
inThinkingBlock = true
dropTagThinking = !allowTagSource(&thinkingSource)
thinkingStarted = false
} else if forceFlush || len([]rune(textBuffer)) > 50 {
// 使用 rune 切片来正确处理 Unicode 字符
runes := []rune(textBuffer)
safeLen := len(runes)
if !forceFlush {
safeLen = max(0, len(runes)-15)
}
if safeLen > 0 {
sendText(string(runes[:safeLen]), 0)
textBuffer = string(runes[safeLen:])
}
break
} else {
break
}
} else {
thinkingEnd := strings.Index(textBuffer, "")
if thinkingEnd != -1 {
content := textBuffer[:thinkingEnd]
if !dropTagThinking {
if !thinkingStarted {
sendText(content, 1)
sendText("", 3)
} else {
sendText(content, 3)
}
}
textBuffer = textBuffer[thinkingEnd+11:]
inThinkingBlock = false
dropTagThinking = false
thinkingStarted = false
} else if forceFlush {
if textBuffer != "" {
if !dropTagThinking {
if !thinkingStarted {
sendText(textBuffer, 1)
sendText("", 3)
} else {
sendText(textBuffer, 3)
}
}
textBuffer = ""
}
inThinkingBlock = false
dropTagThinking = false
thinkingStarted = false
break
} else {
// 流式输出 thinking 块内的内容
runes := []rune(textBuffer)
if len(runes) > 20 {
safeLen := len(runes) - 15
if safeLen > 0 {
if !dropTagThinking {
if !thinkingStarted {
sendText(string(runes[:safeLen]), 1)
thinkingStarted = true
} else {
sendText(string(runes[:safeLen]), 2)
}
}
textBuffer = string(runes[safeLen:])
}
}
break
}
}
}
}
// 发送 message_start
h.sendSSE(w, flusher, "message_start", map[string]interface{}{
"type": "message_start",
"message": map[string]interface{}{
"id": msgID,
"type": "message",
"role": "assistant",
"content": []interface{}{},
"model": model,
"stop_reason": nil,
"stop_sequence": nil,
"usage": buildClaudeUsageMap(startInputTokens, 0, cacheUsage, cacheProfile != nil),
},
})
callback := &KiroStreamCallback{
OnText: func(text string, isThinking bool) {
if text == "" {
return
}
if isThinking {
rawThinkingBuilder.WriteString(text)
} else {
rawContentBuilder.WriteString(text)
}
processClaudeText(text, isThinking, false)
},
OnToolUse: func(tu KiroToolUse) {
// 先刷新缓冲区
processClaudeText("", false, true)
rawContentBuilder.WriteString(tu.Name)
if b, err := json.Marshal(tu.Input); err == nil {
rawContentBuilder.Write(b)
}
toolUses = append(toolUses, tu)
closeActiveBlock()
idx := nextContentIndex
nextContentIndex++
h.sendSSE(w, flusher, "content_block_start", map[string]interface{}{
"type": "content_block_start",
"index": idx,
"content_block": map[string]interface{}{
"type": "tool_use",
"id": tu.ToolUseID,
"name": tu.Name,
"input": map[string]interface{}{},
},
})
inputJSON, _ := json.Marshal(tu.Input)
h.sendSSE(w, flusher, "content_block_delta", map[string]interface{}{
"type": "content_block_delta",
"index": idx,
"delta": map[string]interface{}{
"type": "input_json_delta",
"partial_json": string(inputJSON),
},
})
h.sendSSE(w, flusher, "content_block_stop", map[string]interface{}{
"type": "content_block_stop",
"index": idx,
})
},
OnComplete: func(inTok, outTok int) {
inputTokens = inTok
outputTokens = outTok
},
OnError: func(err error) {
h.pool.RecordError(account.ID, strings.Contains(err.Error(), "429") || strings.Contains(err.Error(), "quota"))
},
OnCredits: func(c float64) {
credits = c
},
OnContextUsage: func(pct float64) {
realInputTokens = int(pct * float64(getContextWindowSize(model)) / 100.0)
},
}
err := CallKiroAPI(account, payload, callback)
if err != nil {
h.recordFailure()
h.pool.RecordError(account.ID, strings.Contains(err.Error(), "429") || strings.Contains(err.Error(), "quota"))
h.sendSSE(w, flusher, "error", map[string]interface{}{
"type": "error",
"error": map[string]string{"type": "api_error", "message": err.Error()},
})
return
}
// 刷新剩余缓冲区
processClaudeText("", false, true)
if eventThinkingOpen {
sendText("", 3)
eventThinkingOpen = false
}
closeActiveBlock()
if realInputTokens > 0 {
inputTokens = realInputTokens
} else if inputTokens <= 0 {
inputTokens = estimatedInputTokens
}
outputContent, extractedReasoning := extractThinkingFromContent(rawContentBuilder.String())
thinkingOutput := rawThinkingBuilder.String()
if thinking && thinkingOutput == "" && extractedReasoning != "" {
thinkingOutput = extractedReasoning
}
if !thinking {
thinkingOutput = ""
}
outputTokens = estimateClaudeOutputTokens(outputContent, thinkingOutput, toolUses)
h.recordSuccess(inputTokens, outputTokens, credits)
h.pool.RecordSuccess(account.ID)
h.pool.UpdateStats(account.ID, inputTokens+outputTokens, credits)
h.promptCache.Update(account.ID, cacheProfile)
// 发送 message_delta
stopReason := "end_turn"
if len(toolUses) > 0 {
stopReason = "tool_use"
}
h.sendSSE(w, flusher, "message_delta", map[string]interface{}{
"type": "message_delta",
"delta": map[string]interface{}{
"stop_reason": stopReason,
},
"usage": buildClaudeUsageMap(inputTokens, outputTokens, cacheUsage, cacheProfile != nil),
})
h.sendSSE(w, flusher, "message_stop", map[string]interface{}{
"type": "message_stop",
})
}
func (h *Handler) sendSSE(w http.ResponseWriter, flusher http.Flusher, event string, data interface{}) {
jsonData, _ := json.Marshal(data)
fmt.Fprintf(w, "event: %s\ndata: %s\n\n", event, string(jsonData))
flusher.Flush()
}
// backgroundStatsSaver 后台定时保存统计数据
func (h *Handler) backgroundStatsSaver() {
ticker := time.NewTicker(30 * time.Second)
defer ticker.Stop()
for {
select {
case <-ticker.C:
h.saveStats()
case <-h.stopStatsSaver:
h.saveStats() // 退出前保存一次
return
}
}
}
// saveStats 保存统计到配置文件
func (h *Handler) saveStats() {
config.UpdateStats(
int(atomic.LoadInt64(&h.totalRequests)),
int(atomic.LoadInt64(&h.successRequests)),
int(atomic.LoadInt64(&h.failedRequests)),
int(atomic.LoadInt64(&h.totalTokens)),
h.getCredits(),
)
}
// getCredits 线程安全获取 credits
func (h *Handler) getCredits() float64 {
h.creditsMu.RLock()
defer h.creditsMu.RUnlock()
return h.totalCredits
}
// addCredits 线程安全增加 credits
func (h *Handler) addCredits(credits float64) {
h.creditsMu.Lock()
h.totalCredits += credits
h.creditsMu.Unlock()
}
// 统计记录 (使用原子操作)
func (h *Handler) recordSuccess(inputTokens, outputTokens int, credits float64) {
atomic.AddInt64(&h.totalRequests, 1)
atomic.AddInt64(&h.successRequests, 1)
atomic.AddInt64(&h.totalTokens, int64(inputTokens+outputTokens))
h.addCredits(credits)
}
func (h *Handler) recordFailure() {
atomic.AddInt64(&h.totalRequests, 1)
atomic.AddInt64(&h.failedRequests, 1)
}
// handleClaudeNonStream Claude 非流式响应
func (h *Handler) handleClaudeNonStream(w http.ResponseWriter, account *config.Account, payload *KiroPayload, model string, thinking bool, thinkingOpts claudeThinkingResponseOptions, estimatedInputTokens int, cacheUsage promptCacheUsage, cacheProfile *promptCacheProfile) {
var content string
var thinkingContent string
var toolUses []KiroToolUse
var inputTokens, outputTokens int
var credits float64
var realInputTokens int
callback := &KiroStreamCallback{
OnText: func(text string, isThinking bool) {
if isThinking {
thinkingContent += text
} else {
content += text
}
},
OnToolUse: func(tu KiroToolUse) {
toolUses = append(toolUses, tu)
},
OnComplete: func(inTok, outTok int) {
inputTokens = inTok
outputTokens = outTok
},
OnError: func(err error) {
h.pool.RecordError(account.ID, strings.Contains(err.Error(), "429"))
},
OnCredits: func(c float64) {
credits = c
},
OnContextUsage: func(pct float64) {
realInputTokens = int(pct * float64(getContextWindowSize(model)) / 100.0)
},
}
err := CallKiroAPI(account, payload, callback)
if err != nil {
h.recordFailure()
h.pool.RecordError(account.ID, strings.Contains(err.Error(), "429"))
h.sendClaudeError(w, 500, "api_error", err.Error())
return
}
// 合并 thinking 内容(如果有 reasoningContentEvent 的内容)
thinkingFormat := thinkingOpts.Format
finalContent, extractedReasoning := extractThinkingFromContent(content)
rawThinkingContent := thinkingContent
if thinking && rawThinkingContent == "" && extractedReasoning != "" {
rawThinkingContent = extractedReasoning
}
if !thinking {
rawThinkingContent = ""
}
if realInputTokens > 0 {
inputTokens = realInputTokens
} else if inputTokens <= 0 {
inputTokens = estimatedInputTokens
}
outputTokens = estimateClaudeOutputTokens(finalContent, rawThinkingContent, toolUses)
h.recordSuccess(inputTokens, outputTokens, credits)
h.pool.RecordSuccess(account.ID)
h.pool.UpdateStats(account.ID, inputTokens+outputTokens, credits)
h.promptCache.Update(account.ID, cacheProfile)
responseThinkingContent := rawThinkingContent
includeEmptyThinkingBlock := thinking && thinkingOpts.OmitDisplay && rawThinkingContent != ""
if includeEmptyThinkingBlock {
responseThinkingContent = ""
}
if thinking && responseThinkingContent != "" {
switch thinkingFormat {
case "think":
finalContent = "" + responseThinkingContent + "" + finalContent
responseThinkingContent = ""
case "reasoning_content":
finalContent = responseThinkingContent + finalContent // Claude 格式不支持 reasoning_content,直接拼接
responseThinkingContent = ""
default:
}
}
resp := KiroToClaudeResponse(finalContent, responseThinkingContent, includeEmptyThinkingBlock, toolUses, inputTokens, outputTokens, model)
resp.Usage.InputTokens = inputTokens
resp.Usage.CacheCreationInputTokens = cacheUsage.CacheCreationInputTokens
resp.Usage.CacheReadInputTokens = cacheUsage.CacheReadInputTokens
if cacheProfile != nil {
resp.Usage.CacheCreation = &ClaudeCacheCreationUsage{
Ephemeral5mInputTokens: cacheUsage.CacheCreation5mInputTokens,
Ephemeral1hInputTokens: cacheUsage.CacheCreation1hInputTokens,
}
}
w.Header().Set("Content-Type", "application/json; charset=utf-8")
json.NewEncoder(w).Encode(resp)
}
func (h *Handler) sendClaudeError(w http.ResponseWriter, status int, errType, message string) {
w.Header().Set("Content-Type", "application/json; charset=utf-8")
w.WriteHeader(status)
json.NewEncoder(w).Encode(map[string]interface{}{
"type": "error",
"error": map[string]string{
"type": errType,
"message": message,
},
})
}
// handleOpenAIChat OpenAI API 处理
func (h *Handler) handleOpenAIChat(w http.ResponseWriter, r *http.Request) {
if r.Method != "POST" {
http.Error(w, "Method Not Allowed", 405)
return
}
body, err := io.ReadAll(r.Body)
if err != nil {
h.sendOpenAIError(w, 400, "invalid_request_error", "Failed to read request body")
return
}
var req OpenAIRequest
if err := json.Unmarshal(body, &req); err != nil {
h.sendOpenAIError(w, 400, "invalid_request_error", "Invalid JSON")
return
}
if msg := validateOpenAIRequestShape(&req); msg != "" {
h.sendOpenAIError(w, 400, "invalid_request_error", msg)
return
}
account := h.pool.GetNext()
if account == nil {
h.sendOpenAIError(w, 503, "server_error", "No available accounts")
return
}
if err := h.ensureValidToken(account); err != nil {
h.sendOpenAIError(w, 503, "server_error", "Token refresh failed")
return
}
// 解析模型和 thinking 模式
thinkingCfg := config.GetThinkingConfig()
actualModel, thinking := ParseModelAndThinking(req.Model, thinkingCfg.Suffix)
req.Model = actualModel
estimatedInputTokens := estimateOpenAIRequestInputTokens(&req)
kiroPayload := OpenAIToKiro(&req, thinking)
if req.Stream {
h.handleOpenAIStream(w, account, kiroPayload, req.Model, thinking, estimatedInputTokens)
} else {
h.handleOpenAINonStream(w, account, kiroPayload, req.Model, thinking, estimatedInputTokens)
}
}
// handleOpenAIStream OpenAI 流式响应
func (h *Handler) handleOpenAIStream(w http.ResponseWriter, account *config.Account, payload *KiroPayload, model string, thinking bool, estimatedInputTokens int) {
w.Header().Set("Content-Type", "text/event-stream; charset=utf-8")
w.Header().Set("Cache-Control", "no-cache")
w.Header().Set("Connection", "keep-alive")
flusher, ok := w.(http.Flusher)
if !ok {
h.sendOpenAIError(w, 500, "server_error", "Streaming not supported")
return
}
// 获取 thinking 输出格式配置
thinkingFormat := config.GetThinkingConfig().OpenAIFormat
chatID := "chatcmpl-" + uuid.New().String()
var toolCalls []ToolCall
var toolCallIndex int
var inputTokens, outputTokens int
var credits float64
var realInputTokens int
var rawContentBuilder strings.Builder
var rawReasoningBuilder strings.Builder
// Thinking 标签解析状态
var textBuffer string
var inThinkingBlock bool
var dropTagThinking bool
var thinkingSource thinkingStreamSource
// 发送 chunk 的辅助函数
// thinkingState: 0=普通内容, 1=thinking开始, 2=thinking中间, 3=thinking结束
sendChunk := func(content string, thinkingState int) {
if content == "" && thinkingState == 2 {
return
}
var chunk map[string]interface{}
if thinkingState > 0 {
if !thinking {
return
}
// thinking 内容
switch thinkingFormat {
case "thinking":
// 流式输出标签
var text string
switch thinkingState {
case 1: // 开始
text = "" + content
case 2: // 中间
text = content
case 3: // 结束
text = content + ""
}
if text == "" {
return
}
chunk = map[string]interface{}{
"id": chatID,
"object": "chat.completion.chunk",
"created": time.Now().Unix(),
"model": model,
"choices": []map[string]interface{}{{
"index": 0,
"delta": map[string]string{"content": text},
"finish_reason": nil,
}},
}
case "think":
var text string
switch thinkingState {
case 1:
text = "" + content
case 2:
text = content
case 3:
text = content + ""
}
if text == "" {
return
}
chunk = map[string]interface{}{
"id": chatID,
"object": "chat.completion.chunk",
"created": time.Now().Unix(),
"model": model,
"choices": []map[string]interface{}{{
"index": 0,
"delta": map[string]string{"content": text},
"finish_reason": nil,
}},
}
default: // "reasoning_content"
if content == "" {
return
}
chunk = map[string]interface{}{
"id": chatID,
"object": "chat.completion.chunk",
"created": time.Now().Unix(),
"model": model,
"choices": []map[string]interface{}{{
"index": 0,
"delta": map[string]string{"reasoning_content": content},
"finish_reason": nil,
}},
}
}
} else {
// 普通内容
if content == "" {
return
}
chunk = map[string]interface{}{
"id": chatID,
"object": "chat.completion.chunk",
"created": time.Now().Unix(),
"model": model,
"choices": []map[string]interface{}{{
"index": 0,
"delta": map[string]string{"content": content},
"finish_reason": nil,
}},
}
}
data, _ := json.Marshal(chunk)
fmt.Fprintf(w, "data: %s\n\n", string(data))
flusher.Flush()
}
// 处理文本,解析 标签
// thinkingStarted 用于跟踪是否已发送开始标签
var thinkingStarted bool
var eventThinkingOpen bool
processText := func(text string, isThinking bool, forceFlush bool) {
if isThinking && !thinking {
return
}
// 如果是 reasoningContentEvent,直接输出
if isThinking {
if !allowReasoningSource(&thinkingSource) {
return
}
if !thinkingStarted {
sendChunk(text, 1) // 开始
thinkingStarted = true
eventThinkingOpen = true
} else {
sendChunk(text, 2) // 中间
}
return
}
if eventThinkingOpen {
sendChunk("", 3)
eventThinkingOpen = false
thinkingStarted = false
}
textBuffer += text
for {
if !inThinkingBlock {
// 查找 开始标签
thinkingStart := strings.Index(textBuffer, "")
if thinkingStart != -1 {
// 输出 thinking 标签之前的内容
if thinkingStart > 0 {
sendChunk(textBuffer[:thinkingStart], 0)
}
textBuffer = textBuffer[thinkingStart+10:] // 移除
inThinkingBlock = true
dropTagThinking = !allowTagSource(&thinkingSource)
thinkingStarted = false // 重置,准备发送新的开始标签
} else if forceFlush || len([]rune(textBuffer)) > 50 {
// 没有找到标签,安全输出(保留可能的部分标签)
runes := []rune(textBuffer)
safeLen := len(runes)
if !forceFlush {
safeLen = max(0, len(runes)-15)
}
if safeLen > 0 {
sendChunk(string(runes[:safeLen]), 0)
textBuffer = string(runes[safeLen:])
}
break
} else {
break
}
} else {
// 在 thinking 块内,查找 结束标签
thinkingEnd := strings.Index(textBuffer, "")
if thinkingEnd != -1 {
// 输出 thinking 内容
content := textBuffer[:thinkingEnd]
if !dropTagThinking {
if !thinkingStarted {
// 一次性输出完整内容(开始+内容+结束)
sendChunk(content, 1) // 开始
sendChunk("", 3) // 结束(空内容,只发结束标签)
} else {
// 已经开始了,发送剩余内容和结束
sendChunk(content, 3) // 结束
}
}
textBuffer = textBuffer[thinkingEnd+11:] // 移除
inThinkingBlock = false
dropTagThinking = false
thinkingStarted = false
} else if forceFlush {
// 强制刷新:输出剩余内容
if textBuffer != "" {
if !dropTagThinking {
if !thinkingStarted {
sendChunk(textBuffer, 1) // 开始
sendChunk("", 3) // 结束
} else {
sendChunk(textBuffer, 3) // 结束
}
}
textBuffer = ""
}
inThinkingBlock = false
dropTagThinking = false
thinkingStarted = false
break
} else {
// 流式输出 thinking 块内的内容
runes := []rune(textBuffer)
if len(runes) > 20 {
safeLen := len(runes) - 15 // 保留可能的 部分
if safeLen > 0 {
if !dropTagThinking {
if !thinkingStarted {
sendChunk(string(runes[:safeLen]), 1) // 开始
thinkingStarted = true
} else {
sendChunk(string(runes[:safeLen]), 2) // 中间
}
}
textBuffer = string(runes[safeLen:])
}
}
break
}
}
}
}
callback := &KiroStreamCallback{
OnText: func(text string, isThinking bool) {
if text == "" {
return
}
if isThinking {
rawReasoningBuilder.WriteString(text)
} else {
rawContentBuilder.WriteString(text)
}
processText(text, isThinking, false)
},
OnToolUse: func(tu KiroToolUse) {
// 先刷新缓冲区
processText("", false, true)
args, _ := json.Marshal(tu.Input)
rawContentBuilder.WriteString(tu.Name)
rawContentBuilder.Write(args)
tc := ToolCall{ID: tu.ToolUseID, Type: "function"}
tc.Function.Name = tu.Name
tc.Function.Arguments = string(args)
toolCalls = append(toolCalls, tc)
chunk := map[string]interface{}{
"id": chatID,
"object": "chat.completion.chunk",
"created": time.Now().Unix(),
"model": model,
"choices": []map[string]interface{}{{
"index": 0,
"delta": map[string]interface{}{
"tool_calls": []map[string]interface{}{{
"index": toolCallIndex,
"id": tu.ToolUseID,
"type": "function",
"function": map[string]string{
"name": tu.Name,
"arguments": string(args),
},
}},
},
"finish_reason": nil,
}},
}
toolCallIndex++
data, _ := json.Marshal(chunk)
fmt.Fprintf(w, "data: %s\n\n", string(data))
flusher.Flush()
},
OnComplete: func(inTok, outTok int) {
inputTokens = inTok
outputTokens = outTok
},
OnError: func(err error) {
h.pool.RecordError(account.ID, strings.Contains(err.Error(), "429"))
},
OnCredits: func(c float64) {
credits = c
},
OnContextUsage: func(pct float64) {
realInputTokens = int(pct * float64(getContextWindowSize(model)) / 100.0)
},
}
err := CallKiroAPI(account, payload, callback)
if err != nil {
h.recordFailure()
h.pool.RecordError(account.ID, strings.Contains(err.Error(), "429"))
return
}
// 刷新剩余缓冲区
processText("", false, true)
if eventThinkingOpen {
sendChunk("", 3)
eventThinkingOpen = false
}
if realInputTokens > 0 {
inputTokens = realInputTokens
} else if inputTokens <= 0 {
inputTokens = estimatedInputTokens
}
outputContent, extractedReasoning := extractThinkingFromContent(rawContentBuilder.String())
reasoningOutput := rawReasoningBuilder.String()
if thinking && reasoningOutput == "" && extractedReasoning != "" {
reasoningOutput = extractedReasoning
}
if !thinking {
reasoningOutput = ""
}
outputTokens = estimateApproxTokens(outputContent) + estimateApproxTokens(reasoningOutput)
for _, tc := range toolCalls {
outputTokens += estimateApproxTokens(tc.Function.Name)
outputTokens += estimateApproxTokens(tc.Function.Arguments)
}
h.recordSuccess(inputTokens, outputTokens, credits)
h.pool.RecordSuccess(account.ID)
h.pool.UpdateStats(account.ID, inputTokens+outputTokens, credits)
// 发送结束
finishReason := "stop"
if len(toolCalls) > 0 {
finishReason = "tool_calls"
}
chunk := map[string]interface{}{
"id": chatID,
"object": "chat.completion.chunk",
"created": time.Now().Unix(),
"model": model,
"choices": []map[string]interface{}{{
"index": 0,
"delta": map[string]interface{}{},
"finish_reason": finishReason,
}},
"usage": map[string]int{
"prompt_tokens": inputTokens,
"completion_tokens": outputTokens,
"total_tokens": inputTokens + outputTokens,
},
}
data, _ := json.Marshal(chunk)
fmt.Fprintf(w, "data: %s\n\n", string(data))
fmt.Fprintf(w, "data: [DONE]\n\n")
flusher.Flush()
}
// handleOpenAINonStream OpenAI 非流式响应
func (h *Handler) handleOpenAINonStream(w http.ResponseWriter, account *config.Account, payload *KiroPayload, model string, thinking bool, estimatedInputTokens int) {
var content string
var reasoningContent string
var toolUses []KiroToolUse
var inputTokens, outputTokens int
var credits float64
var realInputTokens int
callback := &KiroStreamCallback{
OnText: func(text string, isThinking bool) {
if isThinking {
reasoningContent += text
} else {
content += text
}
},
OnToolUse: func(tu KiroToolUse) { toolUses = append(toolUses, tu) },
OnComplete: func(inTok, outTok int) { inputTokens = inTok; outputTokens = outTok },
OnError: func(err error) { h.pool.RecordError(account.ID, strings.Contains(err.Error(), "429")) },
OnCredits: func(c float64) { credits = c },
OnContextUsage: func(pct float64) {
realInputTokens = int(pct * float64(getContextWindowSize(model)) / 100.0)
},
}
err := CallKiroAPI(account, payload, callback)
if err != nil {
h.recordFailure()
h.pool.RecordError(account.ID, strings.Contains(err.Error(), "429"))
h.sendOpenAIError(w, 500, "server_error", err.Error())
return
}
// 解析 content 中的 标签
finalContent, extractedReasoning := extractThinkingFromContent(content)
if thinking && reasoningContent == "" && extractedReasoning != "" {
reasoningContent = extractedReasoning
} else if !thinking {
reasoningContent = ""
}
if realInputTokens > 0 {
inputTokens = realInputTokens
} else if inputTokens <= 0 {
inputTokens = estimatedInputTokens
}
outputTokens = estimateOpenAIOutputTokens(finalContent, reasoningContent, toolUses)
h.recordSuccess(inputTokens, outputTokens, credits)
h.pool.RecordSuccess(account.ID)
h.pool.UpdateStats(account.ID, inputTokens+outputTokens, credits)
thinkingFormat := config.GetThinkingConfig().OpenAIFormat
resp := KiroToOpenAIResponseWithReasoning(finalContent, reasoningContent, toolUses, inputTokens, outputTokens, model, thinkingFormat)
w.Header().Set("Content-Type", "application/json; charset=utf-8")
json.NewEncoder(w).Encode(resp)
}
func (h *Handler) sendOpenAIError(w http.ResponseWriter, status int, errType, message string) {
w.Header().Set("Content-Type", "application/json; charset=utf-8")
w.WriteHeader(status)
json.NewEncoder(w).Encode(map[string]interface{}{
"error": map[string]interface{}{
"type": errType,
"message": message,
},
})
}
// ensureValidToken 确保 token 有效
func (h *Handler) ensureValidToken(account *config.Account) error {
if account.ExpiresAt == 0 || time.Now().Unix() < account.ExpiresAt-300 {
return nil
}
accessToken, refreshToken, expiresAt, err := auth.RefreshToken(account)
if err != nil {
return err
}
// 更新内存
h.pool.UpdateToken(account.ID, accessToken, refreshToken, expiresAt)
account.AccessToken = accessToken
if refreshToken != "" {
account.RefreshToken = refreshToken
}
account.ExpiresAt = expiresAt
// 持久化
config.UpdateAccountToken(account.ID, accessToken, refreshToken, expiresAt)
return nil
}
// ==================== 管理 API ====================
func (h *Handler) handleAdminAPI(w http.ResponseWriter, r *http.Request) {
// 验证密码
password := r.Header.Get("X-Admin-Password")
if password == "" {
cookie, _ := r.Cookie("admin_password")
if cookie != nil {
password = cookie.Value
}
}
if password != config.GetPassword() {
w.WriteHeader(401)
json.NewEncoder(w).Encode(map[string]string{"error": "Unauthorized"})
return
}
path := strings.TrimPrefix(r.URL.Path, "/admin/api")
w.Header().Set("Content-Type", "application/json; charset=utf-8")
switch {
case path == "/accounts" && r.Method == "GET":
h.apiGetAccounts(w, r)
case path == "/accounts" && r.Method == "POST":
h.apiAddAccount(w, r)
case path == "/accounts/batch" && r.Method == "POST":
h.apiBatchAccounts(w, r)
case strings.HasPrefix(path, "/accounts/") && strings.HasSuffix(path, "/refresh") && r.Method == "POST":
id := strings.TrimSuffix(strings.TrimPrefix(path, "/accounts/"), "/refresh")
h.apiRefreshAccount(w, r, id)
case strings.HasPrefix(path, "/accounts/") && strings.HasSuffix(path, "/models") && r.Method == "GET":
id := strings.TrimSuffix(strings.TrimPrefix(path, "/accounts/"), "/models")
h.apiGetAccountModels(w, r, id)
case strings.HasPrefix(path, "/accounts/") && strings.HasSuffix(path, "/full") && r.Method == "GET":
id := strings.TrimSuffix(strings.TrimPrefix(path, "/accounts/"), "/full")
h.apiGetAccountFull(w, r, id)
case strings.HasPrefix(path, "/accounts/") && r.Method == "DELETE":
h.apiDeleteAccount(w, r, strings.TrimPrefix(path, "/accounts/"))
case strings.HasPrefix(path, "/accounts/") && r.Method == "PUT":
h.apiUpdateAccount(w, r, strings.TrimPrefix(path, "/accounts/"))
case path == "/auth/iam-sso/start" && r.Method == "POST":
h.apiStartIamSso(w, r)
case path == "/auth/iam-sso/complete" && r.Method == "POST":
h.apiCompleteIamSso(w, r)
case path == "/auth/builderid/start" && r.Method == "POST":
h.apiStartBuilderIdLogin(w, r)
case path == "/auth/builderid/poll" && r.Method == "POST":
h.apiPollBuilderIdAuth(w, r)
case path == "/auth/sso-token" && r.Method == "POST":
h.apiImportSsoToken(w, r)
case path == "/auth/credentials" && r.Method == "POST":
h.apiImportCredentials(w, r)
case path == "/status" && r.Method == "GET":
h.apiGetStatus(w, r)
case path == "/settings" && r.Method == "GET":
h.apiGetSettings(w, r)
case path == "/settings" && r.Method == "POST":
h.apiUpdateSettings(w, r)
case path == "/stats" && r.Method == "GET":
h.apiGetStats(w, r)
case path == "/stats/reset" && r.Method == "POST":
h.apiResetStats(w, r)
case path == "/generate-machine-id" && r.Method == "GET":
h.apiGenerateMachineId(w, r)
case path == "/thinking" && r.Method == "GET":
h.apiGetThinkingConfig(w, r)
case path == "/thinking" && r.Method == "POST":
h.apiUpdateThinkingConfig(w, r)
case path == "/endpoint" && r.Method == "GET":
h.apiGetEndpointConfig(w, r)
case path == "/endpoint" && r.Method == "POST":
h.apiUpdateEndpointConfig(w, r)
case path == "/proxy" && r.Method == "GET":
h.apiGetProxy(w, r)
case path == "/proxy" && r.Method == "POST":
h.apiUpdateProxy(w, r)
case path == "/general" && r.Method == "GET":
h.apiGetGeneralConfig(w, r)
case path == "/general" && r.Method == "POST":
h.apiUpdateGeneralConfig(w, r)
case path == "/version" && r.Method == "GET":
h.apiGetVersion(w, r)
case path == "/export" && r.Method == "POST":
h.apiExportAccounts(w, r)
default:
w.WriteHeader(404)
json.NewEncoder(w).Encode(map[string]string{"error": "Not Found"})
}
}
func (h *Handler) apiGetAccounts(w http.ResponseWriter, r *http.Request) {
accounts := config.GetAccounts()
poolAccounts := h.pool.GetAllAccounts()
// 合并运行时统计
statsMap := make(map[string]config.Account)
for _, a := range poolAccounts {
statsMap[a.ID] = a
}
// 隐藏敏感信息
result := make([]map[string]interface{}, len(accounts))
for i, a := range accounts {
// 获取运行时统计
stats := statsMap[a.ID]
result[i] = map[string]interface{}{
"id": a.ID,
"email": a.Email,
"userId": a.UserId,
"nickname": a.Nickname,
"authMethod": a.AuthMethod,
"provider": a.Provider,
"region": a.Region,
"enabled": a.Enabled,
"banStatus": a.BanStatus,
"banReason": a.BanReason,
"banTime": a.BanTime,
"expiresAt": a.ExpiresAt,
"hasToken": a.AccessToken != "",
"machineId": a.MachineId,
"weight": a.Weight,
"subscriptionType": a.SubscriptionType,
"subscriptionTitle": a.SubscriptionTitle,
"daysRemaining": a.DaysRemaining,
"usageCurrent": a.UsageCurrent,
"usageLimit": a.UsageLimit,
"usagePercent": a.UsagePercent,
"nextResetDate": a.NextResetDate,
"lastRefresh": a.LastRefresh,
"trialUsageCurrent": a.TrialUsageCurrent,
"trialUsageLimit": a.TrialUsageLimit,
"trialUsagePercent": a.TrialUsagePercent,
"trialStatus": a.TrialStatus,
"trialExpiresAt": a.TrialExpiresAt,
"requestCount": stats.RequestCount,
"errorCount": stats.ErrorCount,
"totalTokens": stats.TotalTokens,
"totalCredits": stats.TotalCredits,
"lastUsed": stats.LastUsed,
}
}
json.NewEncoder(w).Encode(result)
}
func (h *Handler) apiAddAccount(w http.ResponseWriter, r *http.Request) {
var account config.Account
if err := json.NewDecoder(r.Body).Decode(&account); err != nil {
w.WriteHeader(400)
json.NewEncoder(w).Encode(map[string]string{"error": "Invalid JSON"})
return
}
if account.ID == "" {
account.ID = auth.GenerateAccountID()
}
if account.Region == "" {
account.Region = "us-east-1"
}
if err := config.AddAccount(account); err != nil {
w.WriteHeader(500)
json.NewEncoder(w).Encode(map[string]string{"error": err.Error()})
return
}
h.pool.Reload()
json.NewEncoder(w).Encode(map[string]interface{}{"success": true, "id": account.ID})
}
func (h *Handler) apiDeleteAccount(w http.ResponseWriter, r *http.Request, id string) {
if err := config.DeleteAccount(id); err != nil {
w.WriteHeader(500)
json.NewEncoder(w).Encode(map[string]string{"error": err.Error()})
return
}
h.pool.Reload()
json.NewEncoder(w).Encode(map[string]bool{"success": true})
}
func (h *Handler) apiUpdateAccount(w http.ResponseWriter, r *http.Request, id string) {
var updates map[string]interface{}
if err := json.NewDecoder(r.Body).Decode(&updates); err != nil {
w.WriteHeader(400)
json.NewEncoder(w).Encode(map[string]string{"error": "Invalid JSON"})
return
}
// 获取现有账号
accounts := config.GetAccounts()
var existing *config.Account
for i := range accounts {
if accounts[i].ID == id {
existing = &accounts[i]
break
}
}
if existing == nil {
w.WriteHeader(404)
json.NewEncoder(w).Encode(map[string]string{"error": "Account not found"})
return
}
// 只更新传入的字段
if v, ok := updates["enabled"].(bool); ok {
existing.Enabled = v
}
if v, ok := updates["nickname"].(string); ok {
existing.Nickname = v
}
if v, ok := updates["machineId"].(string); ok {
existing.MachineId = v
}
if v, ok := updates["weight"].(float64); ok {
existing.Weight = int(v)
}
if err := config.UpdateAccount(id, *existing); err != nil {
w.WriteHeader(500)
json.NewEncoder(w).Encode(map[string]string{"error": err.Error()})
return
}
h.pool.Reload()
json.NewEncoder(w).Encode(map[string]bool{"success": true})
}
// apiBatchAccounts 批量操作账号(启用/禁用/刷新)
func (h *Handler) apiBatchAccounts(w http.ResponseWriter, r *http.Request) {
var req struct {
IDs []string `json:"ids"`
Action string `json:"action"` // "enable", "disable", "refresh"
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
w.WriteHeader(400)
json.NewEncoder(w).Encode(map[string]string{"error": "Invalid JSON"})
return
}
if len(req.IDs) == 0 {
w.WriteHeader(400)
json.NewEncoder(w).Encode(map[string]string{"error": "No account IDs provided"})
return
}
switch req.Action {
case "enable", "disable":
enabled := req.Action == "enable"
accounts := config.GetAccounts()
idSet := make(map[string]bool)
for _, id := range req.IDs {
idSet[id] = true
}
for _, a := range accounts {
if idSet[a.ID] {
a.Enabled = enabled
if enabled && a.BanStatus != "" && a.BanStatus != "ACTIVE" {
a.BanStatus = "ACTIVE"
a.BanReason = ""
a.BanTime = 0
}
config.UpdateAccount(a.ID, a)
}
}
h.pool.Reload()
json.NewEncoder(w).Encode(map[string]interface{}{"success": true, "count": len(req.IDs)})
case "refresh":
successCount := 0
failCount := 0
for _, id := range req.IDs {
accounts := config.GetAccounts()
var account *config.Account
for i := range accounts {
if accounts[i].ID == id {
account = &accounts[i]
break
}
}
if account == nil {
failCount++
continue
}
// 刷新 token
if account.RefreshToken != "" {
if newAccess, newRefresh, newExpires, err := auth.RefreshToken(account); err == nil {
account.AccessToken = newAccess
if newRefresh != "" {
account.RefreshToken = newRefresh
}
account.ExpiresAt = newExpires
config.UpdateAccountToken(id, newAccess, newRefresh, newExpires)
h.pool.UpdateToken(id, newAccess, newRefresh, newExpires)
}
}
// 刷新账户信息
info, err := RefreshAccountInfo(account)
if err != nil {
failCount++
continue
}
config.UpdateAccountInfo(id, *info)
successCount++
}
h.pool.Reload()
json.NewEncoder(w).Encode(map[string]interface{}{
"success": true,
"refreshed": successCount,
"failed": failCount,
})
default:
w.WriteHeader(400)
json.NewEncoder(w).Encode(map[string]string{"error": "Invalid action: " + req.Action})
}
}
func (h *Handler) apiStartIamSso(w http.ResponseWriter, r *http.Request) {
var req struct {
StartUrl string `json:"startUrl"`
Region string `json:"region"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
w.WriteHeader(400)
json.NewEncoder(w).Encode(map[string]string{"error": "Invalid JSON"})
return
}
if req.StartUrl == "" {
w.WriteHeader(400)
json.NewEncoder(w).Encode(map[string]string{"error": "startUrl is required"})
return
}
sessionID, authorizeUrl, expiresIn, err := auth.StartIamSsoLogin(req.StartUrl, req.Region)
if err != nil {
w.WriteHeader(500)
json.NewEncoder(w).Encode(map[string]string{"error": err.Error()})
return
}
json.NewEncoder(w).Encode(map[string]interface{}{
"sessionId": sessionID,
"authorizeUrl": authorizeUrl,
"expiresIn": expiresIn,
})
}
func (h *Handler) apiCompleteIamSso(w http.ResponseWriter, r *http.Request) {
var req struct {
SessionID string `json:"sessionId"`
CallbackUrl string `json:"callbackUrl"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
w.WriteHeader(400)
json.NewEncoder(w).Encode(map[string]string{"error": "Invalid JSON"})
return
}
accessToken, refreshToken, clientID, clientSecret, region, expiresIn, err := auth.CompleteIamSsoLogin(req.SessionID, req.CallbackUrl)
if err != nil {
w.WriteHeader(400)
json.NewEncoder(w).Encode(map[string]string{"error": err.Error()})
return
}
// 获取用户信息
email, _, _ := auth.GetUserInfo(accessToken)
// 创建账号
account := config.Account{
ID: auth.GenerateAccountID(),
Email: email,
AccessToken: accessToken,
RefreshToken: refreshToken,
ClientID: clientID,
ClientSecret: clientSecret,
AuthMethod: "idc",
Region: region,
ExpiresAt: time.Now().Unix() + int64(expiresIn),
Enabled: true,
MachineId: config.GenerateMachineId(),
}
if err := config.AddAccount(account); err != nil {
w.WriteHeader(500)
json.NewEncoder(w).Encode(map[string]string{"error": err.Error()})
return
}
h.pool.Reload()
json.NewEncoder(w).Encode(map[string]interface{}{
"success": true,
"account": map[string]interface{}{
"id": account.ID,
"email": account.Email,
},
})
}
func (h *Handler) apiStartBuilderIdLogin(w http.ResponseWriter, r *http.Request) {
var req struct {
Region string `json:"region"`
}
json.NewDecoder(r.Body).Decode(&req)
session, err := auth.StartBuilderIdLogin(req.Region)
if err != nil {
w.WriteHeader(500)
json.NewEncoder(w).Encode(map[string]string{"error": err.Error()})
return
}
json.NewEncoder(w).Encode(map[string]interface{}{
"sessionId": session.ID,
"userCode": session.UserCode,
"verificationUri": session.VerificationUri,
"interval": session.Interval,
})
}
func (h *Handler) apiPollBuilderIdAuth(w http.ResponseWriter, r *http.Request) {
var req struct {
SessionID string `json:"sessionId"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
w.WriteHeader(400)
json.NewEncoder(w).Encode(map[string]string{"error": "Invalid JSON"})
return
}
accessToken, refreshToken, clientID, clientSecret, region, expiresIn, status, err := auth.PollBuilderIdAuth(req.SessionID)
if err != nil {
w.WriteHeader(400)
json.NewEncoder(w).Encode(map[string]interface{}{
"success": false,
"error": err.Error(),
})
return
}
if status == "pending" || status == "slow_down" {
// 获取当前间隔
interval := 5
if session := auth.GetBuilderIdSession(req.SessionID); session != nil {
interval = session.Interval
}
json.NewEncoder(w).Encode(map[string]interface{}{
"success": true,
"completed": false,
"status": status,
"interval": interval,
})
return
}
// 授权完成,获取用户信息
email, _, _ := auth.GetUserInfo(accessToken)
// 创建账号
account := config.Account{
ID: auth.GenerateAccountID(),
Email: email,
AccessToken: accessToken,
RefreshToken: refreshToken,
ClientID: clientID,
ClientSecret: clientSecret,
AuthMethod: "idc",
Provider: "BuilderId",
Region: region,
ExpiresAt: time.Now().Unix() + int64(expiresIn),
Enabled: true,
MachineId: config.GenerateMachineId(),
}
if err := config.AddAccount(account); err != nil {
w.WriteHeader(500)
json.NewEncoder(w).Encode(map[string]string{"error": err.Error()})
return
}
h.pool.Reload()
json.NewEncoder(w).Encode(map[string]interface{}{
"success": true,
"completed": true,
"account": map[string]interface{}{
"id": account.ID,
"email": account.Email,
},
})
}
func (h *Handler) apiImportSsoToken(w http.ResponseWriter, r *http.Request) {
var req struct {
BearerToken string `json:"bearerToken"`
Region string `json:"region"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
w.WriteHeader(400)
json.NewEncoder(w).Encode(map[string]string{"error": "Invalid JSON"})
return
}
if req.BearerToken == "" {
w.WriteHeader(400)
json.NewEncoder(w).Encode(map[string]string{"error": "bearerToken is required"})
return
}
// 支持批量导入,按行分割
tokens := strings.Split(strings.TrimSpace(req.BearerToken), "\n")
var imported []map[string]interface{}
var errors []string
for _, token := range tokens {
token = strings.TrimSpace(token)
if token == "" {
continue
}
accessToken, refreshToken, clientID, clientSecret, expiresIn, err := auth.ImportFromSsoToken(token, req.Region)
if err != nil {
errors = append(errors, err.Error())
continue
}
// 获取用户信息
email, _, _ := auth.GetUserInfo(accessToken)
// 创建账号
account := config.Account{
ID: auth.GenerateAccountID(),
Email: email,
AccessToken: accessToken,
RefreshToken: refreshToken,
ClientID: clientID,
ClientSecret: clientSecret,
AuthMethod: "idc",
Region: req.Region,
ExpiresAt: time.Now().Unix() + int64(expiresIn),
Enabled: true,
MachineId: config.GenerateMachineId(),
}
if err := config.AddAccount(account); err != nil {
errors = append(errors, err.Error())
continue
}
imported = append(imported, map[string]interface{}{
"id": account.ID,
"email": account.Email,
})
}
h.pool.Reload()
if len(imported) == 0 && len(errors) > 0 {
w.WriteHeader(500)
json.NewEncoder(w).Encode(map[string]interface{}{
"success": false,
"error": strings.Join(errors, "; "),
})
return
}
json.NewEncoder(w).Encode(map[string]interface{}{
"success": true,
"accounts": imported,
"errors": errors,
})
}
func (h *Handler) apiImportCredentials(w http.ResponseWriter, r *http.Request) {
var req struct {
AccessToken string `json:"accessToken"`
RefreshToken string `json:"refreshToken"`
ClientID string `json:"clientId"`
ClientSecret string `json:"clientSecret"`
AuthMethod string `json:"authMethod"`
Provider string `json:"provider"`
Region string `json:"region"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
w.WriteHeader(400)
json.NewEncoder(w).Encode(map[string]string{"error": "Invalid JSON"})
return
}
if req.RefreshToken == "" {
w.WriteHeader(400)
json.NewEncoder(w).Encode(map[string]string{"error": "refreshToken is required"})
return
}
// 设置默认值
if req.Region == "" {
req.Region = "us-east-1"
}
if req.AuthMethod == "" {
if req.ClientID != "" {
req.AuthMethod = "idc"
} else {
req.AuthMethod = "social"
}
}
// 标准化 authMethod
switch strings.ToLower(req.AuthMethod) {
case "idc", "builderid", "enterprise":
req.AuthMethod = "idc"
case "social", "google", "github":
req.AuthMethod = "social"
default:
if req.ClientID != "" && req.ClientSecret != "" {
req.AuthMethod = "idc"
} else {
req.AuthMethod = "social"
}
}
// 始终尝试用 refreshToken 刷新获取新的 accessToken
var accessToken string
var expiresAt int64
tempAccount := &config.Account{
RefreshToken: req.RefreshToken,
ClientID: req.ClientID,
ClientSecret: req.ClientSecret,
AuthMethod: req.AuthMethod,
Region: req.Region,
}
newAccessToken, newRefreshToken, newExpiresAt, err := auth.RefreshToken(tempAccount)
if err != nil {
// 刷新失败,如果有传入的 accessToken 则尝试使用
if req.AccessToken != "" {
accessToken = req.AccessToken
expiresAt = time.Now().Unix() + 300 // 可能已过期,设短一点
} else {
w.WriteHeader(400)
json.NewEncoder(w).Encode(map[string]string{"error": "Token refresh failed: " + err.Error()})
return
}
} else {
accessToken = newAccessToken
if newRefreshToken != "" {
req.RefreshToken = newRefreshToken
}
expiresAt = newExpiresAt
}
// 获取用户信息
email, _, _ := auth.GetUserInfo(accessToken)
// 创建账号
account := config.Account{
ID: auth.GenerateAccountID(),
Email: email,
AccessToken: accessToken,
RefreshToken: req.RefreshToken,
ClientID: req.ClientID,
ClientSecret: req.ClientSecret,
AuthMethod: req.AuthMethod,
Provider: req.Provider,
Region: req.Region,
ExpiresAt: expiresAt,
Enabled: true,
MachineId: config.GenerateMachineId(),
}
if err := config.AddAccount(account); err != nil {
w.WriteHeader(500)
json.NewEncoder(w).Encode(map[string]string{"error": err.Error()})
return
}
h.pool.Reload()
json.NewEncoder(w).Encode(map[string]interface{}{
"success": true,
"account": map[string]interface{}{
"id": account.ID,
"email": account.Email,
},
})
}
func (h *Handler) apiGetStatus(w http.ResponseWriter, r *http.Request) {
json.NewEncoder(w).Encode(map[string]interface{}{
"accounts": h.pool.Count(),
"available": h.pool.AvailableCount(),
"totalRequests": h.totalRequests,
"successRequests": h.successRequests,
"failedRequests": h.failedRequests,
"totalTokens": h.totalTokens,
"totalCredits": h.totalCredits,
"uptime": time.Now().Unix() - h.startTime,
})
}
func (h *Handler) apiGetSettings(w http.ResponseWriter, r *http.Request) {
json.NewEncoder(w).Encode(map[string]interface{}{
"apiKey": config.GetApiKey(),
"requireApiKey": config.IsApiKeyRequired(),
"port": config.GetPort(),
"host": config.GetHost(),
})
}
func (h *Handler) apiUpdateSettings(w http.ResponseWriter, r *http.Request) {
var req struct {
ApiKey string `json:"apiKey"`
RequireApiKey bool `json:"requireApiKey"`
Password string `json:"password"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
w.WriteHeader(400)
json.NewEncoder(w).Encode(map[string]string{"error": "Invalid JSON"})
return
}
if err := config.UpdateSettings(req.ApiKey, req.RequireApiKey, req.Password); err != nil {
w.WriteHeader(500)
json.NewEncoder(w).Encode(map[string]string{"error": err.Error()})
return
}
json.NewEncoder(w).Encode(map[string]bool{"success": true})
}
func (h *Handler) apiGetStats(w http.ResponseWriter, r *http.Request) {
json.NewEncoder(w).Encode(map[string]interface{}{
"totalRequests": atomic.LoadInt64(&h.totalRequests),
"successRequests": atomic.LoadInt64(&h.successRequests),
"failedRequests": atomic.LoadInt64(&h.failedRequests),
"totalTokens": atomic.LoadInt64(&h.totalTokens),
"totalCredits": h.getCredits(),
"uptime": time.Now().Unix() - h.startTime,
})
}
func (h *Handler) apiResetStats(w http.ResponseWriter, r *http.Request) {
atomic.StoreInt64(&h.totalRequests, 0)
atomic.StoreInt64(&h.successRequests, 0)
atomic.StoreInt64(&h.failedRequests, 0)
atomic.StoreInt64(&h.totalTokens, 0)
h.creditsMu.Lock()
h.totalCredits = 0
h.creditsMu.Unlock()
config.UpdateStats(0, 0, 0, 0, 0)
json.NewEncoder(w).Encode(map[string]bool{"success": true})
}
// apiGenerateMachineId 生成新的机器码
func (h *Handler) apiGenerateMachineId(w http.ResponseWriter, r *http.Request) {
machineId := config.GenerateMachineId()
json.NewEncoder(w).Encode(map[string]string{"machineId": machineId})
}
// apiRefreshAccount 刷新账户信息(使用量、订阅等)
func (h *Handler) apiRefreshAccount(w http.ResponseWriter, r *http.Request, id string) {
accounts := config.GetAccounts()
var account *config.Account
for i := range accounts {
if accounts[i].ID == id {
account = &accounts[i]
break
}
}
if account == nil {
w.WriteHeader(404)
json.NewEncoder(w).Encode(map[string]string{"error": "Account not found"})
return
}
// 先尝试刷新 token(不管是否过期,确保 token 有效)
refreshTokenIfNeeded := func() error {
if account.RefreshToken == "" {
return nil
}
newAccessToken, newRefreshToken, newExpiresAt, err := auth.RefreshToken(account)
if err != nil {
return err
}
account.AccessToken = newAccessToken
if newRefreshToken != "" {
account.RefreshToken = newRefreshToken
}
account.ExpiresAt = newExpiresAt
config.UpdateAccountToken(id, newAccessToken, newRefreshToken, newExpiresAt)
h.pool.UpdateToken(id, newAccessToken, newRefreshToken, newExpiresAt)
return nil
}
// 检查 token 是否快过期,先刷新
if account.ExpiresAt > 0 && time.Now().Unix() > account.ExpiresAt-300 {
if err := refreshTokenIfNeeded(); err != nil {
w.WriteHeader(500)
json.NewEncoder(w).Encode(map[string]string{"error": "Token refresh failed: " + err.Error()})
return
}
}
// 获取账户信息
info, err := RefreshAccountInfo(account)
if err != nil {
// 检查是否为封禁相关错误
errMsg := err.Error()
if strings.Contains(errMsg, "TEMPORARILY_SUSPENDED") || strings.Contains(errMsg, "Account suspended") {
// 封禁状态已在 RefreshAccountInfo 中处理,静默返回成功
json.NewEncoder(w).Encode(map[string]interface{}{
"success": true,
"message": "Account status updated",
})
return
}
// 如果是 403/401,说明 token 无效,尝试刷新后重试
if strings.Contains(errMsg, "403") || strings.Contains(errMsg, "401") || strings.Contains(errMsg, "invalid") || strings.Contains(errMsg, "expired") {
if refreshErr := refreshTokenIfNeeded(); refreshErr == nil {
// 重试
info, err = RefreshAccountInfo(account)
if err != nil {
// 重试后仍然失败,检查是否为封禁状态
if strings.Contains(err.Error(), "TEMPORARILY_SUSPENDED") || strings.Contains(err.Error(), "Account suspended") {
json.NewEncoder(w).Encode(map[string]interface{}{
"success": true,
"message": "Account status updated",
})
return
}
}
}
}
// 其他错误才显示错误信息
if err != nil {
w.WriteHeader(500)
json.NewEncoder(w).Encode(map[string]string{"error": err.Error()})
return
}
}
// 保存到配置
if err := config.UpdateAccountInfo(id, *info); err != nil {
w.WriteHeader(500)
json.NewEncoder(w).Encode(map[string]string{"error": err.Error()})
return
}
json.NewEncoder(w).Encode(map[string]interface{}{
"success": true,
"info": info,
})
}
// apiGetAccountFull 获取单个账号的完整信息(包含敏感字段)
func (h *Handler) apiGetAccountFull(w http.ResponseWriter, r *http.Request, id string) {
accounts := config.GetAccounts()
poolAccounts := h.pool.GetAllAccounts()
// 查找指定账号
var account *config.Account
for i := range accounts {
if accounts[i].ID == id {
account = &accounts[i]
break
}
}
if account == nil {
w.WriteHeader(404)
json.NewEncoder(w).Encode(map[string]string{"error": "Account not found"})
return
}
// 获取运行时统计
var stats config.Account
for _, a := range poolAccounts {
if a.ID == id {
stats = a
break
}
}
// 返回完整账号信息(包含敏感字段)
result := map[string]interface{}{
"id": account.ID,
"email": account.Email,
"userId": account.UserId,
"nickname": account.Nickname,
"accessToken": account.AccessToken,
"refreshToken": account.RefreshToken,
"clientId": account.ClientID,
"clientSecret": account.ClientSecret,
"authMethod": account.AuthMethod,
"provider": account.Provider,
"region": account.Region,
"expiresAt": account.ExpiresAt,
"machineId": account.MachineId,
"enabled": account.Enabled,
"banStatus": account.BanStatus,
"banReason": account.BanReason,
"banTime": account.BanTime,
"subscriptionType": account.SubscriptionType,
"subscriptionTitle": account.SubscriptionTitle,
"daysRemaining": account.DaysRemaining,
"usageCurrent": account.UsageCurrent,
"usageLimit": account.UsageLimit,
"usagePercent": account.UsagePercent,
"nextResetDate": account.NextResetDate,
"lastRefresh": account.LastRefresh,
"trialUsageCurrent": account.TrialUsageCurrent,
"trialUsageLimit": account.TrialUsageLimit,
"trialUsagePercent": account.TrialUsagePercent,
"trialStatus": account.TrialStatus,
"trialExpiresAt": account.TrialExpiresAt,
"requestCount": stats.RequestCount,
"errorCount": stats.ErrorCount,
"totalTokens": stats.TotalTokens,
"totalCredits": stats.TotalCredits,
"lastUsed": stats.LastUsed,
}
json.NewEncoder(w).Encode(result)
}
// apiGetAccountModels 获取账户可用模型
func (h *Handler) apiGetAccountModels(w http.ResponseWriter, r *http.Request, id string) {
accounts := config.GetAccounts()
var account *config.Account
for i := range accounts {
if accounts[i].ID == id {
account = &accounts[i]
break
}
}
if account == nil {
w.WriteHeader(404)
json.NewEncoder(w).Encode(map[string]string{"error": "Account not found"})
return
}
models, err := ListAvailableModels(account)
if err != nil {
w.WriteHeader(500)
json.NewEncoder(w).Encode(map[string]string{"error": err.Error()})
return
}
json.NewEncoder(w).Encode(map[string]interface{}{
"success": true,
"models": models,
})
}
// ==================== 静态文件服务 ====================
func (h *Handler) serveAdminPage(w http.ResponseWriter, r *http.Request) {
http.ServeFile(w, r, "web/index.html")
}
func (h *Handler) serveStaticFile(w http.ResponseWriter, r *http.Request) {
path := strings.TrimPrefix(r.URL.Path, "/admin/")
http.ServeFile(w, r, "web/"+path)
}
// apiGetThinkingConfig 获取 thinking 配置
func (h *Handler) apiGetThinkingConfig(w http.ResponseWriter, r *http.Request) {
cfg := config.GetThinkingConfig()
json.NewEncoder(w).Encode(map[string]interface{}{
"suffix": cfg.Suffix,
"openaiFormat": cfg.OpenAIFormat,
"claudeFormat": cfg.ClaudeFormat,
})
}
// apiUpdateThinkingConfig 更新 thinking 配置
func (h *Handler) apiUpdateThinkingConfig(w http.ResponseWriter, r *http.Request) {
var req struct {
Suffix string `json:"suffix"`
OpenAIFormat string `json:"openaiFormat"`
ClaudeFormat string `json:"claudeFormat"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
w.WriteHeader(400)
json.NewEncoder(w).Encode(map[string]string{"error": "Invalid JSON"})
return
}
// 验证格式
validFormats := map[string]bool{"reasoning_content": true, "thinking": true, "think": true}
if req.OpenAIFormat != "" && !validFormats[req.OpenAIFormat] {
w.WriteHeader(400)
json.NewEncoder(w).Encode(map[string]string{"error": "Invalid openaiFormat, must be: reasoning_content, thinking, or think"})
return
}
if req.ClaudeFormat != "" && !validFormats[req.ClaudeFormat] {
w.WriteHeader(400)
json.NewEncoder(w).Encode(map[string]string{"error": "Invalid claudeFormat, must be: reasoning_content, thinking, or think"})
return
}
if err := config.UpdateThinkingConfig(req.Suffix, req.OpenAIFormat, req.ClaudeFormat); err != nil {
w.WriteHeader(500)
json.NewEncoder(w).Encode(map[string]string{"error": err.Error()})
return
}
json.NewEncoder(w).Encode(map[string]bool{"success": true})
}
// apiGetEndpointConfig 获取端点配置
func (h *Handler) apiGetEndpointConfig(w http.ResponseWriter, r *http.Request) {
json.NewEncoder(w).Encode(map[string]string{
"preferredEndpoint": config.GetPreferredEndpoint(),
})
}
// apiUpdateEndpointConfig 更新端点配置
func (h *Handler) apiUpdateEndpointConfig(w http.ResponseWriter, r *http.Request) {
var req struct {
PreferredEndpoint string `json:"preferredEndpoint"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
w.WriteHeader(400)
json.NewEncoder(w).Encode(map[string]string{"error": "Invalid JSON"})
return
}
valid := map[string]bool{"auto": true, "codewhisperer": true, "amazonq": true}
if !valid[req.PreferredEndpoint] {
w.WriteHeader(400)
json.NewEncoder(w).Encode(map[string]string{"error": "Invalid endpoint, must be: auto, codewhisperer, or amazonq"})
return
}
if err := config.UpdatePreferredEndpoint(req.PreferredEndpoint); err != nil {
w.WriteHeader(500)
json.NewEncoder(w).Encode(map[string]string{"error": err.Error()})
return
}
json.NewEncoder(w).Encode(map[string]bool{"success": true})
}
// applyProxyConfig 将代理配置应用到所有出站 HTTP 客户端(Kiro API + auth 模块)
func applyProxyConfig(proxyURL string) {
InitKiroHttpClient(proxyURL)
auth.InitHttpClient(proxyURL)
}
// apiGetProxy 获取当前代理配置
func (h *Handler) apiGetProxy(w http.ResponseWriter, r *http.Request) {
json.NewEncoder(w).Encode(map[string]string{
"proxyURL": config.GetProxyURL(),
})
}
// apiUpdateProxy 更新代理配置并立即生效
func (h *Handler) apiUpdateProxy(w http.ResponseWriter, r *http.Request) {
var req struct {
ProxyURL string `json:"proxyURL"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
w.WriteHeader(400)
json.NewEncoder(w).Encode(map[string]string{"error": "Invalid JSON"})
return
}
// 验证代理 URL 格式(非空时)
if req.ProxyURL != "" {
if !strings.HasPrefix(req.ProxyURL, "http://") &&
!strings.HasPrefix(req.ProxyURL, "https://") &&
!strings.HasPrefix(req.ProxyURL, "socks5://") &&
!strings.HasPrefix(req.ProxyURL, "socks5h://") {
w.WriteHeader(400)
json.NewEncoder(w).Encode(map[string]string{"error": "proxyURL must start with http://, https://, socks5://, or socks5h://"})
return
}
}
if err := config.UpdateProxySettings(req.ProxyURL); err != nil {
w.WriteHeader(500)
json.NewEncoder(w).Encode(map[string]string{"error": err.Error()})
return
}
// 立即应用新的代理配置
applyProxyConfig(req.ProxyURL)
json.NewEncoder(w).Encode(map[string]bool{"success": true})
}
// apiGetGeneralConfig 获取通用设置
func (h *Handler) apiGetGeneralConfig(w http.ResponseWriter, r *http.Request) {
json.NewEncoder(w).Encode(map[string]interface{}{
"invalidModelRetries": config.GetInvalidModelRetries(),
"firstByteTimeoutSec": config.GetFirstByteTimeoutSec(),
"firstByteRetries": config.GetFirstByteRetries(),
})
}
// apiUpdateGeneralConfig 更新通用设置
func (h *Handler) apiUpdateGeneralConfig(w http.ResponseWriter, r *http.Request) {
var req struct {
InvalidModelRetries *int `json:"invalidModelRetries"`
FirstByteTimeoutSec *int `json:"firstByteTimeoutSec"`
FirstByteRetries *int `json:"firstByteRetries"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
w.WriteHeader(400)
json.NewEncoder(w).Encode(map[string]string{"error": "Invalid JSON"})
return
}
if req.InvalidModelRetries != nil {
n := *req.InvalidModelRetries
if n < 0 || n > 20 {
w.WriteHeader(400)
json.NewEncoder(w).Encode(map[string]string{"error": "invalidModelRetries must be 0-20"})
return
}
if err := config.UpdateInvalidModelRetries(n); err != nil {
w.WriteHeader(500)
json.NewEncoder(w).Encode(map[string]string{"error": err.Error()})
return
}
}
if req.FirstByteTimeoutSec != nil {
n := *req.FirstByteTimeoutSec
if n < 0 || n > 300 {
w.WriteHeader(400)
json.NewEncoder(w).Encode(map[string]string{"error": "firstByteTimeoutSec must be 0-300"})
return
}
if err := config.UpdateFirstByteTimeoutSec(n); err != nil {
w.WriteHeader(500)
json.NewEncoder(w).Encode(map[string]string{"error": err.Error()})
return
}
}
if req.FirstByteRetries != nil {
n := *req.FirstByteRetries
if n < 0 || n > 10 {
w.WriteHeader(400)
json.NewEncoder(w).Encode(map[string]string{"error": "firstByteRetries must be 0-10"})
return
}
if err := config.UpdateFirstByteRetries(n); err != nil {
w.WriteHeader(500)
json.NewEncoder(w).Encode(map[string]string{"error": err.Error()})
return
}
}
json.NewEncoder(w).Encode(map[string]bool{"success": true})
}
// apiGetVersion 获取版本信息
func (h *Handler) apiGetVersion(w http.ResponseWriter, r *http.Request) {
json.NewEncoder(w).Encode(map[string]string{
"version": config.Version,
})
}
// apiExportAccounts 导出账号凭证
func (h *Handler) apiExportAccounts(w http.ResponseWriter, r *http.Request) {
var req struct {
IDs []string `json:"ids"` // 为空则导出全部
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
// 如果 body 为空或解析失败,导出全部
req.IDs = nil
}
accounts := config.GetAccounts()
// 如果指定了 ID,只导出指定的
if len(req.IDs) > 0 {
idSet := make(map[string]bool)
for _, id := range req.IDs {
idSet[id] = true
}
var filtered []config.Account
for _, a := range accounts {
if idSet[a.ID] {
filtered = append(filtered, a)
}
}
accounts = filtered
}
// 构建兼容 Kiro Account Manager 的导出格式
type ExportCredentials struct {
AccessToken string `json:"accessToken"`
CsrfToken string `json:"csrfToken"`
RefreshToken string `json:"refreshToken"`
ClientID string `json:"clientId,omitempty"`
ClientSecret string `json:"clientSecret,omitempty"`
Region string `json:"region,omitempty"`
ExpiresAt int64 `json:"expiresAt"`
AuthMethod string `json:"authMethod,omitempty"`
Provider string `json:"provider,omitempty"`
}
type ExportSubscription struct {
Type string `json:"type"`
Title string `json:"title,omitempty"`
}
type ExportUsage struct {
Current float64 `json:"current"`
Limit float64 `json:"limit"`
PercentUsed float64 `json:"percentUsed"`
LastUpdated int64 `json:"lastUpdated"`
}
type ExportAccount struct {
ID string `json:"id"`
Email string `json:"email"`
Nickname string `json:"nickname,omitempty"`
Idp string `json:"idp"`
UserId string `json:"userId,omitempty"`
MachineId string `json:"machineId,omitempty"`
Credentials ExportCredentials `json:"credentials"`
Subscription ExportSubscription `json:"subscription"`
Usage ExportUsage `json:"usage"`
Tags []string `json:"tags"`
Status string `json:"status"`
CreatedAt int64 `json:"createdAt"`
LastUsedAt int64 `json:"lastUsedAt"`
}
type ExportData struct {
Version string `json:"version"`
ExportedAt int64 `json:"exportedAt"`
Accounts []ExportAccount `json:"accounts"`
Groups []interface{} `json:"groups"`
Tags []interface{} `json:"tags"`
}
exportAccounts := make([]ExportAccount, 0, len(accounts))
for _, a := range accounts {
// 映射 provider 到 idp
idp := a.Provider
if idp == "" {
if a.AuthMethod == "social" {
idp = "Google"
} else {
idp = "BuilderId"
}
}
// 映射 authMethod
authMethod := a.AuthMethod
if authMethod == "idc" {
authMethod = "IdC"
}
// 映射订阅类型
subType := "Free"
rawType := strings.ToUpper(a.SubscriptionType)
if strings.Contains(rawType, "PRO_PLUS") || strings.Contains(rawType, "PROPLUS") {
subType = "Pro_Plus"
} else if strings.Contains(rawType, "PRO") {
subType = "Pro"
} else if strings.Contains(rawType, "POWER") {
subType = "Pro_Plus"
}
exportAccounts = append(exportAccounts, ExportAccount{
ID: a.ID,
Email: a.Email,
Nickname: a.Nickname,
Idp: idp,
UserId: a.UserId,
MachineId: a.MachineId,
Credentials: ExportCredentials{
AccessToken: a.AccessToken,
CsrfToken: "",
RefreshToken: a.RefreshToken,
ClientID: a.ClientID,
ClientSecret: a.ClientSecret,
Region: a.Region,
ExpiresAt: a.ExpiresAt * 1000, // 转为毫秒时间戳
AuthMethod: authMethod,
Provider: a.Provider,
},
Subscription: ExportSubscription{
Type: subType,
Title: a.SubscriptionTitle,
},
Usage: ExportUsage{
Current: a.UsageCurrent,
Limit: a.UsageLimit,
PercentUsed: a.UsagePercent,
LastUpdated: time.Now().UnixMilli(),
},
Tags: []string{},
Status: "active",
CreatedAt: time.Now().UnixMilli(),
LastUsedAt: time.Now().UnixMilli(),
})
}
data := ExportData{
Version: config.Version,
ExportedAt: time.Now().UnixMilli(),
Accounts: exportAccounts,
Groups: []interface{}{},
Tags: []interface{}{},
}
json.NewEncoder(w).Encode(data)
}