Merge branch 'main' into main

This commit is contained in:
C3EZ 2025-03-09 10:36:02 +11:00 committed by GitHub
commit f72237d668
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 354 additions and 170 deletions

View File

@ -5,7 +5,7 @@ from typing import List, Optional
import torch import torch
from tqdm import tqdm from tqdm import tqdm
from AR.models.utils import make_pad_mask from AR.models.utils import make_pad_mask, make_pad_mask_left
from AR.models.utils import ( from AR.models.utils import (
topk_sampling, topk_sampling,
sample, sample,
@ -162,7 +162,7 @@ class T2SBlock:
) )
return x, k_cache, v_cache return x, k_cache, v_cache
def decode_next_token(self, x:torch.Tensor, k_cache:torch.Tensor, v_cache:torch.Tensor, attn_mask:Optional[torch.Tensor]=None, torch_sdpa:bool=True): def decode_next_token(self, x:torch.Tensor, k_cache:torch.Tensor, v_cache:torch.Tensor, attn_mask:torch.Tensor=None, torch_sdpa:bool=True):
q, k, v = F.linear(x, self.qkv_w, self.qkv_b).chunk(3, dim=-1) q, k, v = F.linear(x, self.qkv_w, self.qkv_b).chunk(3, dim=-1)
k_cache = torch.cat([k_cache, k], dim=1) k_cache = torch.cat([k_cache, k], dim=1)
@ -178,7 +178,7 @@ class T2SBlock:
if torch_sdpa: if torch_sdpa:
attn = F.scaled_dot_product_attention(q, k, v) attn = F.scaled_dot_product_attention(q, k, v, (~attn_mask) if attn_mask is not None else None)
else: else:
attn = scaled_dot_product_attention(q, k, v, attn_mask) attn = scaled_dot_product_attention(q, k, v, attn_mask)
@ -223,7 +223,7 @@ class T2STransformer:
self, x:torch.Tensor, self, x:torch.Tensor,
k_cache: List[torch.Tensor], k_cache: List[torch.Tensor],
v_cache: List[torch.Tensor], v_cache: List[torch.Tensor],
attn_mask : Optional[torch.Tensor]=None, attn_mask : torch.Tensor=None,
torch_sdpa:bool=True torch_sdpa:bool=True
): ):
for i in range(self.num_blocks): for i in range(self.num_blocks):
@ -573,71 +573,88 @@ class Text2SemanticDecoder(nn.Module):
x_item = self.ar_text_embedding(x_item.unsqueeze(0)) x_item = self.ar_text_embedding(x_item.unsqueeze(0))
x_item = x_item + self.bert_proj(bert_item.transpose(0, 1).unsqueeze(0)) x_item = x_item + self.bert_proj(bert_item.transpose(0, 1).unsqueeze(0))
x_item = self.ar_text_position(x_item).squeeze(0) x_item = self.ar_text_position(x_item).squeeze(0)
x_item = F.pad(x_item,(0,0,0,max_len-x_item.shape[0]),value=0) if x_item.shape[0]<max_len else x_item # x_item = F.pad(x_item,(0,0,0,max_len-x_item.shape[0]),value=0) if x_item.shape[0]<max_len else x_item ### padding right
x_item = F.pad(x_item,(0,0,max_len-x_item.shape[0],0),value=0) if x_item.shape[0]<max_len else x_item ### padding left
x_list.append(x_item) x_list.append(x_item)
x = torch.stack(x_list, dim=0) x:torch.Tensor = torch.stack(x_list, dim=0)
# AR Decoder # AR Decoder
y = prompts y = prompts
x_len = x.shape[1] x_len = x.shape[1]
x_attn_mask = torch.zeros((x_len, x_len), dtype=torch.bool)
stop = False stop = False
k_cache = None k_cache = None
v_cache = None v_cache = None
################### first step ########################## ################### first step ##########################
if y is not None: assert y is not None, "Error: Prompt free is not supported batch_infer!"
y_emb = self.ar_audio_embedding(y) ref_free = False
y_len = y_emb.shape[1]
prefix_len = y.shape[1] y_emb = self.ar_audio_embedding(y)
y_lens = torch.LongTensor([y_emb.shape[1]]*y_emb.shape[0]).to(x.device) y_len = y_emb.shape[1]
y_pos = self.ar_audio_position(y_emb) prefix_len = y.shape[1]
xy_pos = torch.concat([x, y_pos], dim=1) y_lens = torch.LongTensor([y_emb.shape[1]]*y_emb.shape[0]).to(x.device)
ref_free = False y_pos = self.ar_audio_position(y_emb)
else: xy_pos = torch.concat([x, y_pos], dim=1)
y_emb = None
y_len = 0
prefix_len = 0
y_lens = torch.LongTensor([y_len]*x.shape[0]).to(x.device)
y_pos = None
xy_pos = x
y = torch.zeros(x.shape[0], 0, dtype=torch.int, device=x.device)
ref_free = True
##### create mask ##### ##### create mask #####
bsz = x.shape[0] bsz = x.shape[0]
src_len = x_len + y_len src_len = x_len + y_len
y_paddind_mask = make_pad_mask(y_lens, y_len) y_paddind_mask = make_pad_mask_left(y_lens, y_len)
x_paddind_mask = make_pad_mask(x_lens, max_len) x_paddind_mask = make_pad_mask_left(x_lens, max_len)
# (bsz, x_len + y_len) # (bsz, x_len + y_len)
xy_padding_mask = torch.concat([x_paddind_mask, y_paddind_mask], dim=1) padding_mask = torch.concat([x_paddind_mask, y_paddind_mask], dim=1)
x_mask = F.pad(
torch.zeros(x_len, x_len, dtype=torch.bool, device=x.device),
(0, y_len),
value=True,
)
x_mask = F.pad(
x_attn_mask,
(0, y_len), ###xx的纯0扩展到xx纯0+xy纯1(x,x+y)
value=True,
)
y_mask = F.pad( ###yy的右上1扩展到左边xy的0,(y,x+y) y_mask = F.pad( ###yy的右上1扩展到左边xy的0,(y,x+y)
torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1), torch.triu(torch.ones(y_len, y_len, dtype=torch.bool, device=x.device), diagonal=1),
(x_len, 0), (x_len, 0),
value=False, value=False,
) )
xy_mask = torch.concat([x_mask, y_mask], dim=0).view(1 , src_len, src_len).repeat(bsz, 1, 1).to(x.device) causal_mask = torch.concat([x_mask, y_mask], dim=0).view(1 , src_len, src_len).repeat(bsz, 1, 1).to(x.device)
_xy_padding_mask = xy_padding_mask.view(bsz, 1, src_len).repeat(1, src_len, 1) # padding_mask = padding_mask.unsqueeze(1) * padding_mask.unsqueeze(2) ### [b, x+y, x+y]
### 上面是错误的会导致padding的token被"看见"
for i in range(bsz):
l = x_lens[i] # 正确的padding_mask应该是
_xy_padding_mask[i,l:max_len,:]=True # | pad_len | x_len | y_len |
# [[PAD, PAD, PAD, 1, 2, 3, 4, 5, 6],
xy_attn_mask = xy_mask.logical_or(_xy_padding_mask) # [PAD, PAD, PAD, 1, 2, 3, 4, 5, 6],
xy_attn_mask = xy_attn_mask.unsqueeze(1).expand(-1, self.num_head, -1, -1) # [PAD, PAD, PAD, 1, 2, 3, 4, 5, 6], 前3行按理说也应该被mask掉但是为了防止计算attention时不出现nan还是保留了不影响结果
xy_attn_mask = xy_attn_mask.bool() # [PAD, PAD, PAD, 1, 2, 3, 4, 5, 6],
xy_padding_mask = xy_padding_mask.view(bsz, src_len, 1) # [PAD, PAD, PAD, 1, 2, 3, 4, 5, 6],
# [PAD, PAD, PAD, 1, 2, 3, 4, 5, 6],
# [PAD, PAD, PAD, 1, 2, 3, 4, 5, 6],
# [PAD, PAD, PAD, 1, 2, 3, 4, 5, 6],
# [PAD, PAD, PAD, 1, 2, 3, 4, 5, 6]]
padding_mask = padding_mask.view(bsz, 1, src_len).repeat(1, src_len, 1)
attn_mask:torch.Tensor = causal_mask.logical_or(padding_mask)
attn_mask = attn_mask.unsqueeze(1).expand(-1, self.num_head, -1, -1).bool()
# 正确的attn_mask应该是这样的
# | pad_len | x_len | y_len |
# [[PAD, PAD, PAD, 1, 2, 3, EOS, EOS, EOS],
# [PAD, PAD, PAD, 1, 2, 3, EOS, EOS, EOS],
# [PAD, PAD, PAD, 1, 2, 3, EOS, EOS, EOS], 前3行按理说也应该被mask掉但是为了防止计算attention时不出现nan还是保留了不影响结果
# [PAD, PAD, PAD, 1, 2, 3, EOS, EOS, EOS],
# [PAD, PAD, PAD, 1, 2, 3, EOS, EOS, EOS],
# [PAD, PAD, PAD, 1, 2, 3, EOS, EOS, EOS],
# [PAD, PAD, PAD, 1, 2, 3, 4, EOS, EOS],
# [PAD, PAD, PAD, 1, 2, 3, 4, 5, EOS],
# [PAD, PAD, PAD, 1, 2, 3, 4, 5, 6]]
###### decode ##### ###### decode #####
y_list = [None]*y.shape[0] y_list = [None]*y.shape[0]
@ -645,18 +662,18 @@ class Text2SemanticDecoder(nn.Module):
idx_list = [None]*y.shape[0] idx_list = [None]*y.shape[0]
for idx in tqdm(range(1500)): for idx in tqdm(range(1500)):
if idx == 0: if idx == 0:
xy_dec, k_cache, v_cache = self.t2s_transformer.process_prompt(xy_pos, xy_attn_mask, xy_padding_mask, False) xy_dec, k_cache, v_cache = self.t2s_transformer.process_prompt(xy_pos, attn_mask, None)
else: else:
xy_dec, k_cache, v_cache = self.t2s_transformer.decode_next_token(xy_pos, k_cache, v_cache, xy_attn_mask, False) xy_dec, k_cache, v_cache = self.t2s_transformer.decode_next_token(xy_pos, k_cache, v_cache, attn_mask)
logits = self.ar_predict_layer( logits = self.ar_predict_layer(
xy_dec[:, -1] xy_dec[:, -1]
) )
if idx == 0: if idx == 0:
xy_attn_mask = F.pad(xy_attn_mask[:,:,-1].unsqueeze(-2),(0,1),value=False) attn_mask = F.pad(attn_mask[:,:,-1].unsqueeze(-2),(0,1),value=False)
logits = logits[:, :-1] logits = logits[:, :-1]
else: else:
xy_attn_mask = F.pad(xy_attn_mask,(0,1),value=False) attn_mask = F.pad(attn_mask,(0,1),value=False)
samples = sample( samples = sample(
logits, y, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature logits, y, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature
@ -686,7 +703,7 @@ class Text2SemanticDecoder(nn.Module):
if reserved_idx_of_batch_for_y is not None: if reserved_idx_of_batch_for_y is not None:
# index = torch.LongTensor(batch_idx_map).to(y.device) # index = torch.LongTensor(batch_idx_map).to(y.device)
y = torch.index_select(y, dim=0, index=reserved_idx_of_batch_for_y) y = torch.index_select(y, dim=0, index=reserved_idx_of_batch_for_y)
xy_attn_mask = torch.index_select(xy_attn_mask, dim=0, index=reserved_idx_of_batch_for_y) attn_mask = torch.index_select(attn_mask, dim=0, index=reserved_idx_of_batch_for_y)
if k_cache is not None : if k_cache is not None :
for i in range(len(k_cache)): for i in range(len(k_cache)):
k_cache[i] = torch.index_select(k_cache[i], dim=0, index=reserved_idx_of_batch_for_y) k_cache[i] = torch.index_select(k_cache[i], dim=0, index=reserved_idx_of_batch_for_y)

View File

@ -39,6 +39,39 @@ def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
return expaned_lengths >= lengths.unsqueeze(-1) return expaned_lengths >= lengths.unsqueeze(-1)
def make_pad_mask_left(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
"""
Args:
lengths:
A 1-D tensor containing sentence lengths.
max_len:
The length of masks.
Returns:
Return a 2-D bool tensor, where masked positions
are filled with `True` and non-masked positions are
filled with `False`.
#>>> lengths = torch.tensor([1, 3, 2, 5])
#>>> make_pad_mask(lengths)
tensor(
[
[True, True, False],
[True, False, False],
[True, True, False],
...
]
)
"""
assert lengths.ndim == 1, lengths.ndim
max_len = max(max_len, lengths.max())
n = lengths.size(0)
seq_range = torch.arange(0, max_len, device=lengths.device)
expaned_lengths = seq_range.unsqueeze(0).repeat(n, 1)
expaned_lengths -= (max_len-lengths).unsqueeze(-1)
return expaned_lengths<0
# https://github.com/microsoft/unilm/blob/master/xtune/src/transformers/modeling_utils.py # https://github.com/microsoft/unilm/blob/master/xtune/src/transformers/modeling_utils.py
def top_k_top_p_filtering( def top_k_top_p_filtering(
logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens_to_keep=1 logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens_to_keep=1

View File

@ -145,7 +145,15 @@ class TTS_Config:
self.device = self.configs.get("device", torch.device("cpu")) self.device = self.configs.get("device", torch.device("cpu"))
if "cuda" in str(self.device) and not torch.cuda.is_available():
print(f"Warning: CUDA is not available, set device to CPU.")
self.device = torch.device("cpu")
self.is_half = self.configs.get("is_half", False) self.is_half = self.configs.get("is_half", False)
# if str(self.device) == "cpu" and self.is_half:
# print(f"Warning: Half precision is not supported on CPU, set is_half to False.")
# self.is_half = False
self.version = version self.version = version
self.t2s_weights_path = self.configs.get("t2s_weights_path", None) self.t2s_weights_path = self.configs.get("t2s_weights_path", None)
self.vits_weights_path = self.configs.get("vits_weights_path", None) self.vits_weights_path = self.configs.get("vits_weights_path", None)

View File

@ -691,7 +691,7 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
wav_gen = bigvgan_model(cmf_res) wav_gen = bigvgan_model(cmf_res)
audio=wav_gen[0][0]#.cpu().detach().numpy() audio=wav_gen[0][0]#.cpu().detach().numpy()
max_audio=torch.abs(audio).max()#简单防止16bit爆音 max_audio=torch.abs(audio).max()#简单防止16bit爆音
if max_audio>1:audio/=max_audio if max_audio>1:audio=audio/max_audio
audio_opt.append(audio) audio_opt.append(audio)
audio_opt.append(zero_wav_torch)#zero_wav audio_opt.append(zero_wav_torch)#zero_wav
t4 = ttime() t4 = ttime()

View File

@ -1162,6 +1162,7 @@ class SynthesizerTrnV3(nn.Module):
use_sdp=True, use_sdp=True,
semantic_frame_rate=None, semantic_frame_rate=None,
freeze_quantizer=None, freeze_quantizer=None,
version="v3",
**kwargs): **kwargs):
super().__init__() super().__init__()
@ -1182,6 +1183,7 @@ class SynthesizerTrnV3(nn.Module):
self.segment_size = segment_size self.segment_size = segment_size
self.n_speakers = n_speakers self.n_speakers = n_speakers
self.gin_channels = gin_channels self.gin_channels = gin_channels
self.version = version
self.model_dim=512 self.model_dim=512
self.use_sdp = use_sdp self.use_sdp = use_sdp

View File

@ -8,66 +8,7 @@ jieba.setLogLevel(logging.CRITICAL)
# 更改fast_langdetect大模型位置 # 更改fast_langdetect大模型位置
from pathlib import Path from pathlib import Path
import fast_langdetect import fast_langdetect
fast_langdetect.ft_detect.infer.CACHE_DIRECTORY = Path(__file__).parent.parent.parent / "pretrained_models" / "fast_langdetect" fast_langdetect.infer._default_detector = fast_langdetect.infer.LangDetector(fast_langdetect.infer.LangDetectConfig(cache_dir=Path(__file__).parent.parent.parent / "pretrained_models" / "fast_langdetect"))
# 防止win下无法读取模型
import os
from typing import Optional
def load_fasttext_model(
model_path: Path,
download_url: Optional[str] = None,
proxy: Optional[str] = None,
):
"""
Load a FastText model, downloading it if necessary.
:param model_path: Path to the FastText model file
:param download_url: URL to download the model from
:param proxy: Proxy URL for downloading the model
:return: FastText model
:raises DetectError: If model loading fails
"""
if all([
fast_langdetect.ft_detect.infer.VERIFY_FASTTEXT_LARGE_MODEL,
model_path.exists(),
model_path.name == fast_langdetect.ft_detect.infer.FASTTEXT_LARGE_MODEL_NAME,
]):
if not fast_langdetect.ft_detect.infer.verify_md5(model_path, fast_langdetect.ft_detect.infer.VERIFY_FASTTEXT_LARGE_MODEL):
fast_langdetect.ft_detect.infer.logger.warning(
f"fast-langdetect: MD5 hash verification failed for {model_path}, "
f"please check the integrity of the downloaded file from {fast_langdetect.ft_detect.infer.FASTTEXT_LARGE_MODEL_URL}. "
"\n This may seriously reduce the prediction accuracy. "
"If you want to ignore this, please set `fast_langdetect.ft_detect.infer.VERIFY_FASTTEXT_LARGE_MODEL = None` "
)
if not model_path.exists():
if download_url:
fast_langdetect.ft_detect.infer.download_model(download_url, model_path, proxy)
if not model_path.exists():
raise fast_langdetect.ft_detect.infer.DetectError(f"FastText model file not found at {model_path}")
try:
# Load FastText model
if (re.match(r'^[A-Za-z0-9_/\\:.]*$', str(model_path))):
model = fast_langdetect.ft_detect.infer.fasttext.load_model(str(model_path))
else:
python_dir = os.getcwd()
if (str(model_path)[:len(python_dir)].upper() == python_dir.upper()):
model = fast_langdetect.ft_detect.infer.fasttext.load_model(os.path.relpath(model_path, python_dir))
else:
import tempfile
import shutil
with tempfile.NamedTemporaryFile(delete=False) as tmpfile:
shutil.copyfile(model_path, tmpfile.name)
model = fast_langdetect.ft_detect.infer.fasttext.load_model(tmpfile.name)
os.unlink(tmpfile.name)
return model
except Exception as e:
fast_langdetect.ft_detect.infer.logger.warning(f"fast-langdetect:Failed to load FastText model from {model_path}: {e}")
raise fast_langdetect.ft_detect.infer.DetectError(f"Failed to load FastText model: {e}")
if os.name == 'nt':
fast_langdetect.ft_detect.infer.load_fasttext_model = load_fasttext_model
from split_lang import LangSplitter from split_lang import LangSplitter

View File

@ -17,6 +17,8 @@ pinyin_to_symbol_map = {
for line in open(os.path.join(current_file_path, "opencpop-strict.txt")).readlines() for line in open(os.path.join(current_file_path, "opencpop-strict.txt")).readlines()
} }
import jieba_fast, logging
jieba_fast.setLogLevel(logging.CRITICAL)
import jieba_fast.posseg as psg import jieba_fast.posseg as psg

View File

@ -18,13 +18,15 @@ pinyin_to_symbol_map = {
for line in open(os.path.join(current_file_path, "opencpop-strict.txt")).readlines() for line in open(os.path.join(current_file_path, "opencpop-strict.txt")).readlines()
} }
import jieba_fast, logging
jieba_fast.setLogLevel(logging.CRITICAL)
import jieba_fast.posseg as psg import jieba_fast.posseg as psg
# is_g2pw_str = os.environ.get("is_g2pw", "True")##默认开启 # is_g2pw_str = os.environ.get("is_g2pw", "True")##默认开启
# is_g2pw = False#True if is_g2pw_str.lower() == 'true' else False # is_g2pw = False#True if is_g2pw_str.lower() == 'true' else False
is_g2pw = True#True if is_g2pw_str.lower() == 'true' else False is_g2pw = True#True if is_g2pw_str.lower() == 'true' else False
if is_g2pw: if is_g2pw:
print("当前使用g2pw进行拼音推理") # print("当前使用g2pw进行拼音推理")
from text.g2pw import G2PWPinyin, correct_pronunciation from text.g2pw import G2PWPinyin, correct_pronunciation
parent_directory = os.path.dirname(current_file_path) parent_directory = os.path.dirname(current_file_path)
g2pw = G2PWPinyin(model_dir="GPT_SoVITS/text/G2PWModel",model_source=os.environ.get("bert_path","GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large"),v_to_u=False, neutral_tone_with_five=True) g2pw = G2PWPinyin(model_dir="GPT_SoVITS/text/G2PWModel",model_source=os.environ.get("bert_path","GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large"),v_to_u=False, neutral_tone_with_five=True)

View File

@ -10,7 +10,7 @@ try:
if os.name == 'nt': if os.name == 'nt':
python_dir = os.getcwd() python_dir = os.getcwd()
OPEN_JTALK_DICT_DIR = pyopenjtalk.OPEN_JTALK_DICT_DIR.decode("utf-8") OPEN_JTALK_DICT_DIR = pyopenjtalk.OPEN_JTALK_DICT_DIR.decode("utf-8")
if not (re.match(r'^[A-Za-z0-9_/\\:.]*$', OPEN_JTALK_DICT_DIR)): if not (re.match(r'^[A-Za-z0-9_/\\:.\-]*$', OPEN_JTALK_DICT_DIR)):
if (OPEN_JTALK_DICT_DIR[:len(python_dir)].upper() == python_dir.upper()): if (OPEN_JTALK_DICT_DIR[:len(python_dir)].upper() == python_dir.upper()):
OPEN_JTALK_DICT_DIR = os.path.join(os.path.relpath(OPEN_JTALK_DICT_DIR,python_dir)) OPEN_JTALK_DICT_DIR = os.path.join(os.path.relpath(OPEN_JTALK_DICT_DIR,python_dir))
else: else:
@ -25,7 +25,7 @@ try:
OPEN_JTALK_DICT_DIR = os.path.join("TEMP", "ja", "open_jtalk_dic") OPEN_JTALK_DICT_DIR = os.path.join("TEMP", "ja", "open_jtalk_dic")
pyopenjtalk.OPEN_JTALK_DICT_DIR = OPEN_JTALK_DICT_DIR.encode("utf-8") pyopenjtalk.OPEN_JTALK_DICT_DIR = OPEN_JTALK_DICT_DIR.encode("utf-8")
if not (re.match(r'^[A-Za-z0-9_/\\:.]*$', current_file_path)): if not (re.match(r'^[A-Za-z0-9_/\\:.\-]*$', current_file_path)):
if (current_file_path[:len(python_dir)].upper() == python_dir.upper()): if (current_file_path[:len(python_dir)].upper() == python_dir.upper()):
current_file_path = os.path.join(os.path.relpath(current_file_path,python_dir)) current_file_path = os.path.join(os.path.relpath(current_file_path,python_dir))
else: else:

View File

@ -19,13 +19,13 @@ if os.name == 'nt':
print(f'you have to install eunjeon. install it...') print(f'you have to install eunjeon. install it...')
else: else:
installpath = spam_spec.submodule_search_locations[0] installpath = spam_spec.submodule_search_locations[0]
if not (re.match(r'^[A-Za-z0-9_/\\:.]*$', installpath)): if not (re.match(r'^[A-Za-z0-9_/\\:.\-]*$', installpath)):
import sys import sys
from eunjeon import Mecab as _Mecab from eunjeon import Mecab as _Mecab
class Mecab(_Mecab): class Mecab(_Mecab):
def get_dicpath(installpath): def get_dicpath(installpath):
if not (re.match(r'^[A-Za-z0-9_/\\:.]*$', installpath)): if not (re.match(r'^[A-Za-z0-9_/\\:.\-]*$', installpath)):
import shutil import shutil
python_dir = os.getcwd() python_dir = os.getcwd()
if (installpath[:len(python_dir)].upper() == python_dir.upper()): if (installpath[:len(python_dir)].upper() == python_dir.upper()):

273
api.py
View File

@ -150,9 +150,9 @@ sys.path.append(now_dir)
sys.path.append("%s/GPT_SoVITS" % (now_dir)) sys.path.append("%s/GPT_SoVITS" % (now_dir))
import signal import signal
import LangSegment from text.LangSegmenter import LangSegmenter
from time import time as ttime from time import time as ttime
import torch import torch, torchaudio
import librosa import librosa
import soundfile as sf import soundfile as sf
from fastapi import FastAPI, Request, Query, HTTPException from fastapi import FastAPI, Request, Query, HTTPException
@ -162,7 +162,8 @@ from transformers import AutoModelForMaskedLM, AutoTokenizer
import numpy as np import numpy as np
from feature_extractor import cnhubert from feature_extractor import cnhubert
from io import BytesIO from io import BytesIO
from module.models import SynthesizerTrn from module.models import SynthesizerTrn, SynthesizerTrnV3
from peft import LoraConfig, PeftModel, get_peft_model
from AR.models.t2s_lightning_module import Text2SemanticLightningModule from AR.models.t2s_lightning_module import Text2SemanticLightningModule
from text import cleaned_text_to_sequence from text import cleaned_text_to_sequence
from text.cleaner import clean_text from text.cleaner import clean_text
@ -197,6 +198,61 @@ def is_full(*items): # 任意一项为空返回False
return True return True
def init_bigvgan():
global bigvgan_model
from BigVGAN import bigvgan
bigvgan_model = bigvgan.BigVGAN.from_pretrained("%s/GPT_SoVITS/pretrained_models/models--nvidia--bigvgan_v2_24khz_100band_256x" % (now_dir,), use_cuda_kernel=False) # if True, RuntimeError: Ninja is required to load C++ extensions
# remove weight norm in the model and set to eval mode
bigvgan_model.remove_weight_norm()
bigvgan_model = bigvgan_model.eval()
if is_half == True:
bigvgan_model = bigvgan_model.half().to(device)
else:
bigvgan_model = bigvgan_model.to(device)
resample_transform_dict={}
def resample(audio_tensor, sr0):
global resample_transform_dict
if sr0 not in resample_transform_dict:
resample_transform_dict[sr0] = torchaudio.transforms.Resample(
sr0, 24000
).to(device)
return resample_transform_dict[sr0](audio_tensor)
from module.mel_processing import spectrogram_torch,mel_spectrogram_torch
spec_min = -12
spec_max = 2
def norm_spec(x):
return (x - spec_min) / (spec_max - spec_min) * 2 - 1
def denorm_spec(x):
return (x + 1) / 2 * (spec_max - spec_min) + spec_min
mel_fn=lambda x: mel_spectrogram_torch(x, **{
"n_fft": 1024,
"win_size": 1024,
"hop_size": 256,
"num_mels": 100,
"sampling_rate": 24000,
"fmin": 0,
"fmax": None,
"center": False
})
sr_model=None
def audio_sr(audio,sr):
global sr_model
if sr_model==None:
from tools.audio_sr import AP_BWE
try:
sr_model=AP_BWE(device,DictToAttrRecursive)
except FileNotFoundError:
logger.info("你没有下载超分模型的参数,因此不进行超分。如想超分请先参照教程把文件下载")
return audio.cpu().detach().numpy(),sr
return sr_model(audio,sr)
class Speaker: class Speaker:
def __init__(self, name, gpt, sovits, phones = None, bert = None, prompt = None): def __init__(self, name, gpt, sovits, phones = None, bert = None, prompt = None):
self.name = name self.name = name
@ -214,31 +270,72 @@ class Sovits:
self.vq_model = vq_model self.vq_model = vq_model
self.hps = hps self.hps = hps
from process_ckpt import get_sovits_version_from_path_fast,load_sovits_new
def get_sovits_weights(sovits_path): def get_sovits_weights(sovits_path):
dict_s2 = torch.load(sovits_path, map_location="cpu") path_sovits_v3="GPT_SoVITS/pretrained_models/s2Gv3.pth"
is_exist_s2gv3=os.path.exists(path_sovits_v3)
version, model_version, if_lora_v3=get_sovits_version_from_path_fast(sovits_path)
if if_lora_v3==True and is_exist_s2gv3==False:
logger.info("SoVITS V3 底模缺失,无法加载相应 LoRA 权重")
dict_s2 = load_sovits_new(sovits_path)
hps = dict_s2["config"] hps = dict_s2["config"]
hps = DictToAttrRecursive(hps) hps = DictToAttrRecursive(hps)
hps.model.semantic_frame_rate = "25hz" hps.model.semantic_frame_rate = "25hz"
if dict_s2['weight']['enc_p.text_embedding.weight'].shape[0] == 322: if 'enc_p.text_embedding.weight' not in dict_s2['weight']:
hps.model.version = "v2"#v3model,v2sybomls
elif dict_s2['weight']['enc_p.text_embedding.weight'].shape[0] == 322:
hps.model.version = "v1" hps.model.version = "v1"
else: else:
hps.model.version = "v2" hps.model.version = "v2"
logger.info(f"模型版本: {hps.model.version}")
if model_version == "v3":
hps.model.version = "v3"
model_params_dict = vars(hps.model) model_params_dict = vars(hps.model)
vq_model = SynthesizerTrn( if model_version!="v3":
hps.data.filter_length // 2 + 1, vq_model = SynthesizerTrn(
hps.train.segment_size // hps.data.hop_length, hps.data.filter_length // 2 + 1,
n_speakers=hps.data.n_speakers, hps.train.segment_size // hps.data.hop_length,
**model_params_dict n_speakers=hps.data.n_speakers,
) **model_params_dict
)
else:
vq_model = SynthesizerTrnV3(
hps.data.filter_length // 2 + 1,
hps.train.segment_size // hps.data.hop_length,
n_speakers=hps.data.n_speakers,
**model_params_dict
)
init_bigvgan()
model_version=hps.model.version
logger.info(f"模型版本: {model_version}")
if ("pretrained" not in sovits_path): if ("pretrained" not in sovits_path):
del vq_model.enc_q try:
del vq_model.enc_q
except:pass
if is_half == True: if is_half == True:
vq_model = vq_model.half().to(device) vq_model = vq_model.half().to(device)
else: else:
vq_model = vq_model.to(device) vq_model = vq_model.to(device)
vq_model.eval() vq_model.eval()
vq_model.load_state_dict(dict_s2["weight"], strict=False) if if_lora_v3 == False:
vq_model.load_state_dict(dict_s2["weight"], strict=False)
else:
vq_model.load_state_dict(load_sovits_new(path_sovits_v3)["weight"], strict=False)
lora_rank=dict_s2["lora_rank"]
lora_config = LoraConfig(
target_modules=["to_k", "to_q", "to_v", "to_out.0"],
r=lora_rank,
lora_alpha=lora_rank,
init_lora_weights=True,
)
vq_model.cfm = get_peft_model(vq_model.cfm, lora_config)
vq_model.load_state_dict(dict_s2["weight"], strict=False)
vq_model.cfm = vq_model.cfm.merge_and_unload()
# torch.save(vq_model.state_dict(),"merge_win.pth")
vq_model.eval()
sovits = Sovits(vq_model, hps) sovits = Sovits(vq_model, hps)
return sovits return sovits
@ -260,8 +357,8 @@ def get_gpt_weights(gpt_path):
t2s_model = t2s_model.half() t2s_model = t2s_model.half()
t2s_model = t2s_model.to(device) t2s_model = t2s_model.to(device)
t2s_model.eval() t2s_model.eval()
total = sum([param.nelement() for param in t2s_model.parameters()]) # total = sum([param.nelement() for param in t2s_model.parameters()])
logger.info("Number of parameter: %.2fM" % (total / 1e6)) # logger.info("Number of parameter: %.2fM" % (total / 1e6))
gpt = Gpt(max_sec, t2s_model) gpt = Gpt(max_sec, t2s_model)
return gpt return gpt
@ -295,6 +392,7 @@ def get_bert_feature(text, word2ph):
def clean_text_inf(text, language, version): def clean_text_inf(text, language, version):
language = language.replace("all_","")
phones, word2ph, norm_text = clean_text(text, language, version) phones, word2ph, norm_text = clean_text(text, language, version)
phones = cleaned_text_to_sequence(phones, version) phones = cleaned_text_to_sequence(phones, version)
return phones, word2ph, norm_text return phones, word2ph, norm_text
@ -315,16 +413,10 @@ def get_bert_inf(phones, word2ph, norm_text, language):
from text import chinese from text import chinese
def get_phones_and_bert(text,language,version,final=False): def get_phones_and_bert(text,language,version,final=False):
if language in {"en", "all_zh", "all_ja", "all_ko", "all_yue"}: if language in {"en", "all_zh", "all_ja", "all_ko", "all_yue"}:
language = language.replace("all_","") formattext = text
if language == "en":
LangSegment.setfilters(["en"])
formattext = " ".join(tmp["text"] for tmp in LangSegment.getTexts(text))
else:
# 因无法区别中日韩文汉字,以用户输入为准
formattext = text
while " " in formattext: while " " in formattext:
formattext = formattext.replace(" ", " ") formattext = formattext.replace(" ", " ")
if language == "zh": if language == "all_zh":
if re.search(r'[A-Za-z]', formattext): if re.search(r'[A-Za-z]', formattext):
formattext = re.sub(r'[a-z]', lambda x: x.group(0).upper(), formattext) formattext = re.sub(r'[a-z]', lambda x: x.group(0).upper(), formattext)
formattext = chinese.mix_text_normalize(formattext) formattext = chinese.mix_text_normalize(formattext)
@ -332,7 +424,7 @@ def get_phones_and_bert(text,language,version,final=False):
else: else:
phones, word2ph, norm_text = clean_text_inf(formattext, language, version) phones, word2ph, norm_text = clean_text_inf(formattext, language, version)
bert = get_bert_feature(norm_text, word2ph).to(device) bert = get_bert_feature(norm_text, word2ph).to(device)
elif language == "yue" and re.search(r'[A-Za-z]', formattext): elif language == "all_yue" and re.search(r'[A-Za-z]', formattext):
formattext = re.sub(r'[a-z]', lambda x: x.group(0).upper(), formattext) formattext = re.sub(r'[a-z]', lambda x: x.group(0).upper(), formattext)
formattext = chinese.mix_text_normalize(formattext) formattext = chinese.mix_text_normalize(formattext)
return get_phones_and_bert(formattext,"yue",version) return get_phones_and_bert(formattext,"yue",version)
@ -345,19 +437,18 @@ def get_phones_and_bert(text,language,version,final=False):
elif language in {"zh", "ja", "ko", "yue", "auto", "auto_yue"}: elif language in {"zh", "ja", "ko", "yue", "auto", "auto_yue"}:
textlist=[] textlist=[]
langlist=[] langlist=[]
LangSegment.setfilters(["zh","ja","en","ko"])
if language == "auto": if language == "auto":
for tmp in LangSegment.getTexts(text): for tmp in LangSegmenter.getTexts(text):
langlist.append(tmp["lang"]) langlist.append(tmp["lang"])
textlist.append(tmp["text"]) textlist.append(tmp["text"])
elif language == "auto_yue": elif language == "auto_yue":
for tmp in LangSegment.getTexts(text): for tmp in LangSegmenter.getTexts(text):
if tmp["lang"] == "zh": if tmp["lang"] == "zh":
tmp["lang"] = "yue" tmp["lang"] = "yue"
langlist.append(tmp["lang"]) langlist.append(tmp["lang"])
textlist.append(tmp["text"]) textlist.append(tmp["text"])
else: else:
for tmp in LangSegment.getTexts(text): for tmp in LangSegmenter.getTexts(text):
if tmp["lang"] == "en": if tmp["lang"] == "en":
langlist.append(tmp["lang"]) langlist.append(tmp["lang"])
else: else:
@ -556,10 +647,11 @@ def only_punc(text):
splits = {"", "", "", "", ",", ".", "?", "!", "~", ":", "", "", "", } splits = {"", "", "", "", ",", ".", "?", "!", "~", ":", "", "", "", }
def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, top_k= 15, top_p = 0.6, temperature = 0.6, speed = 1, inp_refs = None, spk = "default"): def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, top_k= 15, top_p = 0.6, temperature = 0.6, speed = 1, inp_refs = None, sample_steps = 32, if_sr = False, spk = "default"):
infer_sovits = speaker_list[spk].sovits infer_sovits = speaker_list[spk].sovits
vq_model = infer_sovits.vq_model vq_model = infer_sovits.vq_model
hps = infer_sovits.hps hps = infer_sovits.hps
version = vq_model.version
infer_gpt = speaker_list[spk].gpt infer_gpt = speaker_list[spk].gpt
t2s_model = infer_gpt.t2s_model t2s_model = infer_gpt.t2s_model
@ -587,20 +679,22 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
prompt_semantic = codes[0, 0] prompt_semantic = codes[0, 0]
prompt = prompt_semantic.unsqueeze(0).to(device) prompt = prompt_semantic.unsqueeze(0).to(device)
refers=[] if version != "v3":
if(inp_refs): refers=[]
for path in inp_refs: if(inp_refs):
try: for path in inp_refs:
refer = get_spepc(hps, path).to(dtype).to(device) try:
refers.append(refer) refer = get_spepc(hps, path).to(dtype).to(device)
except Exception as e: refers.append(refer)
logger.error(e) except Exception as e:
if(len(refers)==0): logger.error(e)
refers = [get_spepc(hps, ref_wav_path).to(dtype).to(device)] if(len(refers)==0):
refers = [get_spepc(hps, ref_wav_path).to(dtype).to(device)]
else:
refer = get_spepc(hps, ref_wav_path).to(device).to(dtype)
t1 = ttime() t1 = ttime()
version = vq_model.version # os.environ['version'] = version
os.environ['version'] = version
prompt_language = dict_language[prompt_language.lower()] prompt_language = dict_language[prompt_language.lower()]
text_language = dict_language[text_language.lower()] text_language = dict_language[text_language.lower()]
phones1, bert1, norm_text1 = get_phones_and_bert(prompt_text, prompt_language, version) phones1, bert1, norm_text1 = get_phones_and_bert(prompt_text, prompt_language, version)
@ -634,20 +728,82 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
early_stop_num=hz * max_sec) early_stop_num=hz * max_sec)
pred_semantic = pred_semantic[:, -idx:].unsqueeze(0) pred_semantic = pred_semantic[:, -idx:].unsqueeze(0)
t3 = ttime() t3 = ttime()
audio = \
vq_model.decode(pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0), if version != "v3":
refers,speed=speed).detach().cpu().numpy()[ audio = \
0, 0] ###试试重建不带上prompt部分 vq_model.decode(pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0),
refers,speed=speed).detach().cpu().numpy()[
0, 0] ###试试重建不带上prompt部分
else:
phoneme_ids0=torch.LongTensor(phones1).to(device).unsqueeze(0)
phoneme_ids1=torch.LongTensor(phones2).to(device).unsqueeze(0)
# print(11111111, phoneme_ids0, phoneme_ids1)
fea_ref,ge = vq_model.decode_encp(prompt.unsqueeze(0), phoneme_ids0, refer)
ref_audio, sr = torchaudio.load(ref_wav_path)
ref_audio=ref_audio.to(device).float()
if (ref_audio.shape[0] == 2):
ref_audio = ref_audio.mean(0).unsqueeze(0)
if sr!=24000:
ref_audio=resample(ref_audio,sr)
# print("ref_audio",ref_audio.abs().mean())
mel2 = mel_fn(ref_audio)
mel2 = norm_spec(mel2)
T_min = min(mel2.shape[2], fea_ref.shape[2])
mel2 = mel2[:, :, :T_min]
fea_ref = fea_ref[:, :, :T_min]
if (T_min > 468):
mel2 = mel2[:, :, -468:]
fea_ref = fea_ref[:, :, -468:]
T_min = 468
chunk_len = 934 - T_min
# print("fea_ref",fea_ref,fea_ref.shape)
# print("mel2",mel2)
mel2=mel2.to(dtype)
fea_todo, ge = vq_model.decode_encp(pred_semantic, phoneme_ids1, refer, ge,speed)
# print("fea_todo",fea_todo)
# print("ge",ge.abs().mean())
cfm_resss = []
idx = 0
while (1):
fea_todo_chunk = fea_todo[:, :, idx:idx + chunk_len]
if (fea_todo_chunk.shape[-1] == 0): break
idx += chunk_len
fea = torch.cat([fea_ref, fea_todo_chunk], 2).transpose(2, 1)
# set_seed(123)
cfm_res = vq_model.cfm.inference(fea, torch.LongTensor([fea.size(1)]).to(fea.device), mel2, sample_steps, inference_cfg_rate=0)
cfm_res = cfm_res[:, :, mel2.shape[2]:]
mel2 = cfm_res[:, :, -T_min:]
# print("fea", fea)
# print("mel2in", mel2)
fea_ref = fea_todo_chunk[:, :, -T_min:]
cfm_resss.append(cfm_res)
cmf_res = torch.cat(cfm_resss, 2)
cmf_res = denorm_spec(cmf_res)
if bigvgan_model==None:init_bigvgan()
with torch.inference_mode():
wav_gen = bigvgan_model(cmf_res)
audio=wav_gen[0][0].cpu().detach().numpy()
max_audio=np.abs(audio).max() max_audio=np.abs(audio).max()
if max_audio>1: if max_audio>1:
audio/=max_audio audio/=max_audio
audio_opt.append(audio) audio_opt.append(audio)
audio_opt.append(zero_wav) audio_opt.append(zero_wav)
audio_opt = np.concatenate(audio_opt, 0)
t4 = ttime() t4 = ttime()
sr = hps.data.sampling_rate if version != "v3" else 24000
if if_sr and sr == 24000:
audio_opt = torch.from_numpy(audio_opt).float().to(device)
audio_opt,sr=audio_sr(audio_opt.unsqueeze(0),sr)
max_audio=np.abs(audio_opt).max()
if max_audio > 1: audio_opt /= max_audio
sr = 48000
if is_int32: if is_int32:
audio_bytes = pack_audio(audio_bytes,(np.concatenate(audio_opt, 0) * 2147483647).astype(np.int32),hps.data.sampling_rate) audio_bytes = pack_audio(audio_bytes,(audio_opt * 2147483647).astype(np.int32),sr)
else: else:
audio_bytes = pack_audio(audio_bytes,(np.concatenate(audio_opt, 0) * 32768).astype(np.int16),hps.data.sampling_rate) audio_bytes = pack_audio(audio_bytes,(audio_opt * 32768).astype(np.int16),sr)
# logger.info("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t3 - t2, t4 - t3)) # logger.info("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t3 - t2, t4 - t3))
if stream_mode == "normal": if stream_mode == "normal":
audio_bytes, audio_chunk = read_clean_buffer(audio_bytes) audio_bytes, audio_chunk = read_clean_buffer(audio_bytes)
@ -655,7 +811,9 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
if not stream_mode == "normal": if not stream_mode == "normal":
if media_type == "wav": if media_type == "wav":
audio_bytes = pack_wav(audio_bytes,hps.data.sampling_rate) sr = 48000 if if_sr else 24000
sr = hps.data.sampling_rate if version != "v3" else sr
audio_bytes = pack_wav(audio_bytes,sr)
yield audio_bytes.getvalue() yield audio_bytes.getvalue()
@ -688,7 +846,7 @@ def handle_change(path, text, language):
return JSONResponse({"code": 0, "message": "Success"}, status_code=200) return JSONResponse({"code": 0, "message": "Success"}, status_code=200)
def handle(refer_wav_path, prompt_text, prompt_language, text, text_language, cut_punc, top_k, top_p, temperature, speed, inp_refs): def handle(refer_wav_path, prompt_text, prompt_language, text, text_language, cut_punc, top_k, top_p, temperature, speed, inp_refs, sample_steps, if_sr):
if ( if (
refer_wav_path == "" or refer_wav_path is None refer_wav_path == "" or refer_wav_path is None
or prompt_text == "" or prompt_text is None or prompt_text == "" or prompt_text is None
@ -702,12 +860,15 @@ def handle(refer_wav_path, prompt_text, prompt_language, text, text_language, cu
if not default_refer.is_ready(): if not default_refer.is_ready():
return JSONResponse({"code": 400, "message": "未指定参考音频且接口无预设"}, status_code=400) return JSONResponse({"code": 400, "message": "未指定参考音频且接口无预设"}, status_code=400)
if not sample_steps in [4,8,16,32]:
sample_steps = 32
if cut_punc == None: if cut_punc == None:
text = cut_text(text,default_cut_punc) text = cut_text(text,default_cut_punc)
else: else:
text = cut_text(text,cut_punc) text = cut_text(text,cut_punc)
return StreamingResponse(get_tts_wav(refer_wav_path, prompt_text, prompt_language, text, text_language, top_k, top_p, temperature, speed, inp_refs), media_type="audio/"+media_type) return StreamingResponse(get_tts_wav(refer_wav_path, prompt_text, prompt_language, text, text_language, top_k, top_p, temperature, speed, inp_refs, sample_steps, if_sr), media_type="audio/"+media_type)
@ -915,7 +1076,9 @@ async def tts_endpoint(request: Request):
json_post_raw.get("top_p", 1.0), json_post_raw.get("top_p", 1.0),
json_post_raw.get("temperature", 1.0), json_post_raw.get("temperature", 1.0),
json_post_raw.get("speed", 1.0), json_post_raw.get("speed", 1.0),
json_post_raw.get("inp_refs", []) json_post_raw.get("inp_refs", []),
json_post_raw.get("sample_steps", 32),
json_post_raw.get("if_sr", False)
) )
@ -931,9 +1094,11 @@ async def tts_endpoint(
top_p: float = 1.0, top_p: float = 1.0,
temperature: float = 1.0, temperature: float = 1.0,
speed: float = 1.0, speed: float = 1.0,
inp_refs: list = Query(default=[]) inp_refs: list = Query(default=[]),
sample_steps: int = 32,
if_sr: bool = False
): ):
return handle(refer_wav_path, prompt_text, prompt_language, text, text_language, cut_punc, top_k, top_p, temperature, speed, inp_refs) return handle(refer_wav_path, prompt_text, prompt_language, text, text_language, cut_punc, top_k, top_p, temperature, speed, inp_refs, sample_steps, if_sr)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -1,4 +1,7 @@
#!/bin/bash #!/bin/bash
# 安装构建工具
# Install build tools
echo "Installing GCC..." echo "Installing GCC..."
conda install -c conda-forge gcc=14 conda install -c conda-forge gcc=14
@ -8,6 +11,12 @@ conda install -c conda-forge gxx
echo "Installing ffmpeg and cmake..." echo "Installing ffmpeg and cmake..."
conda install ffmpeg cmake conda install ffmpeg cmake
# 设置编译环境
# Set up build environment
export CMAKE_MAKE_PROGRAM="$CONDA_PREFIX/bin/cmake"
export CC="$CONDA_PREFIX/bin/gcc"
export CXX="$CONDA_PREFIX/bin/g++"
echo "Checking for CUDA installation..." echo "Checking for CUDA installation..."
if command -v nvidia-smi &> /dev/null; then if command -v nvidia-smi &> /dev/null; then
USE_CUDA=true USE_CUDA=true
@ -49,6 +58,10 @@ fi
echo "Installing Python dependencies from requirements.txt..." echo "Installing Python dependencies from requirements.txt..."
# 刷新环境
# Refresh environment
hash -r
pip install -r requirements.txt pip install -r requirements.txt
if [ "$USE_ROCM" = true ] && [ "$IS_WSL" = true ] ; then if [ "$USE_ROCM" = true ] && [ "$IS_WSL" = true ] ; then
@ -60,3 +73,4 @@ if [ "$USE_ROCM" = true ] && [ "$IS_WSL" = true ] ; then
fi fi
echo "Installation completed successfully!" echo "Installation completed successfully!"

View File

@ -25,7 +25,7 @@ psutil
jieba_fast jieba_fast
jieba jieba
split-lang split-lang
fast_langdetect fast_langdetect>=0.3.0
Faster_Whisper Faster_Whisper
wordsegment wordsegment
rotary_embedding_torch rotary_embedding_torch