package service import ( "context" "crypto/sha256" "encoding/hex" "encoding/json" "fmt" "log" "os" "path/filepath" "regexp" "strings" "sync" "time" "github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/pkg/openai" "github.com/Wei-Shaw/sub2api/internal/util/urlvalidator" ) var ( openAIModelDatePattern = regexp.MustCompile(`-\d{8}$`) openAIModelBasePattern = regexp.MustCompile(`^(gpt-\d+(?:\.\d+)?)(?:-|$)`) ) // LiteLLMModelPricing LiteLLM价格数据结构 // 只保留我们需要的字段,使用指针来处理可能缺失的值 type LiteLLMModelPricing struct { InputCostPerToken float64 `json:"input_cost_per_token"` OutputCostPerToken float64 `json:"output_cost_per_token"` CacheCreationInputTokenCost float64 `json:"cache_creation_input_token_cost"` CacheReadInputTokenCost float64 `json:"cache_read_input_token_cost"` LiteLLMProvider string `json:"litellm_provider"` Mode string `json:"mode"` SupportsPromptCaching bool `json:"supports_prompt_caching"` OutputCostPerImage float64 `json:"output_cost_per_image"` // 图片生成模型每张图片价格 } // PricingRemoteClient 远程价格数据获取接口 type PricingRemoteClient interface { FetchPricingJSON(ctx context.Context, url string) ([]byte, error) FetchHashText(ctx context.Context, url string) (string, error) } // LiteLLMRawEntry 用于解析原始JSON数据 type LiteLLMRawEntry struct { InputCostPerToken *float64 `json:"input_cost_per_token"` OutputCostPerToken *float64 `json:"output_cost_per_token"` CacheCreationInputTokenCost *float64 `json:"cache_creation_input_token_cost"` CacheReadInputTokenCost *float64 `json:"cache_read_input_token_cost"` LiteLLMProvider string `json:"litellm_provider"` Mode string `json:"mode"` SupportsPromptCaching bool `json:"supports_prompt_caching"` OutputCostPerImage *float64 `json:"output_cost_per_image"` } // PricingService 动态价格服务 type PricingService struct { cfg *config.Config remoteClient PricingRemoteClient mu sync.RWMutex pricingData map[string]*LiteLLMModelPricing lastUpdated time.Time localHash string // 停止信号 stopCh chan struct{} wg sync.WaitGroup } // NewPricingService 创建价格服务 func NewPricingService(cfg *config.Config, remoteClient PricingRemoteClient) *PricingService { s := &PricingService{ cfg: cfg, remoteClient: remoteClient, pricingData: make(map[string]*LiteLLMModelPricing), stopCh: make(chan struct{}), } return s } // Initialize 初始化价格服务 func (s *PricingService) Initialize() error { // 确保数据目录存在 if err := os.MkdirAll(s.cfg.Pricing.DataDir, 0755); err != nil { log.Printf("[Pricing] Failed to create data directory: %v", err) } // 首次加载价格数据 if err := s.checkAndUpdatePricing(); err != nil { log.Printf("[Pricing] Initial load failed, using fallback: %v", err) if err := s.useFallbackPricing(); err != nil { return fmt.Errorf("failed to load pricing data: %w", err) } } // 启动定时更新 s.startUpdateScheduler() log.Printf("[Pricing] Service initialized with %d models", len(s.pricingData)) return nil } // Stop 停止价格服务 func (s *PricingService) Stop() { close(s.stopCh) s.wg.Wait() log.Println("[Pricing] Service stopped") } // startUpdateScheduler 启动定时更新调度器 func (s *PricingService) startUpdateScheduler() { // 定期检查哈希更新 hashInterval := time.Duration(s.cfg.Pricing.HashCheckIntervalMinutes) * time.Minute if hashInterval < time.Minute { hashInterval = 10 * time.Minute } s.wg.Add(1) go func() { defer s.wg.Done() ticker := time.NewTicker(hashInterval) defer ticker.Stop() for { select { case <-ticker.C: if err := s.syncWithRemote(); err != nil { log.Printf("[Pricing] Sync failed: %v", err) } case <-s.stopCh: return } } }() log.Printf("[Pricing] Update scheduler started (check every %v)", hashInterval) } // checkAndUpdatePricing 检查并更新价格数据 func (s *PricingService) checkAndUpdatePricing() error { pricingFile := s.getPricingFilePath() // 检查本地文件是否存在 if _, err := os.Stat(pricingFile); os.IsNotExist(err) { log.Println("[Pricing] Local pricing file not found, downloading...") return s.downloadPricingData() } // 检查文件是否过期 info, err := os.Stat(pricingFile) if err != nil { return s.downloadPricingData() } fileAge := time.Since(info.ModTime()) maxAge := time.Duration(s.cfg.Pricing.UpdateIntervalHours) * time.Hour if fileAge > maxAge { log.Printf("[Pricing] Local file is %v old, updating...", fileAge.Round(time.Hour)) if err := s.downloadPricingData(); err != nil { log.Printf("[Pricing] Download failed, using existing file: %v", err) } } // 加载本地文件 return s.loadPricingData(pricingFile) } // syncWithRemote 与远程同步(基于哈希校验) func (s *PricingService) syncWithRemote() error { pricingFile := s.getPricingFilePath() // 计算本地文件哈希 localHash, err := s.computeFileHash(pricingFile) if err != nil { log.Printf("[Pricing] Failed to compute local hash: %v", err) return s.downloadPricingData() } // 如果配置了哈希URL,从远程获取哈希进行比对 if s.cfg.Pricing.HashURL != "" { remoteHash, err := s.fetchRemoteHash() if err != nil { log.Printf("[Pricing] Failed to fetch remote hash: %v", err) return nil // 哈希获取失败不影响正常使用 } if remoteHash != localHash { log.Println("[Pricing] Remote hash differs, downloading new version...") return s.downloadPricingData() } log.Println("[Pricing] Hash check passed, no update needed") return nil } // 没有哈希URL时,基于时间检查 info, err := os.Stat(pricingFile) if err != nil { return s.downloadPricingData() } fileAge := time.Since(info.ModTime()) maxAge := time.Duration(s.cfg.Pricing.UpdateIntervalHours) * time.Hour if fileAge > maxAge { log.Printf("[Pricing] File is %v old, downloading...", fileAge.Round(time.Hour)) return s.downloadPricingData() } return nil } // downloadPricingData 从远程下载价格数据 func (s *PricingService) downloadPricingData() error { remoteURL, err := s.validatePricingURL(s.cfg.Pricing.RemoteURL) if err != nil { return err } log.Printf("[Pricing] Downloading from %s", remoteURL) ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() var expectedHash string if strings.TrimSpace(s.cfg.Pricing.HashURL) != "" { expectedHash, err = s.fetchRemoteHash() if err != nil { return fmt.Errorf("fetch remote hash: %w", err) } } body, err := s.remoteClient.FetchPricingJSON(ctx, remoteURL) if err != nil { return fmt.Errorf("download failed: %w", err) } if expectedHash != "" { actualHash := sha256.Sum256(body) if !strings.EqualFold(expectedHash, hex.EncodeToString(actualHash[:])) { return fmt.Errorf("pricing hash mismatch") } } // 解析JSON数据(使用灵活的解析方式) data, err := s.parsePricingData(body) if err != nil { return fmt.Errorf("parse pricing data: %w", err) } // 保存到本地文件 pricingFile := s.getPricingFilePath() if err := os.WriteFile(pricingFile, body, 0644); err != nil { log.Printf("[Pricing] Failed to save file: %v", err) } // 保存哈希 hash := sha256.Sum256(body) hashStr := hex.EncodeToString(hash[:]) hashFile := s.getHashFilePath() if err := os.WriteFile(hashFile, []byte(hashStr+"\n"), 0644); err != nil { log.Printf("[Pricing] Failed to save hash: %v", err) } // 更新内存数据 s.mu.Lock() s.pricingData = data s.lastUpdated = time.Now() s.localHash = hashStr s.mu.Unlock() log.Printf("[Pricing] Downloaded %d models successfully", len(data)) return nil } // parsePricingData 解析价格数据(处理各种格式) func (s *PricingService) parsePricingData(body []byte) (map[string]*LiteLLMModelPricing, error) { // 首先解析为 map[string]json.RawMessage var rawData map[string]json.RawMessage if err := json.Unmarshal(body, &rawData); err != nil { return nil, fmt.Errorf("parse raw JSON: %w", err) } result := make(map[string]*LiteLLMModelPricing) skipped := 0 for modelName, rawEntry := range rawData { // 跳过 sample_spec 等文档条目 if modelName == "sample_spec" { continue } // 尝试解析每个条目 var entry LiteLLMRawEntry if err := json.Unmarshal(rawEntry, &entry); err != nil { skipped++ continue } // 只保留有有效价格的条目 if entry.InputCostPerToken == nil && entry.OutputCostPerToken == nil { continue } pricing := &LiteLLMModelPricing{ LiteLLMProvider: entry.LiteLLMProvider, Mode: entry.Mode, SupportsPromptCaching: entry.SupportsPromptCaching, } if entry.InputCostPerToken != nil { pricing.InputCostPerToken = *entry.InputCostPerToken } if entry.OutputCostPerToken != nil { pricing.OutputCostPerToken = *entry.OutputCostPerToken } if entry.CacheCreationInputTokenCost != nil { pricing.CacheCreationInputTokenCost = *entry.CacheCreationInputTokenCost } if entry.CacheReadInputTokenCost != nil { pricing.CacheReadInputTokenCost = *entry.CacheReadInputTokenCost } if entry.OutputCostPerImage != nil { pricing.OutputCostPerImage = *entry.OutputCostPerImage } result[modelName] = pricing } if skipped > 0 { log.Printf("[Pricing] Skipped %d invalid entries", skipped) } if len(result) == 0 { return nil, fmt.Errorf("no valid pricing entries found") } return result, nil } // loadPricingData 从本地文件加载价格数据 func (s *PricingService) loadPricingData(filePath string) error { data, err := os.ReadFile(filePath) if err != nil { return fmt.Errorf("read file failed: %w", err) } // 使用灵活的解析方式 pricingData, err := s.parsePricingData(data) if err != nil { return fmt.Errorf("parse pricing data: %w", err) } // 计算哈希 hash := sha256.Sum256(data) hashStr := hex.EncodeToString(hash[:]) s.mu.Lock() s.pricingData = pricingData s.localHash = hashStr info, _ := os.Stat(filePath) if info != nil { s.lastUpdated = info.ModTime() } else { s.lastUpdated = time.Now() } s.mu.Unlock() log.Printf("[Pricing] Loaded %d models from %s", len(pricingData), filePath) return nil } // useFallbackPricing 使用回退价格文件 func (s *PricingService) useFallbackPricing() error { fallbackFile := s.cfg.Pricing.FallbackFile if _, err := os.Stat(fallbackFile); os.IsNotExist(err) { return fmt.Errorf("fallback file not found: %s", fallbackFile) } log.Printf("[Pricing] Using fallback file: %s", fallbackFile) // 复制到数据目录 data, err := os.ReadFile(fallbackFile) if err != nil { return fmt.Errorf("read fallback failed: %w", err) } pricingFile := s.getPricingFilePath() if err := os.WriteFile(pricingFile, data, 0644); err != nil { log.Printf("[Pricing] Failed to copy fallback: %v", err) } return s.loadPricingData(fallbackFile) } // fetchRemoteHash 从远程获取哈希值 func (s *PricingService) fetchRemoteHash() (string, error) { hashURL, err := s.validatePricingURL(s.cfg.Pricing.HashURL) if err != nil { return "", err } ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() hash, err := s.remoteClient.FetchHashText(ctx, hashURL) if err != nil { return "", err } return strings.TrimSpace(hash), nil } func (s *PricingService) validatePricingURL(raw string) (string, error) { if s.cfg != nil && !s.cfg.Security.URLAllowlist.Enabled { normalized, err := urlvalidator.ValidateURLFormat(raw, s.cfg.Security.URLAllowlist.AllowInsecureHTTP) if err != nil { return "", fmt.Errorf("invalid pricing url: %w", err) } return normalized, nil } normalized, err := urlvalidator.ValidateHTTPSURL(raw, urlvalidator.ValidationOptions{ AllowedHosts: s.cfg.Security.URLAllowlist.PricingHosts, RequireAllowlist: true, AllowPrivate: s.cfg.Security.URLAllowlist.AllowPrivateHosts, }) if err != nil { return "", fmt.Errorf("invalid pricing url: %w", err) } return normalized, nil } // computeFileHash 计算文件哈希 func (s *PricingService) computeFileHash(filePath string) (string, error) { data, err := os.ReadFile(filePath) if err != nil { return "", err } hash := sha256.Sum256(data) return hex.EncodeToString(hash[:]), nil } // GetModelPricing 获取模型价格(带模糊匹配) func (s *PricingService) GetModelPricing(modelName string) *LiteLLMModelPricing { s.mu.RLock() defer s.mu.RUnlock() if modelName == "" { return nil } // 标准化模型名称(同时兼容 "models/xxx"、VertexAI 资源名等前缀) modelLower := strings.ToLower(strings.TrimSpace(modelName)) lookupCandidates := s.buildModelLookupCandidates(modelLower) // 1. 精确匹配 for _, candidate := range lookupCandidates { if candidate == "" { continue } if pricing, ok := s.pricingData[candidate]; ok { return pricing } } // 2. 处理常见的模型名称变体 // claude-opus-4-5-20251101 -> claude-opus-4.5-20251101 for _, candidate := range lookupCandidates { normalized := strings.ReplaceAll(candidate, "-4-5-", "-4.5-") if pricing, ok := s.pricingData[normalized]; ok { return pricing } } // 3. 尝试模糊匹配(去掉版本号后缀) // claude-opus-4-5-20251101 -> claude-opus-4.5 baseName := s.extractBaseName(lookupCandidates[0]) for key, pricing := range s.pricingData { keyBase := s.extractBaseName(strings.ToLower(key)) if keyBase == baseName { return pricing } } // 4. 基于模型系列匹配(Claude) if pricing := s.matchByModelFamily(lookupCandidates[0]); pricing != nil { return pricing } // 5. OpenAI 模型回退策略 if strings.HasPrefix(lookupCandidates[0], "gpt-") { return s.matchOpenAIModel(lookupCandidates[0]) } return nil } func (s *PricingService) buildModelLookupCandidates(modelLower string) []string { // Prefer canonical model name first (this also improves billing compatibility with "models/xxx"). candidates := []string{ normalizeModelNameForPricing(modelLower), modelLower, } candidates = append(candidates, strings.TrimPrefix(modelLower, "models/"), lastSegment(modelLower), lastSegment(strings.TrimPrefix(modelLower, "models/")), ) seen := make(map[string]struct{}, len(candidates)) out := make([]string, 0, len(candidates)) for _, c := range candidates { c = strings.TrimSpace(c) if c == "" { continue } if _, ok := seen[c]; ok { continue } seen[c] = struct{}{} out = append(out, c) } if len(out) == 0 { return []string{modelLower} } return out } func normalizeModelNameForPricing(model string) string { // Common Gemini/VertexAI forms: // - models/gemini-2.0-flash-exp // - publishers/google/models/gemini-2.5-pro // - projects/.../locations/.../publishers/google/models/gemini-2.5-pro model = strings.TrimSpace(model) model = strings.TrimLeft(model, "/") model = strings.TrimPrefix(model, "models/") model = strings.TrimPrefix(model, "publishers/google/models/") if idx := strings.LastIndex(model, "/publishers/google/models/"); idx != -1 { model = model[idx+len("/publishers/google/models/"):] } if idx := strings.LastIndex(model, "/models/"); idx != -1 { model = model[idx+len("/models/"):] } model = strings.TrimLeft(model, "/") return model } func lastSegment(model string) string { if idx := strings.LastIndex(model, "/"); idx != -1 { return model[idx+1:] } return model } // extractBaseName 提取基础模型名称(去掉日期版本号) func (s *PricingService) extractBaseName(model string) string { // 移除日期后缀 (如 -20251101, -20241022) parts := strings.Split(model, "-") result := make([]string, 0, len(parts)) for _, part := range parts { // 跳过看起来像日期的部分(8位数字) if len(part) == 8 && isNumeric(part) { continue } // 跳过版本号(如 v1:0) if strings.Contains(part, ":") { continue } result = append(result, part) } return strings.Join(result, "-") } // matchByModelFamily 基于模型系列匹配 func (s *PricingService) matchByModelFamily(model string) *LiteLLMModelPricing { // Claude模型系列匹配规则 familyPatterns := map[string][]string{ "opus-4.6": {"claude-opus-4.6", "claude-opus-4-6"}, "opus-4.5": {"claude-opus-4.5", "claude-opus-4-5"}, "opus-4": {"claude-opus-4", "claude-3-opus"}, "sonnet-4.5": {"claude-sonnet-4.5", "claude-sonnet-4-5"}, "sonnet-4": {"claude-sonnet-4", "claude-3-5-sonnet"}, "sonnet-3.5": {"claude-3-5-sonnet", "claude-3.5-sonnet"}, "sonnet-3": {"claude-3-sonnet"}, "haiku-3.5": {"claude-3-5-haiku", "claude-3.5-haiku"}, "haiku-3": {"claude-3-haiku"}, } // 确定模型属于哪个系列 var matchedFamily string for family, patterns := range familyPatterns { for _, pattern := range patterns { if strings.Contains(model, pattern) || strings.Contains(model, strings.ReplaceAll(pattern, "-", "")) { matchedFamily = family break } } if matchedFamily != "" { break } } if matchedFamily == "" { // 简单的系列匹配 if strings.Contains(model, "opus") { if strings.Contains(model, "4.5") || strings.Contains(model, "4-5") { matchedFamily = "opus-4.5" } else { matchedFamily = "opus-4" } } else if strings.Contains(model, "sonnet") { if strings.Contains(model, "4.5") || strings.Contains(model, "4-5") { matchedFamily = "sonnet-4.5" } else if strings.Contains(model, "3-5") || strings.Contains(model, "3.5") { matchedFamily = "sonnet-3.5" } else { matchedFamily = "sonnet-4" } } else if strings.Contains(model, "haiku") { if strings.Contains(model, "3-5") || strings.Contains(model, "3.5") { matchedFamily = "haiku-3.5" } else { matchedFamily = "haiku-3" } } } if matchedFamily == "" { return nil } // 在价格数据中查找该系列的模型 patterns := familyPatterns[matchedFamily] for _, pattern := range patterns { for key, pricing := range s.pricingData { keyLower := strings.ToLower(key) if strings.Contains(keyLower, pattern) { log.Printf("[Pricing] Fuzzy matched %s -> %s", model, key) return pricing } } } return nil } // matchOpenAIModel OpenAI 模型回退匹配策略 // 回退顺序: // 1. gpt-5.2-codex -> gpt-5.2(去掉后缀如 -codex, -mini, -max 等) // 2. gpt-5.2-20251222 -> gpt-5.2(去掉日期版本号) // 3. gpt-5.3-codex -> gpt-5.2-codex // 4. 最终回退到 DefaultTestModel (gpt-5.1-codex) func (s *PricingService) matchOpenAIModel(model string) *LiteLLMModelPricing { // 尝试的回退变体 variants := s.generateOpenAIModelVariants(model, openAIModelDatePattern) for _, variant := range variants { if pricing, ok := s.pricingData[variant]; ok { log.Printf("[Pricing] OpenAI fallback matched %s -> %s", model, variant) return pricing } } if strings.HasPrefix(model, "gpt-5.3-codex") { if pricing, ok := s.pricingData["gpt-5.2-codex"]; ok { log.Printf("[Pricing] OpenAI fallback matched %s -> %s", model, "gpt-5.2-codex") return pricing } } // 最终回退到 DefaultTestModel defaultModel := strings.ToLower(openai.DefaultTestModel) if pricing, ok := s.pricingData[defaultModel]; ok { log.Printf("[Pricing] OpenAI fallback to default model %s -> %s", model, defaultModel) return pricing } return nil } // generateOpenAIModelVariants 生成 OpenAI 模型的回退变体列表 func (s *PricingService) generateOpenAIModelVariants(model string, datePattern *regexp.Regexp) []string { seen := make(map[string]bool) var variants []string addVariant := func(v string) { if v != model && !seen[v] { seen[v] = true variants = append(variants, v) } } // 1. 去掉日期版本号: gpt-5.2-20251222 -> gpt-5.2 withoutDate := datePattern.ReplaceAllString(model, "") if withoutDate != model { addVariant(withoutDate) } // 2. 提取基础版本号: gpt-5.2-codex -> gpt-5.2 // 只匹配纯数字版本号格式 gpt-X 或 gpt-X.Y,不匹配 gpt-4o 这种带字母后缀的 if matches := openAIModelBasePattern.FindStringSubmatch(model); len(matches) > 1 { addVariant(matches[1]) } // 3. 同时去掉日期后再提取基础版本号 if withoutDate != model { if matches := openAIModelBasePattern.FindStringSubmatch(withoutDate); len(matches) > 1 { addVariant(matches[1]) } } return variants } // GetStatus 获取服务状态 func (s *PricingService) GetStatus() map[string]any { s.mu.RLock() defer s.mu.RUnlock() return map[string]any{ "model_count": len(s.pricingData), "last_updated": s.lastUpdated, "local_hash": s.localHash[:min(8, len(s.localHash))], } } // ForceUpdate 强制更新 func (s *PricingService) ForceUpdate() error { return s.downloadPricingData() } // getPricingFilePath 获取价格文件路径 func (s *PricingService) getPricingFilePath() string { return filepath.Join(s.cfg.Pricing.DataDir, "model_pricing.json") } // getHashFilePath 获取哈希文件路径 func (s *PricingService) getHashFilePath() string { return filepath.Join(s.cfg.Pricing.DataDir, "model_pricing.sha256") } // isNumeric 检查字符串是否为纯数字 func isNumeric(s string) bool { for _, c := range s { if c < '0' || c > '9' { return false } } return true }