First commit

This commit is contained in:
shaw
2025-12-18 13:50:39 +08:00
parent 569f4882e5
commit 642842c29e
218 changed files with 86902 additions and 0 deletions

View File

@@ -0,0 +1,572 @@
package service
import (
"crypto/sha256"
"encoding/hex"
"encoding/json"
"fmt"
"io"
"log"
"net/http"
"os"
"path/filepath"
"strings"
"sync"
"time"
"sub2api/internal/config"
)
// LiteLLMModelPricing LiteLLM价格数据结构
// 只保留我们需要的字段,使用指针来处理可能缺失的值
type LiteLLMModelPricing struct {
InputCostPerToken float64 `json:"input_cost_per_token"`
OutputCostPerToken float64 `json:"output_cost_per_token"`
CacheCreationInputTokenCost float64 `json:"cache_creation_input_token_cost"`
CacheReadInputTokenCost float64 `json:"cache_read_input_token_cost"`
LiteLLMProvider string `json:"litellm_provider"`
Mode string `json:"mode"`
SupportsPromptCaching bool `json:"supports_prompt_caching"`
}
// LiteLLMRawEntry 用于解析原始JSON数据
type LiteLLMRawEntry struct {
InputCostPerToken *float64 `json:"input_cost_per_token"`
OutputCostPerToken *float64 `json:"output_cost_per_token"`
CacheCreationInputTokenCost *float64 `json:"cache_creation_input_token_cost"`
CacheReadInputTokenCost *float64 `json:"cache_read_input_token_cost"`
LiteLLMProvider string `json:"litellm_provider"`
Mode string `json:"mode"`
SupportsPromptCaching bool `json:"supports_prompt_caching"`
}
// PricingService 动态价格服务
type PricingService struct {
cfg *config.Config
mu sync.RWMutex
pricingData map[string]*LiteLLMModelPricing
lastUpdated time.Time
localHash string
// 停止信号
stopCh chan struct{}
wg sync.WaitGroup
}
// NewPricingService 创建价格服务
func NewPricingService(cfg *config.Config) *PricingService {
s := &PricingService{
cfg: cfg,
pricingData: make(map[string]*LiteLLMModelPricing),
stopCh: make(chan struct{}),
}
return s
}
// Initialize 初始化价格服务
func (s *PricingService) Initialize() error {
// 确保数据目录存在
if err := os.MkdirAll(s.cfg.Pricing.DataDir, 0755); err != nil {
log.Printf("[Pricing] Failed to create data directory: %v", err)
}
// 首次加载价格数据
if err := s.checkAndUpdatePricing(); err != nil {
log.Printf("[Pricing] Initial load failed, using fallback: %v", err)
if err := s.useFallbackPricing(); err != nil {
return fmt.Errorf("failed to load pricing data: %w", err)
}
}
// 启动定时更新
s.startUpdateScheduler()
log.Printf("[Pricing] Service initialized with %d models", len(s.pricingData))
return nil
}
// Stop 停止价格服务
func (s *PricingService) Stop() {
close(s.stopCh)
s.wg.Wait()
log.Println("[Pricing] Service stopped")
}
// startUpdateScheduler 启动定时更新调度器
func (s *PricingService) startUpdateScheduler() {
// 定期检查哈希更新
hashInterval := time.Duration(s.cfg.Pricing.HashCheckIntervalMinutes) * time.Minute
if hashInterval < time.Minute {
hashInterval = 10 * time.Minute
}
s.wg.Add(1)
go func() {
defer s.wg.Done()
ticker := time.NewTicker(hashInterval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
if err := s.syncWithRemote(); err != nil {
log.Printf("[Pricing] Sync failed: %v", err)
}
case <-s.stopCh:
return
}
}
}()
log.Printf("[Pricing] Update scheduler started (check every %v)", hashInterval)
}
// checkAndUpdatePricing 检查并更新价格数据
func (s *PricingService) checkAndUpdatePricing() error {
pricingFile := s.getPricingFilePath()
// 检查本地文件是否存在
if _, err := os.Stat(pricingFile); os.IsNotExist(err) {
log.Println("[Pricing] Local pricing file not found, downloading...")
return s.downloadPricingData()
}
// 检查文件是否过期
info, err := os.Stat(pricingFile)
if err != nil {
return s.downloadPricingData()
}
fileAge := time.Since(info.ModTime())
maxAge := time.Duration(s.cfg.Pricing.UpdateIntervalHours) * time.Hour
if fileAge > maxAge {
log.Printf("[Pricing] Local file is %v old, updating...", fileAge.Round(time.Hour))
if err := s.downloadPricingData(); err != nil {
log.Printf("[Pricing] Download failed, using existing file: %v", err)
}
}
// 加载本地文件
return s.loadPricingData(pricingFile)
}
// syncWithRemote 与远程同步(基于哈希校验)
func (s *PricingService) syncWithRemote() error {
pricingFile := s.getPricingFilePath()
// 计算本地文件哈希
localHash, err := s.computeFileHash(pricingFile)
if err != nil {
log.Printf("[Pricing] Failed to compute local hash: %v", err)
return s.downloadPricingData()
}
// 如果配置了哈希URL从远程获取哈希进行比对
if s.cfg.Pricing.HashURL != "" {
remoteHash, err := s.fetchRemoteHash()
if err != nil {
log.Printf("[Pricing] Failed to fetch remote hash: %v", err)
return nil // 哈希获取失败不影响正常使用
}
if remoteHash != localHash {
log.Println("[Pricing] Remote hash differs, downloading new version...")
return s.downloadPricingData()
}
log.Println("[Pricing] Hash check passed, no update needed")
return nil
}
// 没有哈希URL时基于时间检查
info, err := os.Stat(pricingFile)
if err != nil {
return s.downloadPricingData()
}
fileAge := time.Since(info.ModTime())
maxAge := time.Duration(s.cfg.Pricing.UpdateIntervalHours) * time.Hour
if fileAge > maxAge {
log.Printf("[Pricing] File is %v old, downloading...", fileAge.Round(time.Hour))
return s.downloadPricingData()
}
return nil
}
// downloadPricingData 从远程下载价格数据
func (s *PricingService) downloadPricingData() error {
log.Printf("[Pricing] Downloading from %s", s.cfg.Pricing.RemoteURL)
client := &http.Client{Timeout: 30 * time.Second}
resp, err := client.Get(s.cfg.Pricing.RemoteURL)
if err != nil {
return fmt.Errorf("download failed: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("download failed: HTTP %d", resp.StatusCode)
}
body, err := io.ReadAll(resp.Body)
if err != nil {
return fmt.Errorf("read response failed: %w", err)
}
// 解析JSON数据使用灵活的解析方式
data, err := s.parsePricingData(body)
if err != nil {
return fmt.Errorf("parse pricing data: %w", err)
}
// 保存到本地文件
pricingFile := s.getPricingFilePath()
if err := os.WriteFile(pricingFile, body, 0644); err != nil {
log.Printf("[Pricing] Failed to save file: %v", err)
}
// 保存哈希
hash := sha256.Sum256(body)
hashStr := hex.EncodeToString(hash[:])
hashFile := s.getHashFilePath()
if err := os.WriteFile(hashFile, []byte(hashStr+"\n"), 0644); err != nil {
log.Printf("[Pricing] Failed to save hash: %v", err)
}
// 更新内存数据
s.mu.Lock()
s.pricingData = data
s.lastUpdated = time.Now()
s.localHash = hashStr
s.mu.Unlock()
log.Printf("[Pricing] Downloaded %d models successfully", len(data))
return nil
}
// parsePricingData 解析价格数据(处理各种格式)
func (s *PricingService) parsePricingData(body []byte) (map[string]*LiteLLMModelPricing, error) {
// 首先解析为 map[string]json.RawMessage
var rawData map[string]json.RawMessage
if err := json.Unmarshal(body, &rawData); err != nil {
return nil, fmt.Errorf("parse raw JSON: %w", err)
}
result := make(map[string]*LiteLLMModelPricing)
skipped := 0
for modelName, rawEntry := range rawData {
// 跳过 sample_spec 等文档条目
if modelName == "sample_spec" {
continue
}
// 尝试解析每个条目
var entry LiteLLMRawEntry
if err := json.Unmarshal(rawEntry, &entry); err != nil {
skipped++
continue
}
// 只保留有有效价格的条目
if entry.InputCostPerToken == nil && entry.OutputCostPerToken == nil {
continue
}
pricing := &LiteLLMModelPricing{
LiteLLMProvider: entry.LiteLLMProvider,
Mode: entry.Mode,
SupportsPromptCaching: entry.SupportsPromptCaching,
}
if entry.InputCostPerToken != nil {
pricing.InputCostPerToken = *entry.InputCostPerToken
}
if entry.OutputCostPerToken != nil {
pricing.OutputCostPerToken = *entry.OutputCostPerToken
}
if entry.CacheCreationInputTokenCost != nil {
pricing.CacheCreationInputTokenCost = *entry.CacheCreationInputTokenCost
}
if entry.CacheReadInputTokenCost != nil {
pricing.CacheReadInputTokenCost = *entry.CacheReadInputTokenCost
}
result[modelName] = pricing
}
if skipped > 0 {
log.Printf("[Pricing] Skipped %d invalid entries", skipped)
}
if len(result) == 0 {
return nil, fmt.Errorf("no valid pricing entries found")
}
return result, nil
}
// loadPricingData 从本地文件加载价格数据
func (s *PricingService) loadPricingData(filePath string) error {
data, err := os.ReadFile(filePath)
if err != nil {
return fmt.Errorf("read file failed: %w", err)
}
// 使用灵活的解析方式
pricingData, err := s.parsePricingData(data)
if err != nil {
return fmt.Errorf("parse pricing data: %w", err)
}
// 计算哈希
hash := sha256.Sum256(data)
hashStr := hex.EncodeToString(hash[:])
s.mu.Lock()
s.pricingData = pricingData
s.localHash = hashStr
info, _ := os.Stat(filePath)
if info != nil {
s.lastUpdated = info.ModTime()
} else {
s.lastUpdated = time.Now()
}
s.mu.Unlock()
log.Printf("[Pricing] Loaded %d models from %s", len(pricingData), filePath)
return nil
}
// useFallbackPricing 使用回退价格文件
func (s *PricingService) useFallbackPricing() error {
fallbackFile := s.cfg.Pricing.FallbackFile
if _, err := os.Stat(fallbackFile); os.IsNotExist(err) {
return fmt.Errorf("fallback file not found: %s", fallbackFile)
}
log.Printf("[Pricing] Using fallback file: %s", fallbackFile)
// 复制到数据目录
data, err := os.ReadFile(fallbackFile)
if err != nil {
return fmt.Errorf("read fallback failed: %w", err)
}
pricingFile := s.getPricingFilePath()
if err := os.WriteFile(pricingFile, data, 0644); err != nil {
log.Printf("[Pricing] Failed to copy fallback: %v", err)
}
return s.loadPricingData(fallbackFile)
}
// fetchRemoteHash 从远程获取哈希值
func (s *PricingService) fetchRemoteHash() (string, error) {
client := &http.Client{Timeout: 10 * time.Second}
resp, err := client.Get(s.cfg.Pricing.HashURL)
if err != nil {
return "", err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("HTTP %d", resp.StatusCode)
}
body, err := io.ReadAll(resp.Body)
if err != nil {
return "", err
}
// 哈希文件格式hash filename 或者纯 hash
hash := strings.TrimSpace(string(body))
parts := strings.Fields(hash)
if len(parts) > 0 {
return parts[0], nil
}
return hash, nil
}
// computeFileHash 计算文件哈希
func (s *PricingService) computeFileHash(filePath string) (string, error) {
data, err := os.ReadFile(filePath)
if err != nil {
return "", err
}
hash := sha256.Sum256(data)
return hex.EncodeToString(hash[:]), nil
}
// GetModelPricing 获取模型价格(带模糊匹配)
func (s *PricingService) GetModelPricing(modelName string) *LiteLLMModelPricing {
s.mu.RLock()
defer s.mu.RUnlock()
if modelName == "" {
return nil
}
// 标准化模型名称
modelLower := strings.ToLower(modelName)
// 1. 精确匹配
if pricing, ok := s.pricingData[modelLower]; ok {
return pricing
}
if pricing, ok := s.pricingData[modelName]; ok {
return pricing
}
// 2. 处理常见的模型名称变体
// claude-opus-4-5-20251101 -> claude-opus-4.5-20251101
normalized := strings.ReplaceAll(modelLower, "-4-5-", "-4.5-")
if pricing, ok := s.pricingData[normalized]; ok {
return pricing
}
// 3. 尝试模糊匹配(去掉版本号后缀)
// claude-opus-4-5-20251101 -> claude-opus-4.5
baseName := s.extractBaseName(modelLower)
for key, pricing := range s.pricingData {
keyBase := s.extractBaseName(strings.ToLower(key))
if keyBase == baseName {
return pricing
}
}
// 4. 基于模型系列匹配
return s.matchByModelFamily(modelLower)
}
// extractBaseName 提取基础模型名称(去掉日期版本号)
func (s *PricingService) extractBaseName(model string) string {
// 移除日期后缀 (如 -20251101, -20241022)
parts := strings.Split(model, "-")
result := make([]string, 0, len(parts))
for _, part := range parts {
// 跳过看起来像日期的部分8位数字
if len(part) == 8 && isNumeric(part) {
continue
}
// 跳过版本号(如 v1:0
if strings.Contains(part, ":") {
continue
}
result = append(result, part)
}
return strings.Join(result, "-")
}
// matchByModelFamily 基于模型系列匹配
func (s *PricingService) matchByModelFamily(model string) *LiteLLMModelPricing {
// Claude模型系列匹配规则
familyPatterns := map[string][]string{
"opus-4.5": {"claude-opus-4.5", "claude-opus-4-5"},
"opus-4": {"claude-opus-4", "claude-3-opus"},
"sonnet-4.5": {"claude-sonnet-4.5", "claude-sonnet-4-5"},
"sonnet-4": {"claude-sonnet-4", "claude-3-5-sonnet"},
"sonnet-3.5": {"claude-3-5-sonnet", "claude-3.5-sonnet"},
"sonnet-3": {"claude-3-sonnet"},
"haiku-3.5": {"claude-3-5-haiku", "claude-3.5-haiku"},
"haiku-3": {"claude-3-haiku"},
}
// 确定模型属于哪个系列
var matchedFamily string
for family, patterns := range familyPatterns {
for _, pattern := range patterns {
if strings.Contains(model, pattern) || strings.Contains(model, strings.ReplaceAll(pattern, "-", "")) {
matchedFamily = family
break
}
}
if matchedFamily != "" {
break
}
}
if matchedFamily == "" {
// 简单的系列匹配
if strings.Contains(model, "opus") {
if strings.Contains(model, "4.5") || strings.Contains(model, "4-5") {
matchedFamily = "opus-4.5"
} else {
matchedFamily = "opus-4"
}
} else if strings.Contains(model, "sonnet") {
if strings.Contains(model, "4.5") || strings.Contains(model, "4-5") {
matchedFamily = "sonnet-4.5"
} else if strings.Contains(model, "3-5") || strings.Contains(model, "3.5") {
matchedFamily = "sonnet-3.5"
} else {
matchedFamily = "sonnet-4"
}
} else if strings.Contains(model, "haiku") {
if strings.Contains(model, "3-5") || strings.Contains(model, "3.5") {
matchedFamily = "haiku-3.5"
} else {
matchedFamily = "haiku-3"
}
}
}
if matchedFamily == "" {
return nil
}
// 在价格数据中查找该系列的模型
patterns := familyPatterns[matchedFamily]
for _, pattern := range patterns {
for key, pricing := range s.pricingData {
keyLower := strings.ToLower(key)
if strings.Contains(keyLower, pattern) {
log.Printf("[Pricing] Fuzzy matched %s -> %s", model, key)
return pricing
}
}
}
return nil
}
// GetStatus 获取服务状态
func (s *PricingService) GetStatus() map[string]interface{} {
s.mu.RLock()
defer s.mu.RUnlock()
return map[string]interface{}{
"model_count": len(s.pricingData),
"last_updated": s.lastUpdated,
"local_hash": s.localHash[:min(8, len(s.localHash))],
}
}
// ForceUpdate 强制更新
func (s *PricingService) ForceUpdate() error {
return s.downloadPricingData()
}
// getPricingFilePath 获取价格文件路径
func (s *PricingService) getPricingFilePath() string {
return filepath.Join(s.cfg.Pricing.DataDir, "model_pricing.json")
}
// getHashFilePath 获取哈希文件路径
func (s *PricingService) getHashFilePath() string {
return filepath.Join(s.cfg.Pricing.DataDir, "model_pricing.sha256")
}
// isNumeric 检查字符串是否为纯数字
func isNumeric(s string) bool {
for _, c := range s {
if c < '0' || c > '9' {
return false
}
}
return true
}