Fix langsegmenter invalid path

This commit is contained in:
KamioRinn 2025-02-26 16:17:54 +08:00
parent e061e9d38e
commit bb8f8c292c

View File

@ -1,14 +1,74 @@
import logging
import jieba
import re
# jieba静音
import jieba
jieba.setLogLevel(logging.CRITICAL)
# 更改fast_langdetect大模型位置
from pathlib import Path
import fast_langdetect
fast_langdetect.ft_detect.infer.CACHE_DIRECTORY = Path(__file__).parent.parent.parent / "pretrained_models" / "fast_langdetect"
import sys
sys.modules["fast_langdetect"] = fast_langdetect
# 防止win下无法读取模型
import os
from typing import Optional
def load_fasttext_model(
model_path: Path,
download_url: Optional[str] = None,
proxy: Optional[str] = None,
):
"""
Load a FastText model, downloading it if necessary.
:param model_path: Path to the FastText model file
:param download_url: URL to download the model from
:param proxy: Proxy URL for downloading the model
:return: FastText model
:raises DetectError: If model loading fails
"""
if all([
fast_langdetect.ft_detect.infer.VERIFY_FASTTEXT_LARGE_MODEL,
model_path.exists(),
model_path.name == fast_langdetect.ft_detect.infer.FASTTEXT_LARGE_MODEL_NAME,
]):
if not fast_langdetect.ft_detect.infer.verify_md5(model_path, fast_langdetect.ft_detect.infer.VERIFY_FASTTEXT_LARGE_MODEL):
fast_langdetect.ft_detect.infer.logger.warning(
f"fast-langdetect: MD5 hash verification failed for {model_path}, "
f"please check the integrity of the downloaded file from {fast_langdetect.ft_detect.infer.FASTTEXT_LARGE_MODEL_URL}. "
"\n This may seriously reduce the prediction accuracy. "
"If you want to ignore this, please set `fast_langdetect.ft_detect.infer.VERIFY_FASTTEXT_LARGE_MODEL = None` "
)
if not model_path.exists():
if download_url:
fast_langdetect.ft_detect.infer.download_model(download_url, model_path, proxy)
if not model_path.exists():
raise fast_langdetect.ft_detect.infer.DetectError(f"FastText model file not found at {model_path}")
try:
# Load FastText model
if (re.match(r'^[A-Za-z0-9_/\\:.]*$', str(model_path))):
model = fast_langdetect.ft_detect.infer.fasttext.load_model(str(model_path))
else:
python_dir = os.getcwd()
if (str(model_path)[:len(python_dir)] == python_dir):
model = fast_langdetect.ft_detect.infer.fasttext.load_model(os.path.relpath(model_path, python_dir))
else:
import tempfile
import shutil
with tempfile.NamedTemporaryFile(delete=False) as tmpfile:
shutil.copyfile(model_path, tmpfile.name)
model = fast_langdetect.ft_detect.infer.fasttext.load_model(tmpfile.name)
os.unlink(tmpfile.name)
return model
except Exception as e:
fast_langdetect.ft_detect.infer.logger.warning(f"fast-langdetect:Failed to load FastText model from {model_path}: {e}")
raise fast_langdetect.ft_detect.infer.DetectError(f"Failed to load FastText model: {e}")
if os.name == 'nt':
fast_langdetect.ft_detect.infer.load_fasttext_model = load_fasttext_model
from split_lang import LangSplitter
@ -46,99 +106,6 @@ def merge_lang(lang_list, item):
lang_list.append(item)
return lang_list
from typing import List
from split_lang import SubString
def _special_merge_for_zh_ja(
self,
substrings: List[SubString],
) -> List[SubString]:
new_substrings: List[SubString] = []
if len(substrings) == 1:
return substrings
# NOTE: 统计每个语言的字符串长度
substring_text_len_by_lang = {
"zh": 0,
"ja": 0,
"x": 0,
"digit": 0,
"punctuation": 0,
"newline": 0,
}
index = 0
while index < len(substrings):
current_block = substrings[index]
substring_text_len_by_lang[current_block.lang] += current_block.length
if index == 0:
if (
substrings[index + 1].lang in ["zh", "ja"]
and substrings[index].lang in ["zh", "ja", "x"]
and substrings[index].length * 10 < substrings[index + 1].length
):
right_block = substrings[index + 1]
new_substrings.append(
SubString(
is_digit=False,
is_punctuation=False,
lang=right_block.lang,
text=current_block.text + right_block.text,
length=current_block.length + right_block.length,
index=current_block.index,
)
)
index += 1
else:
new_substrings.append(substrings[index])
elif index == len(substrings) - 1:
left_block = new_substrings[-1]
if (
left_block.lang in ["zh", "ja"]
and current_block.lang in ["zh", "ja", "x"]
and current_block.length * 10 < left_block.length
):
new_substrings[-1].text += current_block.text
new_substrings[-1].length += current_block.length
index += 1
else:
new_substrings.append(substrings[index])
else:
if (
new_substrings[-1].lang == substrings[index + 1].lang
and new_substrings[-1].lang in ["zh", "ja"]
# and substrings[index].lang in ["zh", "ja", "x"]
and substrings[index].lang != "en"
and substrings[index].length * 10
< new_substrings[-1].length + substrings[index + 1].length
):
left_block = new_substrings[-1]
right_block = substrings[index + 1]
current_block = substrings[index]
new_substrings[-1].text += current_block.text + right_block.text
new_substrings[-1].length += (
current_block.length + right_block.length
)
index += 1
else:
new_substrings.append(substrings[index])
index += 1
# NOTE: 如果 substring_count 中 存在 x则将 x 设置为最多的 lang
if substring_text_len_by_lang["x"] > 0:
max_lang = max(
substring_text_len_by_lang, key=substring_text_len_by_lang.get
)
for index, substr in enumerate(new_substrings):
if substr.lang == "x":
new_substrings[index].lang = max_lang
# NOTE: 如果 ja 数量是 zh 数量的 10 倍以上,则该 zh 设置为 ja
if substring_text_len_by_lang["ja"] >= substring_text_len_by_lang["zh"] * 10:
for index, substr in enumerate(new_substrings):
if substr.lang == "zh":
new_substrings[index].lang = "ja"
new_substrings = self._merge_substrings(substrings=new_substrings)
return new_substrings
class LangSegmenter():
# 默认过滤器, 基于gsv目前四种语言