diff --git a/GPT_SoVITS/text/LangSegmenter/langsegmenter.py b/GPT_SoVITS/text/LangSegmenter/langsegmenter.py index 1c754b8..48b9955 100644 --- a/GPT_SoVITS/text/LangSegmenter/langsegmenter.py +++ b/GPT_SoVITS/text/LangSegmenter/langsegmenter.py @@ -1,14 +1,74 @@ import logging -import jieba import re + +# jieba静音 +import jieba jieba.setLogLevel(logging.CRITICAL) # 更改fast_langdetect大模型位置 from pathlib import Path import fast_langdetect fast_langdetect.ft_detect.infer.CACHE_DIRECTORY = Path(__file__).parent.parent.parent / "pretrained_models" / "fast_langdetect" -import sys -sys.modules["fast_langdetect"] = fast_langdetect + +# 防止win下无法读取模型 +import os +from typing import Optional +def load_fasttext_model( + model_path: Path, + download_url: Optional[str] = None, + proxy: Optional[str] = None, +): + """ + Load a FastText model, downloading it if necessary. + :param model_path: Path to the FastText model file + :param download_url: URL to download the model from + :param proxy: Proxy URL for downloading the model + :return: FastText model + :raises DetectError: If model loading fails + """ + if all([ + fast_langdetect.ft_detect.infer.VERIFY_FASTTEXT_LARGE_MODEL, + model_path.exists(), + model_path.name == fast_langdetect.ft_detect.infer.FASTTEXT_LARGE_MODEL_NAME, + ]): + if not fast_langdetect.ft_detect.infer.verify_md5(model_path, fast_langdetect.ft_detect.infer.VERIFY_FASTTEXT_LARGE_MODEL): + fast_langdetect.ft_detect.infer.logger.warning( + f"fast-langdetect: MD5 hash verification failed for {model_path}, " + f"please check the integrity of the downloaded file from {fast_langdetect.ft_detect.infer.FASTTEXT_LARGE_MODEL_URL}. " + "\n This may seriously reduce the prediction accuracy. " + "If you want to ignore this, please set `fast_langdetect.ft_detect.infer.VERIFY_FASTTEXT_LARGE_MODEL = None` " + ) + if not model_path.exists(): + if download_url: + fast_langdetect.ft_detect.infer.download_model(download_url, model_path, proxy) + if not model_path.exists(): + raise fast_langdetect.ft_detect.infer.DetectError(f"FastText model file not found at {model_path}") + + try: + # Load FastText model + if (re.match(r'^[A-Za-z0-9_/\\:.]*$', str(model_path))): + model = fast_langdetect.ft_detect.infer.fasttext.load_model(str(model_path)) + else: + python_dir = os.getcwd() + if (str(model_path)[:len(python_dir)] == python_dir): + model = fast_langdetect.ft_detect.infer.fasttext.load_model(os.path.relpath(model_path, python_dir)) + else: + import tempfile + import shutil + with tempfile.NamedTemporaryFile(delete=False) as tmpfile: + shutil.copyfile(model_path, tmpfile.name) + + model = fast_langdetect.ft_detect.infer.fasttext.load_model(tmpfile.name) + os.unlink(tmpfile.name) + return model + + except Exception as e: + fast_langdetect.ft_detect.infer.logger.warning(f"fast-langdetect:Failed to load FastText model from {model_path}: {e}") + raise fast_langdetect.ft_detect.infer.DetectError(f"Failed to load FastText model: {e}") + +if os.name == 'nt': + fast_langdetect.ft_detect.infer.load_fasttext_model = load_fasttext_model + from split_lang import LangSplitter @@ -46,99 +106,6 @@ def merge_lang(lang_list, item): lang_list.append(item) return lang_list -from typing import List -from split_lang import SubString -def _special_merge_for_zh_ja( - self, - substrings: List[SubString], -) -> List[SubString]: - new_substrings: List[SubString] = [] - - if len(substrings) == 1: - return substrings - # NOTE: 统计每个语言的字符串长度 - substring_text_len_by_lang = { - "zh": 0, - "ja": 0, - "x": 0, - "digit": 0, - "punctuation": 0, - "newline": 0, - } - index = 0 - while index < len(substrings): - current_block = substrings[index] - substring_text_len_by_lang[current_block.lang] += current_block.length - if index == 0: - if ( - substrings[index + 1].lang in ["zh", "ja"] - and substrings[index].lang in ["zh", "ja", "x"] - and substrings[index].length * 10 < substrings[index + 1].length - ): - right_block = substrings[index + 1] - new_substrings.append( - SubString( - is_digit=False, - is_punctuation=False, - lang=right_block.lang, - text=current_block.text + right_block.text, - length=current_block.length + right_block.length, - index=current_block.index, - ) - ) - index += 1 - else: - new_substrings.append(substrings[index]) - elif index == len(substrings) - 1: - left_block = new_substrings[-1] - if ( - left_block.lang in ["zh", "ja"] - and current_block.lang in ["zh", "ja", "x"] - and current_block.length * 10 < left_block.length - ): - new_substrings[-1].text += current_block.text - new_substrings[-1].length += current_block.length - - index += 1 - else: - new_substrings.append(substrings[index]) - else: - if ( - new_substrings[-1].lang == substrings[index + 1].lang - and new_substrings[-1].lang in ["zh", "ja"] - # and substrings[index].lang in ["zh", "ja", "x"] - and substrings[index].lang != "en" - and substrings[index].length * 10 - < new_substrings[-1].length + substrings[index + 1].length - ): - left_block = new_substrings[-1] - right_block = substrings[index + 1] - current_block = substrings[index] - new_substrings[-1].text += current_block.text + right_block.text - new_substrings[-1].length += ( - current_block.length + right_block.length - ) - index += 1 - else: - new_substrings.append(substrings[index]) - index += 1 - # NOTE: 如果 substring_count 中 存在 x,则将 x 设置为最多的 lang - if substring_text_len_by_lang["x"] > 0: - max_lang = max( - substring_text_len_by_lang, key=substring_text_len_by_lang.get - ) - for index, substr in enumerate(new_substrings): - if substr.lang == "x": - new_substrings[index].lang = max_lang - # NOTE: 如果 ja 数量是 zh 数量的 10 倍以上,则该 zh 设置为 ja - if substring_text_len_by_lang["ja"] >= substring_text_len_by_lang["zh"] * 10: - for index, substr in enumerate(new_substrings): - if substr.lang == "zh": - new_substrings[index].lang = "ja" - new_substrings = self._merge_substrings(substrings=new_substrings) - return new_substrings - - class LangSegmenter(): # 默认过滤器, 基于gsv目前四种语言