mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-04-05 19:41:56 +08:00
Merge branch 'main' into main
This commit is contained in:
commit
f72237d668
@ -5,7 +5,7 @@ from typing import List, Optional
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from AR.models.utils import make_pad_mask
|
||||
from AR.models.utils import make_pad_mask, make_pad_mask_left
|
||||
from AR.models.utils import (
|
||||
topk_sampling,
|
||||
sample,
|
||||
@ -162,7 +162,7 @@ class T2SBlock:
|
||||
)
|
||||
return x, k_cache, v_cache
|
||||
|
||||
def decode_next_token(self, x:torch.Tensor, k_cache:torch.Tensor, v_cache:torch.Tensor, attn_mask:Optional[torch.Tensor]=None, torch_sdpa:bool=True):
|
||||
def decode_next_token(self, x:torch.Tensor, k_cache:torch.Tensor, v_cache:torch.Tensor, attn_mask:torch.Tensor=None, torch_sdpa:bool=True):
|
||||
q, k, v = F.linear(x, self.qkv_w, self.qkv_b).chunk(3, dim=-1)
|
||||
|
||||
k_cache = torch.cat([k_cache, k], dim=1)
|
||||
@ -178,7 +178,7 @@ class T2SBlock:
|
||||
|
||||
|
||||
if torch_sdpa:
|
||||
attn = F.scaled_dot_product_attention(q, k, v)
|
||||
attn = F.scaled_dot_product_attention(q, k, v, (~attn_mask) if attn_mask is not None else None)
|
||||
else:
|
||||
attn = scaled_dot_product_attention(q, k, v, attn_mask)
|
||||
|
||||
@ -223,7 +223,7 @@ class T2STransformer:
|
||||
self, x:torch.Tensor,
|
||||
k_cache: List[torch.Tensor],
|
||||
v_cache: List[torch.Tensor],
|
||||
attn_mask : Optional[torch.Tensor]=None,
|
||||
attn_mask : torch.Tensor=None,
|
||||
torch_sdpa:bool=True
|
||||
):
|
||||
for i in range(self.num_blocks):
|
||||
@ -573,71 +573,88 @@ class Text2SemanticDecoder(nn.Module):
|
||||
x_item = self.ar_text_embedding(x_item.unsqueeze(0))
|
||||
x_item = x_item + self.bert_proj(bert_item.transpose(0, 1).unsqueeze(0))
|
||||
x_item = self.ar_text_position(x_item).squeeze(0)
|
||||
x_item = F.pad(x_item,(0,0,0,max_len-x_item.shape[0]),value=0) if x_item.shape[0]<max_len else x_item
|
||||
# x_item = F.pad(x_item,(0,0,0,max_len-x_item.shape[0]),value=0) if x_item.shape[0]<max_len else x_item ### padding right
|
||||
x_item = F.pad(x_item,(0,0,max_len-x_item.shape[0],0),value=0) if x_item.shape[0]<max_len else x_item ### padding left
|
||||
x_list.append(x_item)
|
||||
x = torch.stack(x_list, dim=0)
|
||||
x:torch.Tensor = torch.stack(x_list, dim=0)
|
||||
|
||||
|
||||
# AR Decoder
|
||||
y = prompts
|
||||
|
||||
x_len = x.shape[1]
|
||||
x_attn_mask = torch.zeros((x_len, x_len), dtype=torch.bool)
|
||||
stop = False
|
||||
|
||||
k_cache = None
|
||||
v_cache = None
|
||||
################### first step ##########################
|
||||
if y is not None:
|
||||
y_emb = self.ar_audio_embedding(y)
|
||||
y_len = y_emb.shape[1]
|
||||
prefix_len = y.shape[1]
|
||||
y_lens = torch.LongTensor([y_emb.shape[1]]*y_emb.shape[0]).to(x.device)
|
||||
y_pos = self.ar_audio_position(y_emb)
|
||||
xy_pos = torch.concat([x, y_pos], dim=1)
|
||||
ref_free = False
|
||||
else:
|
||||
y_emb = None
|
||||
y_len = 0
|
||||
prefix_len = 0
|
||||
y_lens = torch.LongTensor([y_len]*x.shape[0]).to(x.device)
|
||||
y_pos = None
|
||||
xy_pos = x
|
||||
y = torch.zeros(x.shape[0], 0, dtype=torch.int, device=x.device)
|
||||
ref_free = True
|
||||
assert y is not None, "Error: Prompt free is not supported batch_infer!"
|
||||
ref_free = False
|
||||
|
||||
y_emb = self.ar_audio_embedding(y)
|
||||
y_len = y_emb.shape[1]
|
||||
prefix_len = y.shape[1]
|
||||
y_lens = torch.LongTensor([y_emb.shape[1]]*y_emb.shape[0]).to(x.device)
|
||||
y_pos = self.ar_audio_position(y_emb)
|
||||
xy_pos = torch.concat([x, y_pos], dim=1)
|
||||
|
||||
|
||||
|
||||
##### create mask #####
|
||||
bsz = x.shape[0]
|
||||
src_len = x_len + y_len
|
||||
y_paddind_mask = make_pad_mask(y_lens, y_len)
|
||||
x_paddind_mask = make_pad_mask(x_lens, max_len)
|
||||
y_paddind_mask = make_pad_mask_left(y_lens, y_len)
|
||||
x_paddind_mask = make_pad_mask_left(x_lens, max_len)
|
||||
|
||||
# (bsz, x_len + y_len)
|
||||
xy_padding_mask = torch.concat([x_paddind_mask, y_paddind_mask], dim=1)
|
||||
padding_mask = torch.concat([x_paddind_mask, y_paddind_mask], dim=1)
|
||||
|
||||
x_mask = F.pad(
|
||||
torch.zeros(x_len, x_len, dtype=torch.bool, device=x.device),
|
||||
(0, y_len),
|
||||
value=True,
|
||||
)
|
||||
|
||||
x_mask = F.pad(
|
||||
x_attn_mask,
|
||||
(0, y_len), ###xx的纯0扩展到xx纯0+xy纯1,(x,x+y)
|
||||
value=True,
|
||||
)
|
||||
y_mask = F.pad( ###yy的右上1扩展到左边xy的0,(y,x+y)
|
||||
torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1),
|
||||
torch.triu(torch.ones(y_len, y_len, dtype=torch.bool, device=x.device), diagonal=1),
|
||||
(x_len, 0),
|
||||
value=False,
|
||||
)
|
||||
|
||||
xy_mask = torch.concat([x_mask, y_mask], dim=0).view(1 , src_len, src_len).repeat(bsz, 1, 1).to(x.device)
|
||||
_xy_padding_mask = xy_padding_mask.view(bsz, 1, src_len).repeat(1, src_len, 1)
|
||||
|
||||
for i in range(bsz):
|
||||
l = x_lens[i]
|
||||
_xy_padding_mask[i,l:max_len,:]=True
|
||||
|
||||
xy_attn_mask = xy_mask.logical_or(_xy_padding_mask)
|
||||
xy_attn_mask = xy_attn_mask.unsqueeze(1).expand(-1, self.num_head, -1, -1)
|
||||
xy_attn_mask = xy_attn_mask.bool()
|
||||
xy_padding_mask = xy_padding_mask.view(bsz, src_len, 1)
|
||||
causal_mask = torch.concat([x_mask, y_mask], dim=0).view(1 , src_len, src_len).repeat(bsz, 1, 1).to(x.device)
|
||||
# padding_mask = padding_mask.unsqueeze(1) * padding_mask.unsqueeze(2) ### [b, x+y, x+y]
|
||||
### 上面是错误的,会导致padding的token被"看见"
|
||||
|
||||
# 正确的padding_mask应该是:
|
||||
# | pad_len | x_len | y_len |
|
||||
# [[PAD, PAD, PAD, 1, 2, 3, 4, 5, 6],
|
||||
# [PAD, PAD, PAD, 1, 2, 3, 4, 5, 6],
|
||||
# [PAD, PAD, PAD, 1, 2, 3, 4, 5, 6], 前3行按理说也应该被mask掉,但是为了防止计算attention时不出现nan,还是保留了,不影响结果
|
||||
# [PAD, PAD, PAD, 1, 2, 3, 4, 5, 6],
|
||||
# [PAD, PAD, PAD, 1, 2, 3, 4, 5, 6],
|
||||
# [PAD, PAD, PAD, 1, 2, 3, 4, 5, 6],
|
||||
# [PAD, PAD, PAD, 1, 2, 3, 4, 5, 6],
|
||||
# [PAD, PAD, PAD, 1, 2, 3, 4, 5, 6],
|
||||
# [PAD, PAD, PAD, 1, 2, 3, 4, 5, 6]]
|
||||
|
||||
padding_mask = padding_mask.view(bsz, 1, src_len).repeat(1, src_len, 1)
|
||||
|
||||
attn_mask:torch.Tensor = causal_mask.logical_or(padding_mask)
|
||||
attn_mask = attn_mask.unsqueeze(1).expand(-1, self.num_head, -1, -1).bool()
|
||||
|
||||
|
||||
# 正确的attn_mask应该是这样的:
|
||||
# | pad_len | x_len | y_len |
|
||||
# [[PAD, PAD, PAD, 1, 2, 3, EOS, EOS, EOS],
|
||||
# [PAD, PAD, PAD, 1, 2, 3, EOS, EOS, EOS],
|
||||
# [PAD, PAD, PAD, 1, 2, 3, EOS, EOS, EOS], 前3行按理说也应该被mask掉,但是为了防止计算attention时不出现nan,还是保留了,不影响结果
|
||||
# [PAD, PAD, PAD, 1, 2, 3, EOS, EOS, EOS],
|
||||
# [PAD, PAD, PAD, 1, 2, 3, EOS, EOS, EOS],
|
||||
# [PAD, PAD, PAD, 1, 2, 3, EOS, EOS, EOS],
|
||||
# [PAD, PAD, PAD, 1, 2, 3, 4, EOS, EOS],
|
||||
# [PAD, PAD, PAD, 1, 2, 3, 4, 5, EOS],
|
||||
# [PAD, PAD, PAD, 1, 2, 3, 4, 5, 6]]
|
||||
|
||||
|
||||
###### decode #####
|
||||
y_list = [None]*y.shape[0]
|
||||
@ -645,18 +662,18 @@ class Text2SemanticDecoder(nn.Module):
|
||||
idx_list = [None]*y.shape[0]
|
||||
for idx in tqdm(range(1500)):
|
||||
if idx == 0:
|
||||
xy_dec, k_cache, v_cache = self.t2s_transformer.process_prompt(xy_pos, xy_attn_mask, xy_padding_mask, False)
|
||||
xy_dec, k_cache, v_cache = self.t2s_transformer.process_prompt(xy_pos, attn_mask, None)
|
||||
else:
|
||||
xy_dec, k_cache, v_cache = self.t2s_transformer.decode_next_token(xy_pos, k_cache, v_cache, xy_attn_mask, False)
|
||||
xy_dec, k_cache, v_cache = self.t2s_transformer.decode_next_token(xy_pos, k_cache, v_cache, attn_mask)
|
||||
logits = self.ar_predict_layer(
|
||||
xy_dec[:, -1]
|
||||
)
|
||||
|
||||
if idx == 0:
|
||||
xy_attn_mask = F.pad(xy_attn_mask[:,:,-1].unsqueeze(-2),(0,1),value=False)
|
||||
attn_mask = F.pad(attn_mask[:,:,-1].unsqueeze(-2),(0,1),value=False)
|
||||
logits = logits[:, :-1]
|
||||
else:
|
||||
xy_attn_mask = F.pad(xy_attn_mask,(0,1),value=False)
|
||||
attn_mask = F.pad(attn_mask,(0,1),value=False)
|
||||
|
||||
samples = sample(
|
||||
logits, y, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature
|
||||
@ -686,7 +703,7 @@ class Text2SemanticDecoder(nn.Module):
|
||||
if reserved_idx_of_batch_for_y is not None:
|
||||
# index = torch.LongTensor(batch_idx_map).to(y.device)
|
||||
y = torch.index_select(y, dim=0, index=reserved_idx_of_batch_for_y)
|
||||
xy_attn_mask = torch.index_select(xy_attn_mask, dim=0, index=reserved_idx_of_batch_for_y)
|
||||
attn_mask = torch.index_select(attn_mask, dim=0, index=reserved_idx_of_batch_for_y)
|
||||
if k_cache is not None :
|
||||
for i in range(len(k_cache)):
|
||||
k_cache[i] = torch.index_select(k_cache[i], dim=0, index=reserved_idx_of_batch_for_y)
|
||||
|
@ -39,6 +39,39 @@ def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
|
||||
return expaned_lengths >= lengths.unsqueeze(-1)
|
||||
|
||||
|
||||
def make_pad_mask_left(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
lengths:
|
||||
A 1-D tensor containing sentence lengths.
|
||||
max_len:
|
||||
The length of masks.
|
||||
Returns:
|
||||
Return a 2-D bool tensor, where masked positions
|
||||
are filled with `True` and non-masked positions are
|
||||
filled with `False`.
|
||||
|
||||
#>>> lengths = torch.tensor([1, 3, 2, 5])
|
||||
#>>> make_pad_mask(lengths)
|
||||
tensor(
|
||||
[
|
||||
[True, True, False],
|
||||
[True, False, False],
|
||||
[True, True, False],
|
||||
...
|
||||
]
|
||||
)
|
||||
"""
|
||||
assert lengths.ndim == 1, lengths.ndim
|
||||
max_len = max(max_len, lengths.max())
|
||||
n = lengths.size(0)
|
||||
seq_range = torch.arange(0, max_len, device=lengths.device)
|
||||
expaned_lengths = seq_range.unsqueeze(0).repeat(n, 1)
|
||||
expaned_lengths -= (max_len-lengths).unsqueeze(-1)
|
||||
|
||||
return expaned_lengths<0
|
||||
|
||||
|
||||
# https://github.com/microsoft/unilm/blob/master/xtune/src/transformers/modeling_utils.py
|
||||
def top_k_top_p_filtering(
|
||||
logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens_to_keep=1
|
||||
|
@ -145,7 +145,15 @@ class TTS_Config:
|
||||
|
||||
|
||||
self.device = self.configs.get("device", torch.device("cpu"))
|
||||
if "cuda" in str(self.device) and not torch.cuda.is_available():
|
||||
print(f"Warning: CUDA is not available, set device to CPU.")
|
||||
self.device = torch.device("cpu")
|
||||
|
||||
self.is_half = self.configs.get("is_half", False)
|
||||
# if str(self.device) == "cpu" and self.is_half:
|
||||
# print(f"Warning: Half precision is not supported on CPU, set is_half to False.")
|
||||
# self.is_half = False
|
||||
|
||||
self.version = version
|
||||
self.t2s_weights_path = self.configs.get("t2s_weights_path", None)
|
||||
self.vits_weights_path = self.configs.get("vits_weights_path", None)
|
||||
|
@ -691,7 +691,7 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
|
||||
wav_gen = bigvgan_model(cmf_res)
|
||||
audio=wav_gen[0][0]#.cpu().detach().numpy()
|
||||
max_audio=torch.abs(audio).max()#简单防止16bit爆音
|
||||
if max_audio>1:audio/=max_audio
|
||||
if max_audio>1:audio=audio/max_audio
|
||||
audio_opt.append(audio)
|
||||
audio_opt.append(zero_wav_torch)#zero_wav
|
||||
t4 = ttime()
|
||||
|
@ -1162,6 +1162,7 @@ class SynthesizerTrnV3(nn.Module):
|
||||
use_sdp=True,
|
||||
semantic_frame_rate=None,
|
||||
freeze_quantizer=None,
|
||||
version="v3",
|
||||
**kwargs):
|
||||
|
||||
super().__init__()
|
||||
@ -1182,6 +1183,7 @@ class SynthesizerTrnV3(nn.Module):
|
||||
self.segment_size = segment_size
|
||||
self.n_speakers = n_speakers
|
||||
self.gin_channels = gin_channels
|
||||
self.version = version
|
||||
|
||||
self.model_dim=512
|
||||
self.use_sdp = use_sdp
|
||||
|
@ -8,66 +8,7 @@ jieba.setLogLevel(logging.CRITICAL)
|
||||
# 更改fast_langdetect大模型位置
|
||||
from pathlib import Path
|
||||
import fast_langdetect
|
||||
fast_langdetect.ft_detect.infer.CACHE_DIRECTORY = Path(__file__).parent.parent.parent / "pretrained_models" / "fast_langdetect"
|
||||
|
||||
# 防止win下无法读取模型
|
||||
import os
|
||||
from typing import Optional
|
||||
def load_fasttext_model(
|
||||
model_path: Path,
|
||||
download_url: Optional[str] = None,
|
||||
proxy: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Load a FastText model, downloading it if necessary.
|
||||
:param model_path: Path to the FastText model file
|
||||
:param download_url: URL to download the model from
|
||||
:param proxy: Proxy URL for downloading the model
|
||||
:return: FastText model
|
||||
:raises DetectError: If model loading fails
|
||||
"""
|
||||
if all([
|
||||
fast_langdetect.ft_detect.infer.VERIFY_FASTTEXT_LARGE_MODEL,
|
||||
model_path.exists(),
|
||||
model_path.name == fast_langdetect.ft_detect.infer.FASTTEXT_LARGE_MODEL_NAME,
|
||||
]):
|
||||
if not fast_langdetect.ft_detect.infer.verify_md5(model_path, fast_langdetect.ft_detect.infer.VERIFY_FASTTEXT_LARGE_MODEL):
|
||||
fast_langdetect.ft_detect.infer.logger.warning(
|
||||
f"fast-langdetect: MD5 hash verification failed for {model_path}, "
|
||||
f"please check the integrity of the downloaded file from {fast_langdetect.ft_detect.infer.FASTTEXT_LARGE_MODEL_URL}. "
|
||||
"\n This may seriously reduce the prediction accuracy. "
|
||||
"If you want to ignore this, please set `fast_langdetect.ft_detect.infer.VERIFY_FASTTEXT_LARGE_MODEL = None` "
|
||||
)
|
||||
if not model_path.exists():
|
||||
if download_url:
|
||||
fast_langdetect.ft_detect.infer.download_model(download_url, model_path, proxy)
|
||||
if not model_path.exists():
|
||||
raise fast_langdetect.ft_detect.infer.DetectError(f"FastText model file not found at {model_path}")
|
||||
|
||||
try:
|
||||
# Load FastText model
|
||||
if (re.match(r'^[A-Za-z0-9_/\\:.]*$', str(model_path))):
|
||||
model = fast_langdetect.ft_detect.infer.fasttext.load_model(str(model_path))
|
||||
else:
|
||||
python_dir = os.getcwd()
|
||||
if (str(model_path)[:len(python_dir)].upper() == python_dir.upper()):
|
||||
model = fast_langdetect.ft_detect.infer.fasttext.load_model(os.path.relpath(model_path, python_dir))
|
||||
else:
|
||||
import tempfile
|
||||
import shutil
|
||||
with tempfile.NamedTemporaryFile(delete=False) as tmpfile:
|
||||
shutil.copyfile(model_path, tmpfile.name)
|
||||
|
||||
model = fast_langdetect.ft_detect.infer.fasttext.load_model(tmpfile.name)
|
||||
os.unlink(tmpfile.name)
|
||||
return model
|
||||
|
||||
except Exception as e:
|
||||
fast_langdetect.ft_detect.infer.logger.warning(f"fast-langdetect:Failed to load FastText model from {model_path}: {e}")
|
||||
raise fast_langdetect.ft_detect.infer.DetectError(f"Failed to load FastText model: {e}")
|
||||
|
||||
if os.name == 'nt':
|
||||
fast_langdetect.ft_detect.infer.load_fasttext_model = load_fasttext_model
|
||||
fast_langdetect.infer._default_detector = fast_langdetect.infer.LangDetector(fast_langdetect.infer.LangDetectConfig(cache_dir=Path(__file__).parent.parent.parent / "pretrained_models" / "fast_langdetect"))
|
||||
|
||||
|
||||
from split_lang import LangSplitter
|
||||
|
@ -17,6 +17,8 @@ pinyin_to_symbol_map = {
|
||||
for line in open(os.path.join(current_file_path, "opencpop-strict.txt")).readlines()
|
||||
}
|
||||
|
||||
import jieba_fast, logging
|
||||
jieba_fast.setLogLevel(logging.CRITICAL)
|
||||
import jieba_fast.posseg as psg
|
||||
|
||||
|
||||
|
@ -18,13 +18,15 @@ pinyin_to_symbol_map = {
|
||||
for line in open(os.path.join(current_file_path, "opencpop-strict.txt")).readlines()
|
||||
}
|
||||
|
||||
import jieba_fast, logging
|
||||
jieba_fast.setLogLevel(logging.CRITICAL)
|
||||
import jieba_fast.posseg as psg
|
||||
|
||||
# is_g2pw_str = os.environ.get("is_g2pw", "True")##默认开启
|
||||
# is_g2pw = False#True if is_g2pw_str.lower() == 'true' else False
|
||||
is_g2pw = True#True if is_g2pw_str.lower() == 'true' else False
|
||||
if is_g2pw:
|
||||
print("当前使用g2pw进行拼音推理")
|
||||
# print("当前使用g2pw进行拼音推理")
|
||||
from text.g2pw import G2PWPinyin, correct_pronunciation
|
||||
parent_directory = os.path.dirname(current_file_path)
|
||||
g2pw = G2PWPinyin(model_dir="GPT_SoVITS/text/G2PWModel",model_source=os.environ.get("bert_path","GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large"),v_to_u=False, neutral_tone_with_five=True)
|
||||
|
@ -10,7 +10,7 @@ try:
|
||||
if os.name == 'nt':
|
||||
python_dir = os.getcwd()
|
||||
OPEN_JTALK_DICT_DIR = pyopenjtalk.OPEN_JTALK_DICT_DIR.decode("utf-8")
|
||||
if not (re.match(r'^[A-Za-z0-9_/\\:.]*$', OPEN_JTALK_DICT_DIR)):
|
||||
if not (re.match(r'^[A-Za-z0-9_/\\:.\-]*$', OPEN_JTALK_DICT_DIR)):
|
||||
if (OPEN_JTALK_DICT_DIR[:len(python_dir)].upper() == python_dir.upper()):
|
||||
OPEN_JTALK_DICT_DIR = os.path.join(os.path.relpath(OPEN_JTALK_DICT_DIR,python_dir))
|
||||
else:
|
||||
@ -25,7 +25,7 @@ try:
|
||||
OPEN_JTALK_DICT_DIR = os.path.join("TEMP", "ja", "open_jtalk_dic")
|
||||
pyopenjtalk.OPEN_JTALK_DICT_DIR = OPEN_JTALK_DICT_DIR.encode("utf-8")
|
||||
|
||||
if not (re.match(r'^[A-Za-z0-9_/\\:.]*$', current_file_path)):
|
||||
if not (re.match(r'^[A-Za-z0-9_/\\:.\-]*$', current_file_path)):
|
||||
if (current_file_path[:len(python_dir)].upper() == python_dir.upper()):
|
||||
current_file_path = os.path.join(os.path.relpath(current_file_path,python_dir))
|
||||
else:
|
||||
|
@ -19,13 +19,13 @@ if os.name == 'nt':
|
||||
print(f'you have to install eunjeon. install it...')
|
||||
else:
|
||||
installpath = spam_spec.submodule_search_locations[0]
|
||||
if not (re.match(r'^[A-Za-z0-9_/\\:.]*$', installpath)):
|
||||
if not (re.match(r'^[A-Za-z0-9_/\\:.\-]*$', installpath)):
|
||||
|
||||
import sys
|
||||
from eunjeon import Mecab as _Mecab
|
||||
class Mecab(_Mecab):
|
||||
def get_dicpath(installpath):
|
||||
if not (re.match(r'^[A-Za-z0-9_/\\:.]*$', installpath)):
|
||||
if not (re.match(r'^[A-Za-z0-9_/\\:.\-]*$', installpath)):
|
||||
import shutil
|
||||
python_dir = os.getcwd()
|
||||
if (installpath[:len(python_dir)].upper() == python_dir.upper()):
|
||||
|
273
api.py
273
api.py
@ -150,9 +150,9 @@ sys.path.append(now_dir)
|
||||
sys.path.append("%s/GPT_SoVITS" % (now_dir))
|
||||
|
||||
import signal
|
||||
import LangSegment
|
||||
from text.LangSegmenter import LangSegmenter
|
||||
from time import time as ttime
|
||||
import torch
|
||||
import torch, torchaudio
|
||||
import librosa
|
||||
import soundfile as sf
|
||||
from fastapi import FastAPI, Request, Query, HTTPException
|
||||
@ -162,7 +162,8 @@ from transformers import AutoModelForMaskedLM, AutoTokenizer
|
||||
import numpy as np
|
||||
from feature_extractor import cnhubert
|
||||
from io import BytesIO
|
||||
from module.models import SynthesizerTrn
|
||||
from module.models import SynthesizerTrn, SynthesizerTrnV3
|
||||
from peft import LoraConfig, PeftModel, get_peft_model
|
||||
from AR.models.t2s_lightning_module import Text2SemanticLightningModule
|
||||
from text import cleaned_text_to_sequence
|
||||
from text.cleaner import clean_text
|
||||
@ -197,6 +198,61 @@ def is_full(*items): # 任意一项为空返回False
|
||||
return True
|
||||
|
||||
|
||||
def init_bigvgan():
|
||||
global bigvgan_model
|
||||
from BigVGAN import bigvgan
|
||||
bigvgan_model = bigvgan.BigVGAN.from_pretrained("%s/GPT_SoVITS/pretrained_models/models--nvidia--bigvgan_v2_24khz_100band_256x" % (now_dir,), use_cuda_kernel=False) # if True, RuntimeError: Ninja is required to load C++ extensions
|
||||
# remove weight norm in the model and set to eval mode
|
||||
bigvgan_model.remove_weight_norm()
|
||||
bigvgan_model = bigvgan_model.eval()
|
||||
if is_half == True:
|
||||
bigvgan_model = bigvgan_model.half().to(device)
|
||||
else:
|
||||
bigvgan_model = bigvgan_model.to(device)
|
||||
|
||||
|
||||
resample_transform_dict={}
|
||||
def resample(audio_tensor, sr0):
|
||||
global resample_transform_dict
|
||||
if sr0 not in resample_transform_dict:
|
||||
resample_transform_dict[sr0] = torchaudio.transforms.Resample(
|
||||
sr0, 24000
|
||||
).to(device)
|
||||
return resample_transform_dict[sr0](audio_tensor)
|
||||
|
||||
|
||||
from module.mel_processing import spectrogram_torch,mel_spectrogram_torch
|
||||
spec_min = -12
|
||||
spec_max = 2
|
||||
def norm_spec(x):
|
||||
return (x - spec_min) / (spec_max - spec_min) * 2 - 1
|
||||
def denorm_spec(x):
|
||||
return (x + 1) / 2 * (spec_max - spec_min) + spec_min
|
||||
mel_fn=lambda x: mel_spectrogram_torch(x, **{
|
||||
"n_fft": 1024,
|
||||
"win_size": 1024,
|
||||
"hop_size": 256,
|
||||
"num_mels": 100,
|
||||
"sampling_rate": 24000,
|
||||
"fmin": 0,
|
||||
"fmax": None,
|
||||
"center": False
|
||||
})
|
||||
|
||||
|
||||
sr_model=None
|
||||
def audio_sr(audio,sr):
|
||||
global sr_model
|
||||
if sr_model==None:
|
||||
from tools.audio_sr import AP_BWE
|
||||
try:
|
||||
sr_model=AP_BWE(device,DictToAttrRecursive)
|
||||
except FileNotFoundError:
|
||||
logger.info("你没有下载超分模型的参数,因此不进行超分。如想超分请先参照教程把文件下载")
|
||||
return audio.cpu().detach().numpy(),sr
|
||||
return sr_model(audio,sr)
|
||||
|
||||
|
||||
class Speaker:
|
||||
def __init__(self, name, gpt, sovits, phones = None, bert = None, prompt = None):
|
||||
self.name = name
|
||||
@ -214,31 +270,72 @@ class Sovits:
|
||||
self.vq_model = vq_model
|
||||
self.hps = hps
|
||||
|
||||
from process_ckpt import get_sovits_version_from_path_fast,load_sovits_new
|
||||
def get_sovits_weights(sovits_path):
|
||||
dict_s2 = torch.load(sovits_path, map_location="cpu")
|
||||
path_sovits_v3="GPT_SoVITS/pretrained_models/s2Gv3.pth"
|
||||
is_exist_s2gv3=os.path.exists(path_sovits_v3)
|
||||
|
||||
version, model_version, if_lora_v3=get_sovits_version_from_path_fast(sovits_path)
|
||||
if if_lora_v3==True and is_exist_s2gv3==False:
|
||||
logger.info("SoVITS V3 底模缺失,无法加载相应 LoRA 权重")
|
||||
|
||||
dict_s2 = load_sovits_new(sovits_path)
|
||||
hps = dict_s2["config"]
|
||||
hps = DictToAttrRecursive(hps)
|
||||
hps.model.semantic_frame_rate = "25hz"
|
||||
if dict_s2['weight']['enc_p.text_embedding.weight'].shape[0] == 322:
|
||||
if 'enc_p.text_embedding.weight' not in dict_s2['weight']:
|
||||
hps.model.version = "v2"#v3model,v2sybomls
|
||||
elif dict_s2['weight']['enc_p.text_embedding.weight'].shape[0] == 322:
|
||||
hps.model.version = "v1"
|
||||
else:
|
||||
hps.model.version = "v2"
|
||||
logger.info(f"模型版本: {hps.model.version}")
|
||||
|
||||
if model_version == "v3":
|
||||
hps.model.version = "v3"
|
||||
|
||||
model_params_dict = vars(hps.model)
|
||||
vq_model = SynthesizerTrn(
|
||||
hps.data.filter_length // 2 + 1,
|
||||
hps.train.segment_size // hps.data.hop_length,
|
||||
n_speakers=hps.data.n_speakers,
|
||||
**model_params_dict
|
||||
)
|
||||
if model_version!="v3":
|
||||
vq_model = SynthesizerTrn(
|
||||
hps.data.filter_length // 2 + 1,
|
||||
hps.train.segment_size // hps.data.hop_length,
|
||||
n_speakers=hps.data.n_speakers,
|
||||
**model_params_dict
|
||||
)
|
||||
else:
|
||||
vq_model = SynthesizerTrnV3(
|
||||
hps.data.filter_length // 2 + 1,
|
||||
hps.train.segment_size // hps.data.hop_length,
|
||||
n_speakers=hps.data.n_speakers,
|
||||
**model_params_dict
|
||||
)
|
||||
init_bigvgan()
|
||||
model_version=hps.model.version
|
||||
logger.info(f"模型版本: {model_version}")
|
||||
if ("pretrained" not in sovits_path):
|
||||
del vq_model.enc_q
|
||||
try:
|
||||
del vq_model.enc_q
|
||||
except:pass
|
||||
if is_half == True:
|
||||
vq_model = vq_model.half().to(device)
|
||||
else:
|
||||
vq_model = vq_model.to(device)
|
||||
vq_model.eval()
|
||||
vq_model.load_state_dict(dict_s2["weight"], strict=False)
|
||||
if if_lora_v3 == False:
|
||||
vq_model.load_state_dict(dict_s2["weight"], strict=False)
|
||||
else:
|
||||
vq_model.load_state_dict(load_sovits_new(path_sovits_v3)["weight"], strict=False)
|
||||
lora_rank=dict_s2["lora_rank"]
|
||||
lora_config = LoraConfig(
|
||||
target_modules=["to_k", "to_q", "to_v", "to_out.0"],
|
||||
r=lora_rank,
|
||||
lora_alpha=lora_rank,
|
||||
init_lora_weights=True,
|
||||
)
|
||||
vq_model.cfm = get_peft_model(vq_model.cfm, lora_config)
|
||||
vq_model.load_state_dict(dict_s2["weight"], strict=False)
|
||||
vq_model.cfm = vq_model.cfm.merge_and_unload()
|
||||
# torch.save(vq_model.state_dict(),"merge_win.pth")
|
||||
vq_model.eval()
|
||||
|
||||
sovits = Sovits(vq_model, hps)
|
||||
return sovits
|
||||
@ -260,8 +357,8 @@ def get_gpt_weights(gpt_path):
|
||||
t2s_model = t2s_model.half()
|
||||
t2s_model = t2s_model.to(device)
|
||||
t2s_model.eval()
|
||||
total = sum([param.nelement() for param in t2s_model.parameters()])
|
||||
logger.info("Number of parameter: %.2fM" % (total / 1e6))
|
||||
# total = sum([param.nelement() for param in t2s_model.parameters()])
|
||||
# logger.info("Number of parameter: %.2fM" % (total / 1e6))
|
||||
|
||||
gpt = Gpt(max_sec, t2s_model)
|
||||
return gpt
|
||||
@ -295,6 +392,7 @@ def get_bert_feature(text, word2ph):
|
||||
|
||||
|
||||
def clean_text_inf(text, language, version):
|
||||
language = language.replace("all_","")
|
||||
phones, word2ph, norm_text = clean_text(text, language, version)
|
||||
phones = cleaned_text_to_sequence(phones, version)
|
||||
return phones, word2ph, norm_text
|
||||
@ -315,16 +413,10 @@ def get_bert_inf(phones, word2ph, norm_text, language):
|
||||
from text import chinese
|
||||
def get_phones_and_bert(text,language,version,final=False):
|
||||
if language in {"en", "all_zh", "all_ja", "all_ko", "all_yue"}:
|
||||
language = language.replace("all_","")
|
||||
if language == "en":
|
||||
LangSegment.setfilters(["en"])
|
||||
formattext = " ".join(tmp["text"] for tmp in LangSegment.getTexts(text))
|
||||
else:
|
||||
# 因无法区别中日韩文汉字,以用户输入为准
|
||||
formattext = text
|
||||
formattext = text
|
||||
while " " in formattext:
|
||||
formattext = formattext.replace(" ", " ")
|
||||
if language == "zh":
|
||||
if language == "all_zh":
|
||||
if re.search(r'[A-Za-z]', formattext):
|
||||
formattext = re.sub(r'[a-z]', lambda x: x.group(0).upper(), formattext)
|
||||
formattext = chinese.mix_text_normalize(formattext)
|
||||
@ -332,7 +424,7 @@ def get_phones_and_bert(text,language,version,final=False):
|
||||
else:
|
||||
phones, word2ph, norm_text = clean_text_inf(formattext, language, version)
|
||||
bert = get_bert_feature(norm_text, word2ph).to(device)
|
||||
elif language == "yue" and re.search(r'[A-Za-z]', formattext):
|
||||
elif language == "all_yue" and re.search(r'[A-Za-z]', formattext):
|
||||
formattext = re.sub(r'[a-z]', lambda x: x.group(0).upper(), formattext)
|
||||
formattext = chinese.mix_text_normalize(formattext)
|
||||
return get_phones_and_bert(formattext,"yue",version)
|
||||
@ -345,19 +437,18 @@ def get_phones_and_bert(text,language,version,final=False):
|
||||
elif language in {"zh", "ja", "ko", "yue", "auto", "auto_yue"}:
|
||||
textlist=[]
|
||||
langlist=[]
|
||||
LangSegment.setfilters(["zh","ja","en","ko"])
|
||||
if language == "auto":
|
||||
for tmp in LangSegment.getTexts(text):
|
||||
for tmp in LangSegmenter.getTexts(text):
|
||||
langlist.append(tmp["lang"])
|
||||
textlist.append(tmp["text"])
|
||||
elif language == "auto_yue":
|
||||
for tmp in LangSegment.getTexts(text):
|
||||
for tmp in LangSegmenter.getTexts(text):
|
||||
if tmp["lang"] == "zh":
|
||||
tmp["lang"] = "yue"
|
||||
langlist.append(tmp["lang"])
|
||||
textlist.append(tmp["text"])
|
||||
else:
|
||||
for tmp in LangSegment.getTexts(text):
|
||||
for tmp in LangSegmenter.getTexts(text):
|
||||
if tmp["lang"] == "en":
|
||||
langlist.append(tmp["lang"])
|
||||
else:
|
||||
@ -556,10 +647,11 @@ def only_punc(text):
|
||||
|
||||
|
||||
splits = {",", "。", "?", "!", ",", ".", "?", "!", "~", ":", ":", "—", "…", }
|
||||
def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, top_k= 15, top_p = 0.6, temperature = 0.6, speed = 1, inp_refs = None, spk = "default"):
|
||||
def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, top_k= 15, top_p = 0.6, temperature = 0.6, speed = 1, inp_refs = None, sample_steps = 32, if_sr = False, spk = "default"):
|
||||
infer_sovits = speaker_list[spk].sovits
|
||||
vq_model = infer_sovits.vq_model
|
||||
hps = infer_sovits.hps
|
||||
version = vq_model.version
|
||||
|
||||
infer_gpt = speaker_list[spk].gpt
|
||||
t2s_model = infer_gpt.t2s_model
|
||||
@ -587,20 +679,22 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
|
||||
prompt_semantic = codes[0, 0]
|
||||
prompt = prompt_semantic.unsqueeze(0).to(device)
|
||||
|
||||
refers=[]
|
||||
if(inp_refs):
|
||||
for path in inp_refs:
|
||||
try:
|
||||
refer = get_spepc(hps, path).to(dtype).to(device)
|
||||
refers.append(refer)
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
if(len(refers)==0):
|
||||
refers = [get_spepc(hps, ref_wav_path).to(dtype).to(device)]
|
||||
if version != "v3":
|
||||
refers=[]
|
||||
if(inp_refs):
|
||||
for path in inp_refs:
|
||||
try:
|
||||
refer = get_spepc(hps, path).to(dtype).to(device)
|
||||
refers.append(refer)
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
if(len(refers)==0):
|
||||
refers = [get_spepc(hps, ref_wav_path).to(dtype).to(device)]
|
||||
else:
|
||||
refer = get_spepc(hps, ref_wav_path).to(device).to(dtype)
|
||||
|
||||
t1 = ttime()
|
||||
version = vq_model.version
|
||||
os.environ['version'] = version
|
||||
# os.environ['version'] = version
|
||||
prompt_language = dict_language[prompt_language.lower()]
|
||||
text_language = dict_language[text_language.lower()]
|
||||
phones1, bert1, norm_text1 = get_phones_and_bert(prompt_text, prompt_language, version)
|
||||
@ -634,20 +728,82 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
|
||||
early_stop_num=hz * max_sec)
|
||||
pred_semantic = pred_semantic[:, -idx:].unsqueeze(0)
|
||||
t3 = ttime()
|
||||
audio = \
|
||||
vq_model.decode(pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0),
|
||||
refers,speed=speed).detach().cpu().numpy()[
|
||||
0, 0] ###试试重建不带上prompt部分
|
||||
|
||||
if version != "v3":
|
||||
audio = \
|
||||
vq_model.decode(pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0),
|
||||
refers,speed=speed).detach().cpu().numpy()[
|
||||
0, 0] ###试试重建不带上prompt部分
|
||||
else:
|
||||
phoneme_ids0=torch.LongTensor(phones1).to(device).unsqueeze(0)
|
||||
phoneme_ids1=torch.LongTensor(phones2).to(device).unsqueeze(0)
|
||||
# print(11111111, phoneme_ids0, phoneme_ids1)
|
||||
fea_ref,ge = vq_model.decode_encp(prompt.unsqueeze(0), phoneme_ids0, refer)
|
||||
ref_audio, sr = torchaudio.load(ref_wav_path)
|
||||
ref_audio=ref_audio.to(device).float()
|
||||
if (ref_audio.shape[0] == 2):
|
||||
ref_audio = ref_audio.mean(0).unsqueeze(0)
|
||||
if sr!=24000:
|
||||
ref_audio=resample(ref_audio,sr)
|
||||
# print("ref_audio",ref_audio.abs().mean())
|
||||
mel2 = mel_fn(ref_audio)
|
||||
mel2 = norm_spec(mel2)
|
||||
T_min = min(mel2.shape[2], fea_ref.shape[2])
|
||||
mel2 = mel2[:, :, :T_min]
|
||||
fea_ref = fea_ref[:, :, :T_min]
|
||||
if (T_min > 468):
|
||||
mel2 = mel2[:, :, -468:]
|
||||
fea_ref = fea_ref[:, :, -468:]
|
||||
T_min = 468
|
||||
chunk_len = 934 - T_min
|
||||
# print("fea_ref",fea_ref,fea_ref.shape)
|
||||
# print("mel2",mel2)
|
||||
mel2=mel2.to(dtype)
|
||||
fea_todo, ge = vq_model.decode_encp(pred_semantic, phoneme_ids1, refer, ge,speed)
|
||||
# print("fea_todo",fea_todo)
|
||||
# print("ge",ge.abs().mean())
|
||||
cfm_resss = []
|
||||
idx = 0
|
||||
while (1):
|
||||
fea_todo_chunk = fea_todo[:, :, idx:idx + chunk_len]
|
||||
if (fea_todo_chunk.shape[-1] == 0): break
|
||||
idx += chunk_len
|
||||
fea = torch.cat([fea_ref, fea_todo_chunk], 2).transpose(2, 1)
|
||||
# set_seed(123)
|
||||
cfm_res = vq_model.cfm.inference(fea, torch.LongTensor([fea.size(1)]).to(fea.device), mel2, sample_steps, inference_cfg_rate=0)
|
||||
cfm_res = cfm_res[:, :, mel2.shape[2]:]
|
||||
mel2 = cfm_res[:, :, -T_min:]
|
||||
# print("fea", fea)
|
||||
# print("mel2in", mel2)
|
||||
fea_ref = fea_todo_chunk[:, :, -T_min:]
|
||||
cfm_resss.append(cfm_res)
|
||||
cmf_res = torch.cat(cfm_resss, 2)
|
||||
cmf_res = denorm_spec(cmf_res)
|
||||
if bigvgan_model==None:init_bigvgan()
|
||||
with torch.inference_mode():
|
||||
wav_gen = bigvgan_model(cmf_res)
|
||||
audio=wav_gen[0][0].cpu().detach().numpy()
|
||||
|
||||
max_audio=np.abs(audio).max()
|
||||
if max_audio>1:
|
||||
audio/=max_audio
|
||||
audio_opt.append(audio)
|
||||
audio_opt.append(zero_wav)
|
||||
audio_opt = np.concatenate(audio_opt, 0)
|
||||
t4 = ttime()
|
||||
|
||||
sr = hps.data.sampling_rate if version != "v3" else 24000
|
||||
if if_sr and sr == 24000:
|
||||
audio_opt = torch.from_numpy(audio_opt).float().to(device)
|
||||
audio_opt,sr=audio_sr(audio_opt.unsqueeze(0),sr)
|
||||
max_audio=np.abs(audio_opt).max()
|
||||
if max_audio > 1: audio_opt /= max_audio
|
||||
sr = 48000
|
||||
|
||||
if is_int32:
|
||||
audio_bytes = pack_audio(audio_bytes,(np.concatenate(audio_opt, 0) * 2147483647).astype(np.int32),hps.data.sampling_rate)
|
||||
audio_bytes = pack_audio(audio_bytes,(audio_opt * 2147483647).astype(np.int32),sr)
|
||||
else:
|
||||
audio_bytes = pack_audio(audio_bytes,(np.concatenate(audio_opt, 0) * 32768).astype(np.int16),hps.data.sampling_rate)
|
||||
audio_bytes = pack_audio(audio_bytes,(audio_opt * 32768).astype(np.int16),sr)
|
||||
# logger.info("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t3 - t2, t4 - t3))
|
||||
if stream_mode == "normal":
|
||||
audio_bytes, audio_chunk = read_clean_buffer(audio_bytes)
|
||||
@ -655,7 +811,9 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
|
||||
|
||||
if not stream_mode == "normal":
|
||||
if media_type == "wav":
|
||||
audio_bytes = pack_wav(audio_bytes,hps.data.sampling_rate)
|
||||
sr = 48000 if if_sr else 24000
|
||||
sr = hps.data.sampling_rate if version != "v3" else sr
|
||||
audio_bytes = pack_wav(audio_bytes,sr)
|
||||
yield audio_bytes.getvalue()
|
||||
|
||||
|
||||
@ -688,7 +846,7 @@ def handle_change(path, text, language):
|
||||
return JSONResponse({"code": 0, "message": "Success"}, status_code=200)
|
||||
|
||||
|
||||
def handle(refer_wav_path, prompt_text, prompt_language, text, text_language, cut_punc, top_k, top_p, temperature, speed, inp_refs):
|
||||
def handle(refer_wav_path, prompt_text, prompt_language, text, text_language, cut_punc, top_k, top_p, temperature, speed, inp_refs, sample_steps, if_sr):
|
||||
if (
|
||||
refer_wav_path == "" or refer_wav_path is None
|
||||
or prompt_text == "" or prompt_text is None
|
||||
@ -702,12 +860,15 @@ def handle(refer_wav_path, prompt_text, prompt_language, text, text_language, cu
|
||||
if not default_refer.is_ready():
|
||||
return JSONResponse({"code": 400, "message": "未指定参考音频且接口无预设"}, status_code=400)
|
||||
|
||||
if not sample_steps in [4,8,16,32]:
|
||||
sample_steps = 32
|
||||
|
||||
if cut_punc == None:
|
||||
text = cut_text(text,default_cut_punc)
|
||||
else:
|
||||
text = cut_text(text,cut_punc)
|
||||
|
||||
return StreamingResponse(get_tts_wav(refer_wav_path, prompt_text, prompt_language, text, text_language, top_k, top_p, temperature, speed, inp_refs), media_type="audio/"+media_type)
|
||||
return StreamingResponse(get_tts_wav(refer_wav_path, prompt_text, prompt_language, text, text_language, top_k, top_p, temperature, speed, inp_refs, sample_steps, if_sr), media_type="audio/"+media_type)
|
||||
|
||||
|
||||
|
||||
@ -915,7 +1076,9 @@ async def tts_endpoint(request: Request):
|
||||
json_post_raw.get("top_p", 1.0),
|
||||
json_post_raw.get("temperature", 1.0),
|
||||
json_post_raw.get("speed", 1.0),
|
||||
json_post_raw.get("inp_refs", [])
|
||||
json_post_raw.get("inp_refs", []),
|
||||
json_post_raw.get("sample_steps", 32),
|
||||
json_post_raw.get("if_sr", False)
|
||||
)
|
||||
|
||||
|
||||
@ -931,9 +1094,11 @@ async def tts_endpoint(
|
||||
top_p: float = 1.0,
|
||||
temperature: float = 1.0,
|
||||
speed: float = 1.0,
|
||||
inp_refs: list = Query(default=[])
|
||||
inp_refs: list = Query(default=[]),
|
||||
sample_steps: int = 32,
|
||||
if_sr: bool = False
|
||||
):
|
||||
return handle(refer_wav_path, prompt_text, prompt_language, text, text_language, cut_punc, top_k, top_p, temperature, speed, inp_refs)
|
||||
return handle(refer_wav_path, prompt_text, prompt_language, text, text_language, cut_punc, top_k, top_p, temperature, speed, inp_refs, sample_steps, if_sr)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
14
install.sh
14
install.sh
@ -1,4 +1,7 @@
|
||||
#!/bin/bash
|
||||
|
||||
# 安装构建工具
|
||||
# Install build tools
|
||||
echo "Installing GCC..."
|
||||
conda install -c conda-forge gcc=14
|
||||
|
||||
@ -8,6 +11,12 @@ conda install -c conda-forge gxx
|
||||
echo "Installing ffmpeg and cmake..."
|
||||
conda install ffmpeg cmake
|
||||
|
||||
# 设置编译环境
|
||||
# Set up build environment
|
||||
export CMAKE_MAKE_PROGRAM="$CONDA_PREFIX/bin/cmake"
|
||||
export CC="$CONDA_PREFIX/bin/gcc"
|
||||
export CXX="$CONDA_PREFIX/bin/g++"
|
||||
|
||||
echo "Checking for CUDA installation..."
|
||||
if command -v nvidia-smi &> /dev/null; then
|
||||
USE_CUDA=true
|
||||
@ -49,6 +58,10 @@ fi
|
||||
|
||||
|
||||
echo "Installing Python dependencies from requirements.txt..."
|
||||
|
||||
# 刷新环境
|
||||
# Refresh environment
|
||||
hash -r
|
||||
pip install -r requirements.txt
|
||||
|
||||
if [ "$USE_ROCM" = true ] && [ "$IS_WSL" = true ] ; then
|
||||
@ -60,3 +73,4 @@ if [ "$USE_ROCM" = true ] && [ "$IS_WSL" = true ] ; then
|
||||
fi
|
||||
|
||||
echo "Installation completed successfully!"
|
||||
|
||||
|
@ -25,7 +25,7 @@ psutil
|
||||
jieba_fast
|
||||
jieba
|
||||
split-lang
|
||||
fast_langdetect
|
||||
fast_langdetect>=0.3.0
|
||||
Faster_Whisper
|
||||
wordsegment
|
||||
rotary_embedding_torch
|
||||
|
Loading…
x
Reference in New Issue
Block a user