diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go index d1cb76db..3ee5d6cd 100644 --- a/backend/internal/config/config.go +++ b/backend/internal/config/config.go @@ -1281,8 +1281,8 @@ func setDefaults() { viper.SetDefault("rate_limit.oauth_401_cooldown_minutes", 10) // Pricing - 从 model-price-repo 同步模型定价和上下文窗口数据(固定到 commit,避免分支漂移) - viper.SetDefault("pricing.remote_url", "https://raw.githubusercontent.com/Wei-Shaw/model-price-repo/c7947e9871687e664180bc971d4837f1fc2784a9/model_prices_and_context_window.json") - viper.SetDefault("pricing.hash_url", "https://raw.githubusercontent.com/Wei-Shaw/model-price-repo/c7947e9871687e664180bc971d4837f1fc2784a9/model_prices_and_context_window.sha256") + viper.SetDefault("pricing.remote_url", "https://raw.githubusercontent.com/Wei-Shaw/model-price-repo/main/model_prices_and_context_window.json") + viper.SetDefault("pricing.hash_url", "https://raw.githubusercontent.com/Wei-Shaw/model-price-repo/main/model_prices_and_context_window.sha256") viper.SetDefault("pricing.data_dir", "./data") viper.SetDefault("pricing.fallback_file", "./resources/model-pricing/model_prices_and_context_window.json") viper.SetDefault("pricing.update_interval_hours", 24) diff --git a/backend/internal/service/pricing_service.go b/backend/internal/service/pricing_service.go index 10440c60..5623d4b7 100644 --- a/backend/internal/service/pricing_service.go +++ b/backend/internal/service/pricing_service.go @@ -189,10 +189,38 @@ func (s *PricingService) checkAndUpdatePricing() error { return s.downloadPricingData() } - // 检查文件是否过期 + // 先加载本地文件(确保服务可用),再检查是否需要更新 + if err := s.loadPricingData(pricingFile); err != nil { + logger.LegacyPrintf("service.pricing", "[Pricing] Failed to load local file, downloading: %v", err) + return s.downloadPricingData() + } + + // 如果配置了哈希URL,通过远程哈希检查是否有更新 + if s.cfg.Pricing.HashURL != "" { + remoteHash, err := s.fetchRemoteHash() + if err != nil { + logger.LegacyPrintf("service.pricing", "[Pricing] Failed to fetch remote hash on startup: %v", err) + return nil // 已加载本地文件,哈希获取失败不影响启动 + } + + s.mu.RLock() + localHash := s.localHash + s.mu.RUnlock() + + if localHash == "" || remoteHash != localHash { + logger.LegacyPrintf("service.pricing", "[Pricing] Remote hash differs on startup (local=%s remote=%s), downloading...", + localHash[:min(8, len(localHash))], remoteHash[:min(8, len(remoteHash))]) + if err := s.downloadPricingData(); err != nil { + logger.LegacyPrintf("service.pricing", "[Pricing] Download failed, using existing file: %v", err) + } + } + return nil + } + + // 没有哈希URL时,基于文件年龄检查 info, err := os.Stat(pricingFile) if err != nil { - return s.downloadPricingData() + return nil // 已加载本地文件 } fileAge := time.Since(info.ModTime()) @@ -205,21 +233,11 @@ func (s *PricingService) checkAndUpdatePricing() error { } } - // 加载本地文件 - return s.loadPricingData(pricingFile) + return nil } // syncWithRemote 与远程同步(基于哈希校验) func (s *PricingService) syncWithRemote() error { - pricingFile := s.getPricingFilePath() - - // 计算本地文件哈希 - localHash, err := s.computeFileHash(pricingFile) - if err != nil { - logger.LegacyPrintf("service.pricing", "[Pricing] Failed to compute local hash: %v", err) - return s.downloadPricingData() - } - // 如果配置了哈希URL,从远程获取哈希进行比对 if s.cfg.Pricing.HashURL != "" { remoteHash, err := s.fetchRemoteHash() @@ -228,8 +246,13 @@ func (s *PricingService) syncWithRemote() error { return nil // 哈希获取失败不影响正常使用 } - if remoteHash != localHash { - logger.LegacyPrintf("service.pricing", "%s", "[Pricing] Remote hash differs, downloading new version...") + s.mu.RLock() + localHash := s.localHash + s.mu.RUnlock() + + if localHash == "" || remoteHash != localHash { + logger.LegacyPrintf("service.pricing", "[Pricing] Remote hash differs (local=%s remote=%s), downloading new version...", + localHash[:min(8, len(localHash))], remoteHash[:min(8, len(remoteHash))]) return s.downloadPricingData() } logger.LegacyPrintf("service.pricing", "%s", "[Pricing] Hash check passed, no update needed") @@ -237,6 +260,7 @@ func (s *PricingService) syncWithRemote() error { } // 没有哈希URL时,基于时间检查 + pricingFile := s.getPricingFilePath() info, err := os.Stat(pricingFile) if err != nil { return s.downloadPricingData() @@ -264,11 +288,12 @@ func (s *PricingService) downloadPricingData() error { ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() - var expectedHash string + // 获取远程哈希(用于同步锚点,不作为完整性校验) + var remoteHash string if strings.TrimSpace(s.cfg.Pricing.HashURL) != "" { - expectedHash, err = s.fetchRemoteHash() + remoteHash, err = s.fetchRemoteHash() if err != nil { - return fmt.Errorf("fetch remote hash: %w", err) + logger.LegacyPrintf("service.pricing", "[Pricing] Failed to fetch remote hash (continuing): %v", err) } } @@ -277,11 +302,13 @@ func (s *PricingService) downloadPricingData() error { return fmt.Errorf("download failed: %w", err) } - if expectedHash != "" { - actualHash := sha256.Sum256(body) - if !strings.EqualFold(expectedHash, hex.EncodeToString(actualHash[:])) { - return fmt.Errorf("pricing hash mismatch") - } + // 哈希校验:不匹配时仅告警,不阻止更新 + // 远程哈希文件可能与数据文件不同步(如维护者更新了数据但未更新哈希文件) + dataHash := sha256.Sum256(body) + dataHashStr := hex.EncodeToString(dataHash[:]) + if remoteHash != "" && !strings.EqualFold(remoteHash, dataHashStr) { + logger.LegacyPrintf("service.pricing", "[Pricing] Hash mismatch warning: remote=%s data=%s (hash file may be out of sync)", + remoteHash[:min(8, len(remoteHash))], dataHashStr[:8]) } // 解析JSON数据(使用灵活的解析方式) @@ -296,11 +323,14 @@ func (s *PricingService) downloadPricingData() error { logger.LegacyPrintf("service.pricing", "[Pricing] Failed to save file: %v", err) } - // 保存哈希 - hash := sha256.Sum256(body) - hashStr := hex.EncodeToString(hash[:]) + // 使用远程哈希作为同步锚点,防止重复下载 + // 当远程哈希不可用时,回退到数据本身的哈希 + syncHash := dataHashStr + if remoteHash != "" { + syncHash = remoteHash + } hashFile := s.getHashFilePath() - if err := os.WriteFile(hashFile, []byte(hashStr+"\n"), 0644); err != nil { + if err := os.WriteFile(hashFile, []byte(syncHash+"\n"), 0644); err != nil { logger.LegacyPrintf("service.pricing", "[Pricing] Failed to save hash: %v", err) } @@ -308,7 +338,7 @@ func (s *PricingService) downloadPricingData() error { s.mu.Lock() s.pricingData = data s.lastUpdated = time.Now() - s.localHash = hashStr + s.localHash = syncHash s.mu.Unlock() logger.LegacyPrintf("service.pricing", "[Pricing] Downloaded %d models successfully", len(data)) @@ -486,16 +516,6 @@ func (s *PricingService) validatePricingURL(raw string) (string, error) { return normalized, nil } -// computeFileHash 计算文件哈希 -func (s *PricingService) computeFileHash(filePath string) (string, error) { - data, err := os.ReadFile(filePath) - if err != nil { - return "", err - } - hash := sha256.Sum256(data) - return hex.EncodeToString(hash[:]), nil -} - // GetModelPricing 获取模型价格(带模糊匹配) func (s *PricingService) GetModelPricing(modelName string) *LiteLLMModelPricing { s.mu.RLock() diff --git a/deploy/config.example.yaml b/deploy/config.example.yaml index 2058ced1..8f60acd5 100644 --- a/deploy/config.example.yaml +++ b/deploy/config.example.yaml @@ -865,10 +865,10 @@ rate_limit: pricing: # URL to fetch model pricing data (default: pinned model-price-repo commit) # 获取模型定价数据的 URL(默认:固定 commit 的 model-price-repo) - remote_url: "https://raw.githubusercontent.com/Wei-Shaw/model-price-repo/c7947e9871687e664180bc971d4837f1fc2784a9/model_prices_and_context_window.json" + remote_url: "https://raw.githubusercontent.com/Wei-Shaw/model-price-repo/refs/heads/main//model_prices_and_context_window.json" # Hash verification URL (optional) # 哈希校验 URL(可选) - hash_url: "https://raw.githubusercontent.com/Wei-Shaw/model-price-repo/c7947e9871687e664180bc971d4837f1fc2784a9/model_prices_and_context_window.sha256" + hash_url: "https://raw.githubusercontent.com/Wei-Shaw/model-price-repo/refs/heads/main//model_prices_and_context_window.sha256" # Local data directory for caching # 本地数据缓存目录 data_dir: "./data"