GPT-SoVITS/api_model_manager.py
2025-04-25 20:36:33 +08:00

212 lines
7.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python
# -*- coding: utf-8 -*-
import os
import json
import glob
import re
from typing import Dict, List, Tuple, Optional
import logging
logger = logging.getLogger("gpt-sovits-api")
class ModelManager:
"""
GPT-SoVITS模型管理器
用于管理GPT和SoVITS模型的映射关系
"""
def __init__(self):
self.gpt_weights_dir = "GPT_weights"
self.sovits_weights_dir = "SoVITS_weights"
# 扫描多个版本的模型目录
self.gpt_dirs = [
"GPT_weights",
"GPT_weights_v2",
"GPT_weights_v3",
"GPT_weights_v4"
]
self.sovits_dirs = [
"SoVITS_weights",
"SoVITS_weights_v2",
"SoVITS_weights_v3",
"SoVITS_weights_v4"
]
# 模型映射缓存
self.model_mapping = {}
self.voice_info = {}
# 加载模型映射
self.load_model_mapping()
def _extract_model_info(self, filename: str) -> Dict:
"""
从模型文件名中提取信息
支持多种命名格式:
1. 模型名_e迭代次数_s批次.pth
2. 模型名-e迭代次数.ckpt
Args:
filename: 模型文件名
Returns:
Dict: 包含模型名称、迭代次数和批次的字典
"""
basename = os.path.basename(filename)
name_parts = basename.split('.')
base_name = name_parts[0]
# 尝试匹配迭代次数 (e参数),支持连字符(-)和下划线(_)
e_match = re.search(r"[-_]e(\d+)", base_name)
# 尝试匹配批次 (s参数)主要在SoVITS模型中使用
s_match = re.search(r"[-_]s(\d+)", base_name)
# 提取模型名称去掉e和s参数部分
model_name = base_name
# 如果找到了e参数
if e_match:
# 获取e参数之前的部分作为模型名称
e_pos = base_name.find(e_match.group(0))
if e_pos > 0:
separator = base_name[e_pos] # 获取分隔符 (- 或 _)
model_name = base_name.split(f"{separator}e")[0]
# 提取扩展名
ext = os.path.splitext(basename)[1].lower()
iteration = int(e_match.group(1)) if e_match else 0
batch = int(s_match.group(1)) if s_match else 0
logger.debug(f"解析模型: {basename} -> 名称={model_name}, 迭代={iteration}, 批次={batch}")
return {
"name": model_name,
"iteration": iteration,
"batch": batch,
"filename": filename
}
def load_model_mapping(self):
"""
扫描模型目录,创建模型映射关系
将相同名称的GPT和SoVITS模型进行匹配
"""
# 扫描GPT模型
gpt_models = {}
for dir_path in self.gpt_dirs:
if not os.path.exists(dir_path):
continue
for file_path in glob.glob(f"{dir_path}/*.ckpt"):
model_info = self._extract_model_info(file_path)
model_name = model_info["name"]
# 使用更高迭代次数和批次的模型
if model_name not in gpt_models or \
(model_info["iteration"] > gpt_models[model_name]["iteration"] or \
(model_info["iteration"] == gpt_models[model_name]["iteration"] and \
model_info["batch"] > gpt_models[model_name]["batch"])):
gpt_models[model_name] = model_info
# 扫描SoVITS模型
sovits_models = {}
for dir_path in self.sovits_dirs:
if not os.path.exists(dir_path):
continue
for file_path in glob.glob(f"{dir_path}/*.pth"):
model_info = self._extract_model_info(file_path)
model_name = model_info["name"]
# 使用更高迭代次数和批次的模型
if model_name not in sovits_models or \
(model_info["iteration"] > sovits_models[model_name]["iteration"] or \
(model_info["iteration"] == sovits_models[model_name]["iteration"] and \
model_info["batch"] > sovits_models[model_name]["batch"])):
sovits_models[model_name] = model_info
# 创建映射关系
for name in set(list(gpt_models.keys()) + list(sovits_models.keys())):
gpt_model = gpt_models.get(name)
sovits_model = sovits_models.get(name)
if gpt_model and sovits_model:
self.model_mapping[name] = {
"gpt_path": gpt_model["filename"],
"sovits_path": sovits_model["filename"],
"iteration": min(gpt_model["iteration"], sovits_model["iteration"]),
"batch": min(gpt_model["batch"], sovits_model["batch"])
}
self.voice_info[name] = {
"id": name,
"name": name,
"iteration": min(gpt_model["iteration"], sovits_model["iteration"]),
"batch": min(gpt_model["batch"], sovits_model["batch"])
}
logger.info(f"已加载 {len(self.model_mapping)} 个模型映射")
def get_model_paths(self, voice_name: str) -> Tuple[Optional[str], Optional[str]]:
"""
获取指定voice对应的GPT和SoVITS模型路径
Args:
voice_name: 声音名称
Returns:
Tuple[str, str]: (GPT模型路径, SoVITS模型路径)
"""
if voice_name in self.model_mapping:
return (
self.model_mapping[voice_name]["gpt_path"],
self.model_mapping[voice_name]["sovits_path"]
)
return None, None
def get_all_voices(self) -> List[Dict]:
"""
获取所有可用的声音列表
Returns:
List[Dict]: 声音信息列表
"""
return [self.voice_info[name] for name in self.voice_info]
def get_voice_details(self, voice_name: str) -> Optional[Dict]:
"""
获取指定声音的详细信息
Args:
voice_name: 声音名称
Returns:
Dict: 声音详细信息
"""
if voice_name in self.voice_info:
info = self.voice_info[voice_name].copy()
info.update({
"gpt_path": self.model_mapping[voice_name]["gpt_path"],
"sovits_path": self.model_mapping[voice_name]["sovits_path"]
})
return info
return None
# 单例模式
model_manager = ModelManager()
if __name__ == "__main__":
# 测试代码
logging.basicConfig(level=logging.INFO)
manager = ModelManager()
voices = manager.get_all_voices()
print(f"发现 {len(voices)} 个声音模型:")
for voice in voices:
print(f"- {voice['name']}, 迭代次数: {voice['iteration']}, 批次: {voice['batch']}")
gpt_path, sovits_path = manager.get_model_paths(voice['name'])
print(f" GPT: {gpt_path}")
print(f" SoVITS: {sovits_path}")