Files
kirogo/proxy/handler.go

1512 lines
42 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package proxy
import (
"encoding/json"
"fmt"
"io"
"kiro-api-proxy/auth"
"kiro-api-proxy/config"
"kiro-api-proxy/pool"
"net/http"
"strings"
"time"
"github.com/google/uuid"
)
// Handler HTTP 处理器
type Handler struct {
pool *pool.AccountPool
// 运行时统计
totalRequests int
successRequests int
failedRequests int
totalTokens int
totalCredits float64
startTime int64
stopRefresh chan struct{}
}
func NewHandler() *Handler {
totalReq, successReq, failedReq, totalTokens, totalCredits := config.GetStats()
h := &Handler{
pool: pool.GetPool(),
totalRequests: totalReq,
successRequests: successReq,
failedRequests: failedReq,
totalTokens: totalTokens,
totalCredits: totalCredits,
startTime: time.Now().Unix(),
stopRefresh: make(chan struct{}),
}
// 启动后台刷新
go h.backgroundRefresh()
return h
}
// backgroundRefresh 后台定时刷新账户信息
func (h *Handler) backgroundRefresh() {
ticker := time.NewTicker(30 * time.Minute) // 每 30 分钟刷新一次
defer ticker.Stop()
// 启动时延迟 10 秒后执行一次
time.Sleep(10 * time.Second)
h.refreshAllAccounts()
for {
select {
case <-ticker.C:
h.refreshAllAccounts()
case <-h.stopRefresh:
return
}
}
}
// refreshAllAccounts 刷新所有账户信息
func (h *Handler) refreshAllAccounts() {
accounts := config.GetAccounts()
for i := range accounts {
account := &accounts[i]
if !account.Enabled || account.AccessToken == "" {
continue
}
// 检查 token 是否需要刷新
if account.ExpiresAt > 0 && time.Now().Unix() > account.ExpiresAt-300 {
newAccessToken, newRefreshToken, newExpiresAt, err := auth.RefreshToken(account)
if err != nil {
fmt.Printf("[BackgroundRefresh] Token refresh failed for %s: %v\n", account.Email, err)
continue
}
account.AccessToken = newAccessToken
if newRefreshToken != "" {
account.RefreshToken = newRefreshToken
}
account.ExpiresAt = newExpiresAt
config.UpdateAccountToken(account.ID, newAccessToken, newRefreshToken, newExpiresAt)
h.pool.UpdateToken(account.ID, newAccessToken, newRefreshToken, newExpiresAt)
}
// 刷新账户信息
info, err := RefreshAccountInfo(account)
if err != nil {
fmt.Printf("[BackgroundRefresh] Failed to refresh %s: %v\n", account.Email, err)
continue
}
config.UpdateAccountInfo(account.ID, *info)
fmt.Printf("[BackgroundRefresh] Refreshed %s: %s %.1f/%.1f\n", account.Email, info.SubscriptionType, info.UsageCurrent, info.UsageLimit)
}
h.pool.Reload()
}
// validateApiKey 验证 API Key
func (h *Handler) validateApiKey(r *http.Request) bool {
if !config.IsApiKeyRequired() {
return true
}
expectedKey := config.GetApiKey()
if expectedKey == "" {
return true
}
// 从 Authorization 头或 X-Api-Key 头获取
authHeader := r.Header.Get("Authorization")
apiKeyHeader := r.Header.Get("X-Api-Key")
var providedKey string
if strings.HasPrefix(authHeader, "Bearer ") {
providedKey = strings.TrimPrefix(authHeader, "Bearer ")
} else if apiKeyHeader != "" {
providedKey = apiKeyHeader
}
return providedKey == expectedKey
}
// ServeHTTP 路由分发
func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
path := r.URL.Path
// CORS - 完整的头部支持
w.Header().Set("Access-Control-Allow-Origin", "*")
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization, X-Api-Key, anthropic-version, anthropic-beta, x-api-key, x-stainless-os, x-stainless-lang, x-stainless-package-version, x-stainless-runtime, x-stainless-runtime-version, x-stainless-arch")
w.Header().Set("Access-Control-Expose-Headers", "x-request-id, x-ratelimit-limit-requests, x-ratelimit-limit-tokens, x-ratelimit-remaining-requests, x-ratelimit-remaining-tokens, x-ratelimit-reset-requests, x-ratelimit-reset-tokens")
if r.Method == "OPTIONS" {
w.WriteHeader(204)
return
}
// 路由
switch {
// API 端点(需要验证 API Key
case path == "/v1/messages" || path == "/messages" || path == "/anthropic/v1/messages":
if !h.validateApiKey(r) {
h.sendClaudeError(w, 401, "authentication_error", "Invalid or missing API key")
return
}
h.handleClaudeMessages(w, r)
case path == "/v1/messages/count_tokens" || path == "/messages/count_tokens":
if !h.validateApiKey(r) {
h.sendClaudeError(w, 401, "authentication_error", "Invalid or missing API key")
return
}
h.handleCountTokens(w, r)
case path == "/v1/chat/completions" || path == "/chat/completions":
if !h.validateApiKey(r) {
h.sendOpenAIError(w, 401, "authentication_error", "Invalid or missing API key")
return
}
h.handleOpenAIChat(w, r)
case path == "/v1/models" || path == "/models":
h.handleModels(w, r)
case path == "/api/event_logging/batch":
// Claude Code 遥测端点 - 直接返回 200 OK
w.Header().Set("Content-Type", "application/json; charset=utf-8")
w.Write([]byte(`{"status":"ok"}`))
// 管理端点
case path == "/admin" || path == "/admin/":
h.serveAdminPage(w, r)
case strings.HasPrefix(path, "/admin/api/"):
h.handleAdminAPI(w, r)
case strings.HasPrefix(path, "/admin/"):
h.serveStaticFile(w, r)
// 健康检查
case path == "/health" || path == "/":
h.handleHealth(w, r)
default:
http.Error(w, "Not Found", 404)
}
}
// handleHealth 健康检查
func (h *Handler) handleHealth(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json; charset=utf-8")
json.NewEncoder(w).Encode(map[string]interface{}{
"status": "ok",
"accounts": h.pool.Count(),
"available": h.pool.AvailableCount(),
"totalRequests": h.totalRequests,
"successRequests": h.successRequests,
"failedRequests": h.failedRequests,
"totalTokens": h.totalTokens,
"totalCredits": h.totalCredits,
"uptime": time.Now().Unix() - h.startTime,
})
}
// handleModels 模型列表
func (h *Handler) handleModels(w http.ResponseWriter, r *http.Request) {
models := []map[string]interface{}{
{"id": "claude-sonnet-4.5", "object": "model", "owned_by": "anthropic"},
{"id": "claude-sonnet-4", "object": "model", "owned_by": "anthropic"},
{"id": "claude-haiku-4.5", "object": "model", "owned_by": "anthropic"},
{"id": "claude-opus-4.5", "object": "model", "owned_by": "anthropic"},
{"id": "auto", "object": "model", "owned_by": "kiro-api"},
{"id": "gpt-4o", "object": "model", "owned_by": "kiro-proxy"},
{"id": "gpt-4", "object": "model", "owned_by": "kiro-proxy"},
}
w.Header().Set("Content-Type", "application/json; charset=utf-8")
json.NewEncoder(w).Encode(map[string]interface{}{
"object": "list",
"data": models,
})
}
// handleCountTokens Token 计数Claude Code 会调用)
func (h *Handler) handleCountTokens(w http.ResponseWriter, r *http.Request) {
if r.Method != "POST" {
http.Error(w, "Method Not Allowed", 405)
return
}
body, err := io.ReadAll(r.Body)
if err != nil {
h.sendClaudeError(w, 400, "invalid_request_error", "Failed to read request body")
return
}
var req struct {
Messages []struct {
Role string `json:"role"`
Content interface{} `json:"content"`
} `json:"messages"`
System interface{} `json:"system"`
}
if err := json.Unmarshal(body, &req); err != nil {
h.sendClaudeError(w, 400, "invalid_request_error", "Invalid JSON")
return
}
// 简单估算 token 数量(每 4 个字符约 1 个 token
var totalChars int
for _, msg := range req.Messages {
switch content := msg.Content.(type) {
case string:
totalChars += len(content)
case []interface{}:
for _, part := range content {
if p, ok := part.(map[string]interface{}); ok {
if text, ok := p["text"].(string); ok {
totalChars += len(text)
}
}
}
}
}
// 系统提示
switch system := req.System.(type) {
case string:
totalChars += len(system)
case []interface{}:
for _, part := range system {
if p, ok := part.(map[string]interface{}); ok {
if text, ok := p["text"].(string); ok {
totalChars += len(text)
}
}
}
}
estimatedTokens := (totalChars + 3) / 4 // 向上取整
if estimatedTokens < 1 {
estimatedTokens = 1
}
w.Header().Set("Content-Type", "application/json; charset=utf-8")
json.NewEncoder(w).Encode(map[string]int{"input_tokens": estimatedTokens})
}
// handleClaudeMessages Claude API 处理
func (h *Handler) handleClaudeMessages(w http.ResponseWriter, r *http.Request) {
if r.Method != "POST" {
http.Error(w, "Method Not Allowed", 405)
return
}
// 读取请求
body, err := io.ReadAll(r.Body)
if err != nil {
h.sendClaudeError(w, 400, "invalid_request_error", "Failed to read request body")
return
}
var req ClaudeRequest
if err := json.Unmarshal(body, &req); err != nil {
h.sendClaudeError(w, 400, "invalid_request_error", "Invalid JSON: "+err.Error())
return
}
// 获取账号
account := h.pool.GetNext()
if account == nil {
h.sendClaudeError(w, 503, "api_error", "No available accounts")
return
}
// 检查并刷新 token
if err := h.ensureValidToken(account); err != nil {
h.sendClaudeError(w, 503, "api_error", "Token refresh failed: "+err.Error())
return
}
// 转换请求
kiroPayload := ClaudeToKiro(&req)
// 流式或非流式
if req.Stream {
h.handleClaudeStream(w, account, kiroPayload, req.Model)
} else {
h.handleClaudeNonStream(w, account, kiroPayload, req.Model)
}
}
// handleClaudeStream Claude 流式响应
func (h *Handler) handleClaudeStream(w http.ResponseWriter, account *config.Account, payload *KiroPayload, model string) {
w.Header().Set("Content-Type", "text/event-stream; charset=utf-8")
w.Header().Set("Cache-Control", "no-cache")
w.Header().Set("Connection", "keep-alive")
flusher, ok := w.(http.Flusher)
if !ok {
h.sendClaudeError(w, 500, "api_error", "Streaming not supported")
return
}
msgID := "msg_" + uuid.New().String()
var contentStarted bool
var toolUseIndex int
var inputTokens, outputTokens int
var credits float64
var toolUses []KiroToolUse
// 发送 message_start
h.sendSSE(w, flusher, "message_start", map[string]interface{}{
"type": "message_start",
"message": map[string]interface{}{
"id": msgID,
"type": "message",
"role": "assistant",
"content": []interface{}{},
"model": model,
},
})
callback := &KiroStreamCallback{
OnText: func(text string, isThinking bool) {
if text == "" {
return
}
// 确保 content_block 已开始
if !contentStarted {
h.sendSSE(w, flusher, "content_block_start", map[string]interface{}{
"type": "content_block_start",
"index": 0,
"content_block": map[string]string{"type": "text", "text": ""},
})
contentStarted = true
}
// 直接转发文本,不缓冲
outputText := text
if isThinking {
outputText = "<thinking>" + text + "</thinking>"
}
h.sendSSE(w, flusher, "content_block_delta", map[string]interface{}{
"type": "content_block_delta",
"index": 0,
"delta": map[string]string{"type": "text_delta", "text": outputText},
})
},
OnToolUse: func(tu KiroToolUse) {
toolUses = append(toolUses, tu)
// 关闭文本块
if contentStarted && toolUseIndex == 0 {
h.sendSSE(w, flusher, "content_block_stop", map[string]interface{}{
"type": "content_block_stop",
"index": 0,
})
}
idx := toolUseIndex
if contentStarted {
idx = toolUseIndex + 1
}
toolUseIndex++
h.sendSSE(w, flusher, "content_block_start", map[string]interface{}{
"type": "content_block_start",
"index": idx,
"content_block": map[string]interface{}{
"type": "tool_use",
"id": tu.ToolUseID,
"name": tu.Name,
"input": map[string]interface{}{},
},
})
inputJSON, _ := json.Marshal(tu.Input)
h.sendSSE(w, flusher, "content_block_delta", map[string]interface{}{
"type": "content_block_delta",
"index": idx,
"delta": map[string]interface{}{
"type": "input_json_delta",
"partial_json": string(inputJSON),
},
})
h.sendSSE(w, flusher, "content_block_stop", map[string]interface{}{
"type": "content_block_stop",
"index": idx,
})
},
OnComplete: func(inTok, outTok int) {
inputTokens = inTok
outputTokens = outTok
},
OnError: func(err error) {
h.pool.RecordError(account.ID, strings.Contains(err.Error(), "429") || strings.Contains(err.Error(), "quota"))
},
OnCredits: func(c float64) {
credits = c
},
}
err := CallKiroAPI(account, payload, callback)
if err != nil {
h.recordFailure()
h.pool.RecordError(account.ID, strings.Contains(err.Error(), "429") || strings.Contains(err.Error(), "quota"))
h.sendSSE(w, flusher, "error", map[string]interface{}{
"type": "error",
"error": map[string]string{"type": "api_error", "message": err.Error()},
})
return
}
h.recordSuccess(inputTokens, outputTokens, credits)
h.pool.RecordSuccess(account.ID)
h.pool.UpdateStats(account.ID, inputTokens+outputTokens, credits)
// 关闭最后的内容块
if contentStarted && toolUseIndex == 0 {
h.sendSSE(w, flusher, "content_block_stop", map[string]interface{}{
"type": "content_block_stop",
"index": 0,
})
}
// 发送 message_delta
stopReason := "end_turn"
if len(toolUses) > 0 {
stopReason = "tool_use"
}
h.sendSSE(w, flusher, "message_delta", map[string]interface{}{
"type": "message_delta",
"delta": map[string]interface{}{
"stop_reason": stopReason,
},
"usage": map[string]int{
"input_tokens": inputTokens,
"output_tokens": outputTokens,
},
})
h.sendSSE(w, flusher, "message_stop", map[string]interface{}{
"type": "message_stop",
})
}
func (h *Handler) sendSSE(w http.ResponseWriter, flusher http.Flusher, event string, data interface{}) {
jsonData, _ := json.Marshal(data)
fmt.Fprintf(w, "event: %s\ndata: %s\n\n", event, string(jsonData))
flusher.Flush()
}
// 统计记录
func (h *Handler) recordSuccess(inputTokens, outputTokens int, credits float64) {
h.totalRequests++
h.successRequests++
h.totalTokens += inputTokens + outputTokens
h.totalCredits += credits
// 异步保存
go config.UpdateStats(h.totalRequests, h.successRequests, h.failedRequests, h.totalTokens, h.totalCredits)
}
func (h *Handler) recordFailure() {
h.totalRequests++
h.failedRequests++
go config.UpdateStats(h.totalRequests, h.successRequests, h.failedRequests, h.totalTokens, h.totalCredits)
}
// handleClaudeNonStream Claude 非流式响应
func (h *Handler) handleClaudeNonStream(w http.ResponseWriter, account *config.Account, payload *KiroPayload, model string) {
var content string
var toolUses []KiroToolUse
var inputTokens, outputTokens int
var credits float64
callback := &KiroStreamCallback{
OnText: func(text string, isThinking bool) {
if isThinking {
content += "<thinking>" + text + "</thinking>"
} else {
content += text
}
},
OnToolUse: func(tu KiroToolUse) {
toolUses = append(toolUses, tu)
},
OnComplete: func(inTok, outTok int) {
inputTokens = inTok
outputTokens = outTok
},
OnError: func(err error) {
h.pool.RecordError(account.ID, strings.Contains(err.Error(), "429"))
},
OnCredits: func(c float64) {
credits = c
},
}
err := CallKiroAPI(account, payload, callback)
if err != nil {
h.recordFailure()
h.pool.RecordError(account.ID, strings.Contains(err.Error(), "429"))
h.sendClaudeError(w, 500, "api_error", err.Error())
return
}
h.recordSuccess(inputTokens, outputTokens, credits)
h.pool.RecordSuccess(account.ID)
h.pool.UpdateStats(account.ID, inputTokens+outputTokens, credits)
resp := KiroToClaudeResponse(content, toolUses, inputTokens, outputTokens, model)
w.Header().Set("Content-Type", "application/json; charset=utf-8")
json.NewEncoder(w).Encode(resp)
}
func (h *Handler) sendClaudeError(w http.ResponseWriter, status int, errType, message string) {
w.Header().Set("Content-Type", "application/json; charset=utf-8")
w.WriteHeader(status)
json.NewEncoder(w).Encode(map[string]interface{}{
"type": "error",
"error": map[string]string{
"type": errType,
"message": message,
},
})
}
// handleOpenAIChat OpenAI API 处理
func (h *Handler) handleOpenAIChat(w http.ResponseWriter, r *http.Request) {
if r.Method != "POST" {
http.Error(w, "Method Not Allowed", 405)
return
}
body, err := io.ReadAll(r.Body)
if err != nil {
h.sendOpenAIError(w, 400, "invalid_request_error", "Failed to read request body")
return
}
var req OpenAIRequest
if err := json.Unmarshal(body, &req); err != nil {
h.sendOpenAIError(w, 400, "invalid_request_error", "Invalid JSON")
return
}
account := h.pool.GetNext()
if account == nil {
h.sendOpenAIError(w, 503, "server_error", "No available accounts")
return
}
if err := h.ensureValidToken(account); err != nil {
h.sendOpenAIError(w, 503, "server_error", "Token refresh failed")
return
}
kiroPayload := OpenAIToKiro(&req)
if req.Stream {
h.handleOpenAIStream(w, account, kiroPayload, req.Model)
} else {
h.handleOpenAINonStream(w, account, kiroPayload, req.Model)
}
}
// handleOpenAIStream OpenAI 流式响应
func (h *Handler) handleOpenAIStream(w http.ResponseWriter, account *config.Account, payload *KiroPayload, model string) {
w.Header().Set("Content-Type", "text/event-stream; charset=utf-8")
w.Header().Set("Cache-Control", "no-cache")
w.Header().Set("Connection", "keep-alive")
flusher, ok := w.(http.Flusher)
if !ok {
h.sendOpenAIError(w, 500, "server_error", "Streaming not supported")
return
}
chatID := "chatcmpl-" + uuid.New().String()
var toolCalls []ToolCall
var toolCallIndex int
var inputTokens, outputTokens int
var credits float64
callback := &KiroStreamCallback{
OnText: func(text string, isThinking bool) {
if text == "" {
return
}
// 直接转发,不缓冲
deltaKey := "content"
if isThinking {
deltaKey = "reasoning_content"
}
chunk := map[string]interface{}{
"id": chatID,
"object": "chat.completion.chunk",
"created": time.Now().Unix(),
"model": model,
"choices": []map[string]interface{}{{
"index": 0,
"delta": map[string]string{deltaKey: text},
"finish_reason": nil,
}},
}
data, _ := json.Marshal(chunk)
fmt.Fprintf(w, "data: %s\n\n", string(data))
flusher.Flush()
},
OnToolUse: func(tu KiroToolUse) {
args, _ := json.Marshal(tu.Input)
tc := ToolCall{ID: tu.ToolUseID, Type: "function"}
tc.Function.Name = tu.Name
tc.Function.Arguments = string(args)
toolCalls = append(toolCalls, tc)
chunk := map[string]interface{}{
"id": chatID,
"object": "chat.completion.chunk",
"created": time.Now().Unix(),
"model": model,
"choices": []map[string]interface{}{{
"index": 0,
"delta": map[string]interface{}{
"tool_calls": []map[string]interface{}{{
"index": toolCallIndex,
"id": tu.ToolUseID,
"type": "function",
"function": map[string]string{
"name": tu.Name,
"arguments": string(args),
},
}},
},
"finish_reason": nil,
}},
}
toolCallIndex++
data, _ := json.Marshal(chunk)
fmt.Fprintf(w, "data: %s\n\n", string(data))
flusher.Flush()
},
OnComplete: func(inTok, outTok int) {
inputTokens = inTok
outputTokens = outTok
},
OnError: func(err error) {
h.pool.RecordError(account.ID, strings.Contains(err.Error(), "429"))
},
OnCredits: func(c float64) {
credits = c
},
}
err := CallKiroAPI(account, payload, callback)
if err != nil {
h.recordFailure()
h.pool.RecordError(account.ID, strings.Contains(err.Error(), "429"))
return
}
h.recordSuccess(inputTokens, outputTokens, credits)
h.pool.RecordSuccess(account.ID)
h.pool.UpdateStats(account.ID, inputTokens+outputTokens, credits)
// 发送结束
finishReason := "stop"
if len(toolCalls) > 0 {
finishReason = "tool_calls"
}
chunk := map[string]interface{}{
"id": chatID,
"object": "chat.completion.chunk",
"created": time.Now().Unix(),
"model": model,
"choices": []map[string]interface{}{{
"index": 0,
"delta": map[string]interface{}{},
"finish_reason": finishReason,
}},
}
data, _ := json.Marshal(chunk)
fmt.Fprintf(w, "data: %s\n\n", string(data))
fmt.Fprintf(w, "data: [DONE]\n\n")
flusher.Flush()
}
// handleOpenAINonStream OpenAI 非流式响应
func (h *Handler) handleOpenAINonStream(w http.ResponseWriter, account *config.Account, payload *KiroPayload, model string) {
var content string
var toolUses []KiroToolUse
var inputTokens, outputTokens int
var credits float64
callback := &KiroStreamCallback{
OnText: func(text string, isThinking bool) {
if isThinking {
// 非流式模式下thinking 内容可以作为单独字段或忽略
// 这里暂时忽略
} else {
content += text
}
},
OnToolUse: func(tu KiroToolUse) { toolUses = append(toolUses, tu) },
OnComplete: func(inTok, outTok int) { inputTokens = inTok; outputTokens = outTok },
OnError: func(err error) { h.pool.RecordError(account.ID, strings.Contains(err.Error(), "429")) },
OnCredits: func(c float64) { credits = c },
}
err := CallKiroAPI(account, payload, callback)
if err != nil {
h.recordFailure()
h.pool.RecordError(account.ID, strings.Contains(err.Error(), "429"))
h.sendOpenAIError(w, 500, "server_error", err.Error())
return
}
h.recordSuccess(inputTokens, outputTokens, credits)
h.pool.RecordSuccess(account.ID)
h.pool.UpdateStats(account.ID, inputTokens+outputTokens, credits)
resp := KiroToOpenAIResponse(content, toolUses, inputTokens, outputTokens, model)
w.Header().Set("Content-Type", "application/json; charset=utf-8")
json.NewEncoder(w).Encode(resp)
}
func (h *Handler) sendOpenAIError(w http.ResponseWriter, status int, errType, message string) {
w.Header().Set("Content-Type", "application/json; charset=utf-8")
w.WriteHeader(status)
json.NewEncoder(w).Encode(map[string]interface{}{
"error": map[string]interface{}{
"type": errType,
"message": message,
},
})
}
// ensureValidToken 确保 token 有效
func (h *Handler) ensureValidToken(account *config.Account) error {
if account.ExpiresAt == 0 || time.Now().Unix() < account.ExpiresAt-300 {
return nil
}
accessToken, refreshToken, expiresAt, err := auth.RefreshToken(account)
if err != nil {
return err
}
// 更新内存
h.pool.UpdateToken(account.ID, accessToken, refreshToken, expiresAt)
account.AccessToken = accessToken
if refreshToken != "" {
account.RefreshToken = refreshToken
}
account.ExpiresAt = expiresAt
// 持久化
config.UpdateAccountToken(account.ID, accessToken, refreshToken, expiresAt)
return nil
}
// ==================== 管理 API ====================
func (h *Handler) handleAdminAPI(w http.ResponseWriter, r *http.Request) {
// 验证密码
password := r.Header.Get("X-Admin-Password")
if password == "" {
cookie, _ := r.Cookie("admin_password")
if cookie != nil {
password = cookie.Value
}
}
if password != config.GetPassword() {
w.WriteHeader(401)
json.NewEncoder(w).Encode(map[string]string{"error": "Unauthorized"})
return
}
path := strings.TrimPrefix(r.URL.Path, "/admin/api")
w.Header().Set("Content-Type", "application/json; charset=utf-8")
switch {
case path == "/accounts" && r.Method == "GET":
h.apiGetAccounts(w, r)
case path == "/accounts" && r.Method == "POST":
h.apiAddAccount(w, r)
case strings.HasPrefix(path, "/accounts/") && strings.HasSuffix(path, "/refresh") && r.Method == "POST":
id := strings.TrimSuffix(strings.TrimPrefix(path, "/accounts/"), "/refresh")
h.apiRefreshAccount(w, r, id)
case strings.HasPrefix(path, "/accounts/") && strings.HasSuffix(path, "/models") && r.Method == "GET":
id := strings.TrimSuffix(strings.TrimPrefix(path, "/accounts/"), "/models")
h.apiGetAccountModels(w, r, id)
case strings.HasPrefix(path, "/accounts/") && r.Method == "DELETE":
h.apiDeleteAccount(w, r, strings.TrimPrefix(path, "/accounts/"))
case strings.HasPrefix(path, "/accounts/") && r.Method == "PUT":
h.apiUpdateAccount(w, r, strings.TrimPrefix(path, "/accounts/"))
case path == "/auth/iam-sso/start" && r.Method == "POST":
h.apiStartIamSso(w, r)
case path == "/auth/iam-sso/complete" && r.Method == "POST":
h.apiCompleteIamSso(w, r)
case path == "/auth/builderid/start" && r.Method == "POST":
h.apiStartBuilderIdLogin(w, r)
case path == "/auth/builderid/poll" && r.Method == "POST":
h.apiPollBuilderIdAuth(w, r)
case path == "/auth/sso-token" && r.Method == "POST":
h.apiImportSsoToken(w, r)
case path == "/auth/credentials" && r.Method == "POST":
h.apiImportCredentials(w, r)
case path == "/status" && r.Method == "GET":
h.apiGetStatus(w, r)
case path == "/settings" && r.Method == "GET":
h.apiGetSettings(w, r)
case path == "/settings" && r.Method == "POST":
h.apiUpdateSettings(w, r)
case path == "/stats" && r.Method == "GET":
h.apiGetStats(w, r)
case path == "/stats/reset" && r.Method == "POST":
h.apiResetStats(w, r)
case path == "/generate-machine-id" && r.Method == "GET":
h.apiGenerateMachineId(w, r)
default:
w.WriteHeader(404)
json.NewEncoder(w).Encode(map[string]string{"error": "Not Found"})
}
}
func (h *Handler) apiGetAccounts(w http.ResponseWriter, r *http.Request) {
accounts := config.GetAccounts()
poolAccounts := h.pool.GetAllAccounts()
// 合并运行时统计
statsMap := make(map[string]config.Account)
for _, a := range poolAccounts {
statsMap[a.ID] = a
}
// 隐藏敏感信息
result := make([]map[string]interface{}, len(accounts))
for i, a := range accounts {
// 获取运行时统计
stats := statsMap[a.ID]
result[i] = map[string]interface{}{
"id": a.ID,
"email": a.Email,
"userId": a.UserId,
"nickname": a.Nickname,
"authMethod": a.AuthMethod,
"provider": a.Provider,
"region": a.Region,
"enabled": a.Enabled,
"expiresAt": a.ExpiresAt,
"hasToken": a.AccessToken != "",
"machineId": a.MachineId,
"subscriptionType": a.SubscriptionType,
"subscriptionTitle": a.SubscriptionTitle,
"daysRemaining": a.DaysRemaining,
"usageCurrent": a.UsageCurrent,
"usageLimit": a.UsageLimit,
"usagePercent": a.UsagePercent,
"nextResetDate": a.NextResetDate,
"lastRefresh": a.LastRefresh,
"requestCount": stats.RequestCount,
"errorCount": stats.ErrorCount,
"totalTokens": stats.TotalTokens,
"totalCredits": stats.TotalCredits,
"lastUsed": stats.LastUsed,
}
}
json.NewEncoder(w).Encode(result)
}
func (h *Handler) apiAddAccount(w http.ResponseWriter, r *http.Request) {
var account config.Account
if err := json.NewDecoder(r.Body).Decode(&account); err != nil {
w.WriteHeader(400)
json.NewEncoder(w).Encode(map[string]string{"error": "Invalid JSON"})
return
}
if account.ID == "" {
account.ID = auth.GenerateAccountID()
}
if account.Region == "" {
account.Region = "us-east-1"
}
if err := config.AddAccount(account); err != nil {
w.WriteHeader(500)
json.NewEncoder(w).Encode(map[string]string{"error": err.Error()})
return
}
h.pool.Reload()
json.NewEncoder(w).Encode(map[string]interface{}{"success": true, "id": account.ID})
}
func (h *Handler) apiDeleteAccount(w http.ResponseWriter, r *http.Request, id string) {
if err := config.DeleteAccount(id); err != nil {
w.WriteHeader(500)
json.NewEncoder(w).Encode(map[string]string{"error": err.Error()})
return
}
h.pool.Reload()
json.NewEncoder(w).Encode(map[string]bool{"success": true})
}
func (h *Handler) apiUpdateAccount(w http.ResponseWriter, r *http.Request, id string) {
var updates map[string]interface{}
if err := json.NewDecoder(r.Body).Decode(&updates); err != nil {
w.WriteHeader(400)
json.NewEncoder(w).Encode(map[string]string{"error": "Invalid JSON"})
return
}
// 获取现有账号
accounts := config.GetAccounts()
var existing *config.Account
for i := range accounts {
if accounts[i].ID == id {
existing = &accounts[i]
break
}
}
if existing == nil {
w.WriteHeader(404)
json.NewEncoder(w).Encode(map[string]string{"error": "Account not found"})
return
}
// 只更新传入的字段
if v, ok := updates["enabled"].(bool); ok {
existing.Enabled = v
}
if v, ok := updates["nickname"].(string); ok {
existing.Nickname = v
}
if v, ok := updates["machineId"].(string); ok {
existing.MachineId = v
}
if err := config.UpdateAccount(id, *existing); err != nil {
w.WriteHeader(500)
json.NewEncoder(w).Encode(map[string]string{"error": err.Error()})
return
}
h.pool.Reload()
json.NewEncoder(w).Encode(map[string]bool{"success": true})
}
func (h *Handler) apiStartIamSso(w http.ResponseWriter, r *http.Request) {
var req struct {
StartUrl string `json:"startUrl"`
Region string `json:"region"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
w.WriteHeader(400)
json.NewEncoder(w).Encode(map[string]string{"error": "Invalid JSON"})
return
}
if req.StartUrl == "" {
w.WriteHeader(400)
json.NewEncoder(w).Encode(map[string]string{"error": "startUrl is required"})
return
}
sessionID, authorizeUrl, expiresIn, err := auth.StartIamSsoLogin(req.StartUrl, req.Region)
if err != nil {
w.WriteHeader(500)
json.NewEncoder(w).Encode(map[string]string{"error": err.Error()})
return
}
json.NewEncoder(w).Encode(map[string]interface{}{
"sessionId": sessionID,
"authorizeUrl": authorizeUrl,
"expiresIn": expiresIn,
})
}
func (h *Handler) apiCompleteIamSso(w http.ResponseWriter, r *http.Request) {
var req struct {
SessionID string `json:"sessionId"`
CallbackUrl string `json:"callbackUrl"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
w.WriteHeader(400)
json.NewEncoder(w).Encode(map[string]string{"error": "Invalid JSON"})
return
}
accessToken, refreshToken, clientID, clientSecret, region, expiresIn, err := auth.CompleteIamSsoLogin(req.SessionID, req.CallbackUrl)
if err != nil {
w.WriteHeader(400)
json.NewEncoder(w).Encode(map[string]string{"error": err.Error()})
return
}
// 获取用户信息
email, _, _ := auth.GetUserInfo(accessToken)
// 创建账号
account := config.Account{
ID: auth.GenerateAccountID(),
Email: email,
AccessToken: accessToken,
RefreshToken: refreshToken,
ClientID: clientID,
ClientSecret: clientSecret,
AuthMethod: "idc",
Region: region,
ExpiresAt: time.Now().Unix() + int64(expiresIn),
Enabled: true,
MachineId: config.GenerateMachineId(),
}
if err := config.AddAccount(account); err != nil {
w.WriteHeader(500)
json.NewEncoder(w).Encode(map[string]string{"error": err.Error()})
return
}
h.pool.Reload()
json.NewEncoder(w).Encode(map[string]interface{}{
"success": true,
"account": map[string]interface{}{
"id": account.ID,
"email": account.Email,
},
})
}
func (h *Handler) apiStartBuilderIdLogin(w http.ResponseWriter, r *http.Request) {
var req struct {
Region string `json:"region"`
}
json.NewDecoder(r.Body).Decode(&req)
session, err := auth.StartBuilderIdLogin(req.Region)
if err != nil {
w.WriteHeader(500)
json.NewEncoder(w).Encode(map[string]string{"error": err.Error()})
return
}
json.NewEncoder(w).Encode(map[string]interface{}{
"sessionId": session.ID,
"userCode": session.UserCode,
"verificationUri": session.VerificationUri,
"interval": session.Interval,
})
}
func (h *Handler) apiPollBuilderIdAuth(w http.ResponseWriter, r *http.Request) {
var req struct {
SessionID string `json:"sessionId"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
w.WriteHeader(400)
json.NewEncoder(w).Encode(map[string]string{"error": "Invalid JSON"})
return
}
accessToken, refreshToken, clientID, clientSecret, region, expiresIn, status, err := auth.PollBuilderIdAuth(req.SessionID)
if err != nil {
w.WriteHeader(400)
json.NewEncoder(w).Encode(map[string]interface{}{
"success": false,
"error": err.Error(),
})
return
}
if status == "pending" || status == "slow_down" {
// 获取当前间隔
interval := 5
if session := auth.GetBuilderIdSession(req.SessionID); session != nil {
interval = session.Interval
}
json.NewEncoder(w).Encode(map[string]interface{}{
"success": true,
"completed": false,
"status": status,
"interval": interval,
})
return
}
// 授权完成,获取用户信息
email, _, _ := auth.GetUserInfo(accessToken)
// 创建账号
account := config.Account{
ID: auth.GenerateAccountID(),
Email: email,
AccessToken: accessToken,
RefreshToken: refreshToken,
ClientID: clientID,
ClientSecret: clientSecret,
AuthMethod: "idc",
Provider: "BuilderId",
Region: region,
ExpiresAt: time.Now().Unix() + int64(expiresIn),
Enabled: true,
MachineId: config.GenerateMachineId(),
}
if err := config.AddAccount(account); err != nil {
w.WriteHeader(500)
json.NewEncoder(w).Encode(map[string]string{"error": err.Error()})
return
}
h.pool.Reload()
json.NewEncoder(w).Encode(map[string]interface{}{
"success": true,
"completed": true,
"account": map[string]interface{}{
"id": account.ID,
"email": account.Email,
},
})
}
func (h *Handler) apiImportSsoToken(w http.ResponseWriter, r *http.Request) {
var req struct {
BearerToken string `json:"bearerToken"`
Region string `json:"region"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
w.WriteHeader(400)
json.NewEncoder(w).Encode(map[string]string{"error": "Invalid JSON"})
return
}
if req.BearerToken == "" {
w.WriteHeader(400)
json.NewEncoder(w).Encode(map[string]string{"error": "bearerToken is required"})
return
}
// 支持批量导入,按行分割
tokens := strings.Split(strings.TrimSpace(req.BearerToken), "\n")
var imported []map[string]interface{}
var errors []string
for _, token := range tokens {
token = strings.TrimSpace(token)
if token == "" {
continue
}
accessToken, refreshToken, clientID, clientSecret, expiresIn, err := auth.ImportFromSsoToken(token, req.Region)
if err != nil {
errors = append(errors, err.Error())
continue
}
// 获取用户信息
email, _, _ := auth.GetUserInfo(accessToken)
// 创建账号
account := config.Account{
ID: auth.GenerateAccountID(),
Email: email,
AccessToken: accessToken,
RefreshToken: refreshToken,
ClientID: clientID,
ClientSecret: clientSecret,
AuthMethod: "idc",
Region: req.Region,
ExpiresAt: time.Now().Unix() + int64(expiresIn),
Enabled: true,
MachineId: config.GenerateMachineId(),
}
if err := config.AddAccount(account); err != nil {
errors = append(errors, err.Error())
continue
}
imported = append(imported, map[string]interface{}{
"id": account.ID,
"email": account.Email,
})
}
h.pool.Reload()
if len(imported) == 0 && len(errors) > 0 {
w.WriteHeader(500)
json.NewEncoder(w).Encode(map[string]interface{}{
"success": false,
"error": strings.Join(errors, "; "),
})
return
}
json.NewEncoder(w).Encode(map[string]interface{}{
"success": true,
"accounts": imported,
"errors": errors,
})
}
func (h *Handler) apiImportCredentials(w http.ResponseWriter, r *http.Request) {
var req struct {
AccessToken string `json:"accessToken"`
RefreshToken string `json:"refreshToken"`
ClientID string `json:"clientId"`
ClientSecret string `json:"clientSecret"`
AuthMethod string `json:"authMethod"`
Provider string `json:"provider"`
Region string `json:"region"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
w.WriteHeader(400)
json.NewEncoder(w).Encode(map[string]string{"error": "Invalid JSON"})
return
}
if req.RefreshToken == "" {
w.WriteHeader(400)
json.NewEncoder(w).Encode(map[string]string{"error": "refreshToken is required"})
return
}
// 设置默认值
if req.Region == "" {
req.Region = "us-east-1"
}
if req.AuthMethod == "" {
if req.ClientID != "" {
req.AuthMethod = "idc"
} else {
req.AuthMethod = "social"
}
}
// 如果没有 accessToken尝试刷新获取
accessToken := req.AccessToken
var expiresAt int64
if accessToken == "" {
tempAccount := &config.Account{
RefreshToken: req.RefreshToken,
ClientID: req.ClientID,
ClientSecret: req.ClientSecret,
AuthMethod: req.AuthMethod,
Region: req.Region,
}
newAccessToken, newRefreshToken, newExpiresAt, err := auth.RefreshToken(tempAccount)
if err != nil {
w.WriteHeader(400)
json.NewEncoder(w).Encode(map[string]string{"error": "Token refresh failed: " + err.Error()})
return
}
accessToken = newAccessToken
if newRefreshToken != "" {
req.RefreshToken = newRefreshToken
}
expiresAt = newExpiresAt
} else {
expiresAt = time.Now().Unix() + 3600 // 默认 1 小时
}
// 获取用户信息
email, _, _ := auth.GetUserInfo(accessToken)
// 创建账号
account := config.Account{
ID: auth.GenerateAccountID(),
Email: email,
AccessToken: accessToken,
RefreshToken: req.RefreshToken,
ClientID: req.ClientID,
ClientSecret: req.ClientSecret,
AuthMethod: req.AuthMethod,
Provider: req.Provider,
Region: req.Region,
ExpiresAt: expiresAt,
Enabled: true,
MachineId: config.GenerateMachineId(),
}
if err := config.AddAccount(account); err != nil {
w.WriteHeader(500)
json.NewEncoder(w).Encode(map[string]string{"error": err.Error()})
return
}
h.pool.Reload()
json.NewEncoder(w).Encode(map[string]interface{}{
"success": true,
"account": map[string]interface{}{
"id": account.ID,
"email": account.Email,
},
})
}
func (h *Handler) apiGetStatus(w http.ResponseWriter, r *http.Request) {
json.NewEncoder(w).Encode(map[string]interface{}{
"accounts": h.pool.Count(),
"available": h.pool.AvailableCount(),
"totalRequests": h.totalRequests,
"successRequests": h.successRequests,
"failedRequests": h.failedRequests,
"totalTokens": h.totalTokens,
"totalCredits": h.totalCredits,
"uptime": time.Now().Unix() - h.startTime,
})
}
func (h *Handler) apiGetSettings(w http.ResponseWriter, r *http.Request) {
json.NewEncoder(w).Encode(map[string]interface{}{
"apiKey": config.GetApiKey(),
"requireApiKey": config.IsApiKeyRequired(),
"port": config.GetPort(),
"host": config.GetHost(),
})
}
func (h *Handler) apiUpdateSettings(w http.ResponseWriter, r *http.Request) {
var req struct {
ApiKey string `json:"apiKey"`
RequireApiKey bool `json:"requireApiKey"`
Password string `json:"password"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
w.WriteHeader(400)
json.NewEncoder(w).Encode(map[string]string{"error": "Invalid JSON"})
return
}
if err := config.UpdateSettings(req.ApiKey, req.RequireApiKey, req.Password); err != nil {
w.WriteHeader(500)
json.NewEncoder(w).Encode(map[string]string{"error": err.Error()})
return
}
json.NewEncoder(w).Encode(map[string]bool{"success": true})
}
func (h *Handler) apiGetStats(w http.ResponseWriter, r *http.Request) {
json.NewEncoder(w).Encode(map[string]interface{}{
"totalRequests": h.totalRequests,
"successRequests": h.successRequests,
"failedRequests": h.failedRequests,
"totalTokens": h.totalTokens,
"totalCredits": h.totalCredits,
"uptime": time.Now().Unix() - h.startTime,
})
}
func (h *Handler) apiResetStats(w http.ResponseWriter, r *http.Request) {
h.totalRequests = 0
h.successRequests = 0
h.failedRequests = 0
h.totalTokens = 0
h.totalCredits = 0
config.UpdateStats(0, 0, 0, 0, 0)
json.NewEncoder(w).Encode(map[string]bool{"success": true})
}
// apiGenerateMachineId 生成新的机器码
func (h *Handler) apiGenerateMachineId(w http.ResponseWriter, r *http.Request) {
machineId := config.GenerateMachineId()
json.NewEncoder(w).Encode(map[string]string{"machineId": machineId})
}
// apiRefreshAccount 刷新账户信息(使用量、订阅等)
func (h *Handler) apiRefreshAccount(w http.ResponseWriter, r *http.Request, id string) {
accounts := config.GetAccounts()
var account *config.Account
for i := range accounts {
if accounts[i].ID == id {
account = &accounts[i]
break
}
}
if account == nil {
w.WriteHeader(404)
json.NewEncoder(w).Encode(map[string]string{"error": "Account not found"})
return
}
// 检查 token 是否过期,需要刷新
if account.ExpiresAt > 0 && time.Now().Unix() > account.ExpiresAt-60 {
newAccessToken, newRefreshToken, newExpiresAt, err := auth.RefreshToken(account)
if err != nil {
w.WriteHeader(500)
json.NewEncoder(w).Encode(map[string]string{"error": "Token refresh failed: " + err.Error()})
return
}
account.AccessToken = newAccessToken
if newRefreshToken != "" {
account.RefreshToken = newRefreshToken
}
account.ExpiresAt = newExpiresAt
config.UpdateAccountToken(id, newAccessToken, newRefreshToken, newExpiresAt)
h.pool.UpdateToken(id, newAccessToken, newRefreshToken, newExpiresAt)
}
// 获取账户信息
info, err := RefreshAccountInfo(account)
if err != nil {
w.WriteHeader(500)
json.NewEncoder(w).Encode(map[string]string{"error": err.Error()})
return
}
// 保存到配置
if err := config.UpdateAccountInfo(id, *info); err != nil {
w.WriteHeader(500)
json.NewEncoder(w).Encode(map[string]string{"error": err.Error()})
return
}
json.NewEncoder(w).Encode(map[string]interface{}{
"success": true,
"info": info,
})
}
// apiGetAccountModels 获取账户可用模型
func (h *Handler) apiGetAccountModels(w http.ResponseWriter, r *http.Request, id string) {
accounts := config.GetAccounts()
var account *config.Account
for i := range accounts {
if accounts[i].ID == id {
account = &accounts[i]
break
}
}
if account == nil {
w.WriteHeader(404)
json.NewEncoder(w).Encode(map[string]string{"error": "Account not found"})
return
}
models, err := ListAvailableModels(account)
if err != nil {
w.WriteHeader(500)
json.NewEncoder(w).Encode(map[string]string{"error": err.Error()})
return
}
json.NewEncoder(w).Encode(map[string]interface{}{
"success": true,
"models": models,
})
}
// ==================== 静态文件服务 ====================
func (h *Handler) serveAdminPage(w http.ResponseWriter, r *http.Request) {
http.ServeFile(w, r, "web/index.html")
}
func (h *Handler) serveStaticFile(w http.ResponseWriter, r *http.Request) {
path := strings.TrimPrefix(r.URL.Path, "/admin/")
http.ServeFile(w, r, "web/"+path)
}