mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-04-06 03:57:44 +08:00
Fix langsegmenter invalid path
This commit is contained in:
parent
e061e9d38e
commit
bb8f8c292c
@ -1,14 +1,74 @@
|
|||||||
import logging
|
import logging
|
||||||
import jieba
|
|
||||||
import re
|
import re
|
||||||
|
|
||||||
|
# jieba静音
|
||||||
|
import jieba
|
||||||
jieba.setLogLevel(logging.CRITICAL)
|
jieba.setLogLevel(logging.CRITICAL)
|
||||||
|
|
||||||
# 更改fast_langdetect大模型位置
|
# 更改fast_langdetect大模型位置
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import fast_langdetect
|
import fast_langdetect
|
||||||
fast_langdetect.ft_detect.infer.CACHE_DIRECTORY = Path(__file__).parent.parent.parent / "pretrained_models" / "fast_langdetect"
|
fast_langdetect.ft_detect.infer.CACHE_DIRECTORY = Path(__file__).parent.parent.parent / "pretrained_models" / "fast_langdetect"
|
||||||
import sys
|
|
||||||
sys.modules["fast_langdetect"] = fast_langdetect
|
# 防止win下无法读取模型
|
||||||
|
import os
|
||||||
|
from typing import Optional
|
||||||
|
def load_fasttext_model(
|
||||||
|
model_path: Path,
|
||||||
|
download_url: Optional[str] = None,
|
||||||
|
proxy: Optional[str] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Load a FastText model, downloading it if necessary.
|
||||||
|
:param model_path: Path to the FastText model file
|
||||||
|
:param download_url: URL to download the model from
|
||||||
|
:param proxy: Proxy URL for downloading the model
|
||||||
|
:return: FastText model
|
||||||
|
:raises DetectError: If model loading fails
|
||||||
|
"""
|
||||||
|
if all([
|
||||||
|
fast_langdetect.ft_detect.infer.VERIFY_FASTTEXT_LARGE_MODEL,
|
||||||
|
model_path.exists(),
|
||||||
|
model_path.name == fast_langdetect.ft_detect.infer.FASTTEXT_LARGE_MODEL_NAME,
|
||||||
|
]):
|
||||||
|
if not fast_langdetect.ft_detect.infer.verify_md5(model_path, fast_langdetect.ft_detect.infer.VERIFY_FASTTEXT_LARGE_MODEL):
|
||||||
|
fast_langdetect.ft_detect.infer.logger.warning(
|
||||||
|
f"fast-langdetect: MD5 hash verification failed for {model_path}, "
|
||||||
|
f"please check the integrity of the downloaded file from {fast_langdetect.ft_detect.infer.FASTTEXT_LARGE_MODEL_URL}. "
|
||||||
|
"\n This may seriously reduce the prediction accuracy. "
|
||||||
|
"If you want to ignore this, please set `fast_langdetect.ft_detect.infer.VERIFY_FASTTEXT_LARGE_MODEL = None` "
|
||||||
|
)
|
||||||
|
if not model_path.exists():
|
||||||
|
if download_url:
|
||||||
|
fast_langdetect.ft_detect.infer.download_model(download_url, model_path, proxy)
|
||||||
|
if not model_path.exists():
|
||||||
|
raise fast_langdetect.ft_detect.infer.DetectError(f"FastText model file not found at {model_path}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Load FastText model
|
||||||
|
if (re.match(r'^[A-Za-z0-9_/\\:.]*$', str(model_path))):
|
||||||
|
model = fast_langdetect.ft_detect.infer.fasttext.load_model(str(model_path))
|
||||||
|
else:
|
||||||
|
python_dir = os.getcwd()
|
||||||
|
if (str(model_path)[:len(python_dir)] == python_dir):
|
||||||
|
model = fast_langdetect.ft_detect.infer.fasttext.load_model(os.path.relpath(model_path, python_dir))
|
||||||
|
else:
|
||||||
|
import tempfile
|
||||||
|
import shutil
|
||||||
|
with tempfile.NamedTemporaryFile(delete=False) as tmpfile:
|
||||||
|
shutil.copyfile(model_path, tmpfile.name)
|
||||||
|
|
||||||
|
model = fast_langdetect.ft_detect.infer.fasttext.load_model(tmpfile.name)
|
||||||
|
os.unlink(tmpfile.name)
|
||||||
|
return model
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
fast_langdetect.ft_detect.infer.logger.warning(f"fast-langdetect:Failed to load FastText model from {model_path}: {e}")
|
||||||
|
raise fast_langdetect.ft_detect.infer.DetectError(f"Failed to load FastText model: {e}")
|
||||||
|
|
||||||
|
if os.name == 'nt':
|
||||||
|
fast_langdetect.ft_detect.infer.load_fasttext_model = load_fasttext_model
|
||||||
|
|
||||||
|
|
||||||
from split_lang import LangSplitter
|
from split_lang import LangSplitter
|
||||||
|
|
||||||
@ -46,99 +106,6 @@ def merge_lang(lang_list, item):
|
|||||||
lang_list.append(item)
|
lang_list.append(item)
|
||||||
return lang_list
|
return lang_list
|
||||||
|
|
||||||
from typing import List
|
|
||||||
from split_lang import SubString
|
|
||||||
def _special_merge_for_zh_ja(
|
|
||||||
self,
|
|
||||||
substrings: List[SubString],
|
|
||||||
) -> List[SubString]:
|
|
||||||
new_substrings: List[SubString] = []
|
|
||||||
|
|
||||||
if len(substrings) == 1:
|
|
||||||
return substrings
|
|
||||||
# NOTE: 统计每个语言的字符串长度
|
|
||||||
substring_text_len_by_lang = {
|
|
||||||
"zh": 0,
|
|
||||||
"ja": 0,
|
|
||||||
"x": 0,
|
|
||||||
"digit": 0,
|
|
||||||
"punctuation": 0,
|
|
||||||
"newline": 0,
|
|
||||||
}
|
|
||||||
index = 0
|
|
||||||
while index < len(substrings):
|
|
||||||
current_block = substrings[index]
|
|
||||||
substring_text_len_by_lang[current_block.lang] += current_block.length
|
|
||||||
if index == 0:
|
|
||||||
if (
|
|
||||||
substrings[index + 1].lang in ["zh", "ja"]
|
|
||||||
and substrings[index].lang in ["zh", "ja", "x"]
|
|
||||||
and substrings[index].length * 10 < substrings[index + 1].length
|
|
||||||
):
|
|
||||||
right_block = substrings[index + 1]
|
|
||||||
new_substrings.append(
|
|
||||||
SubString(
|
|
||||||
is_digit=False,
|
|
||||||
is_punctuation=False,
|
|
||||||
lang=right_block.lang,
|
|
||||||
text=current_block.text + right_block.text,
|
|
||||||
length=current_block.length + right_block.length,
|
|
||||||
index=current_block.index,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
index += 1
|
|
||||||
else:
|
|
||||||
new_substrings.append(substrings[index])
|
|
||||||
elif index == len(substrings) - 1:
|
|
||||||
left_block = new_substrings[-1]
|
|
||||||
if (
|
|
||||||
left_block.lang in ["zh", "ja"]
|
|
||||||
and current_block.lang in ["zh", "ja", "x"]
|
|
||||||
and current_block.length * 10 < left_block.length
|
|
||||||
):
|
|
||||||
new_substrings[-1].text += current_block.text
|
|
||||||
new_substrings[-1].length += current_block.length
|
|
||||||
|
|
||||||
index += 1
|
|
||||||
else:
|
|
||||||
new_substrings.append(substrings[index])
|
|
||||||
else:
|
|
||||||
if (
|
|
||||||
new_substrings[-1].lang == substrings[index + 1].lang
|
|
||||||
and new_substrings[-1].lang in ["zh", "ja"]
|
|
||||||
# and substrings[index].lang in ["zh", "ja", "x"]
|
|
||||||
and substrings[index].lang != "en"
|
|
||||||
and substrings[index].length * 10
|
|
||||||
< new_substrings[-1].length + substrings[index + 1].length
|
|
||||||
):
|
|
||||||
left_block = new_substrings[-1]
|
|
||||||
right_block = substrings[index + 1]
|
|
||||||
current_block = substrings[index]
|
|
||||||
new_substrings[-1].text += current_block.text + right_block.text
|
|
||||||
new_substrings[-1].length += (
|
|
||||||
current_block.length + right_block.length
|
|
||||||
)
|
|
||||||
index += 1
|
|
||||||
else:
|
|
||||||
new_substrings.append(substrings[index])
|
|
||||||
index += 1
|
|
||||||
# NOTE: 如果 substring_count 中 存在 x,则将 x 设置为最多的 lang
|
|
||||||
if substring_text_len_by_lang["x"] > 0:
|
|
||||||
max_lang = max(
|
|
||||||
substring_text_len_by_lang, key=substring_text_len_by_lang.get
|
|
||||||
)
|
|
||||||
for index, substr in enumerate(new_substrings):
|
|
||||||
if substr.lang == "x":
|
|
||||||
new_substrings[index].lang = max_lang
|
|
||||||
# NOTE: 如果 ja 数量是 zh 数量的 10 倍以上,则该 zh 设置为 ja
|
|
||||||
if substring_text_len_by_lang["ja"] >= substring_text_len_by_lang["zh"] * 10:
|
|
||||||
for index, substr in enumerate(new_substrings):
|
|
||||||
if substr.lang == "zh":
|
|
||||||
new_substrings[index].lang = "ja"
|
|
||||||
new_substrings = self._merge_substrings(substrings=new_substrings)
|
|
||||||
return new_substrings
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class LangSegmenter():
|
class LangSegmenter():
|
||||||
# 默认过滤器, 基于gsv目前四种语言
|
# 默认过滤器, 基于gsv目前四种语言
|
||||||
|
Loading…
x
Reference in New Issue
Block a user