mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-06-28 08:41:29 +08:00
Compare commits
6 Commits
9d5b89e7ee
...
6bb625cfe4
Author | SHA1 | Date | |
---|---|---|---|
|
6bb625cfe4 | ||
|
ee4a466f79 | ||
|
b65ea9181e | ||
|
c0ce55a132 | ||
|
13573a1b06 | ||
|
0cc6b8aaef |
@ -12,33 +12,33 @@ import torch
|
|||||||
|
|
||||||
|
|
||||||
def multi_head_attention_forward_patched(
|
def multi_head_attention_forward_patched(
|
||||||
query: Tensor,
|
query,
|
||||||
key: Tensor,
|
key,
|
||||||
value: Tensor,
|
value,
|
||||||
embed_dim_to_check: int,
|
embed_dim_to_check,
|
||||||
num_heads: int,
|
num_heads,
|
||||||
in_proj_weight: Optional[Tensor],
|
in_proj_weight,
|
||||||
in_proj_bias: Optional[Tensor],
|
in_proj_bias,
|
||||||
bias_k: Optional[Tensor],
|
bias_k,
|
||||||
bias_v: Optional[Tensor],
|
bias_v,
|
||||||
add_zero_attn: bool,
|
add_zero_attn,
|
||||||
dropout_p: float,
|
dropout_p: float,
|
||||||
out_proj_weight: Tensor,
|
out_proj_weight,
|
||||||
out_proj_bias: Optional[Tensor],
|
out_proj_bias,
|
||||||
training: bool = True,
|
training = True,
|
||||||
key_padding_mask: Optional[Tensor] = None,
|
key_padding_mask = None,
|
||||||
need_weights: bool = True,
|
need_weights = True,
|
||||||
attn_mask: Optional[Tensor] = None,
|
attn_mask = None,
|
||||||
use_separate_proj_weight: bool = False,
|
use_separate_proj_weight = False,
|
||||||
q_proj_weight: Optional[Tensor] = None,
|
q_proj_weight = None,
|
||||||
k_proj_weight: Optional[Tensor] = None,
|
k_proj_weight = None,
|
||||||
v_proj_weight: Optional[Tensor] = None,
|
v_proj_weight = None,
|
||||||
static_k: Optional[Tensor] = None,
|
static_k = None,
|
||||||
static_v: Optional[Tensor] = None,
|
static_v = None,
|
||||||
average_attn_weights: bool = True,
|
average_attn_weights = True,
|
||||||
is_causal: bool = False,
|
is_causal = False,
|
||||||
cache=None,
|
cache=None,
|
||||||
) -> Tuple[Tensor, Optional[Tensor]]:
|
):
|
||||||
r"""
|
r"""
|
||||||
Args:
|
Args:
|
||||||
query, key, value: map a query and a set of key-value pairs to an output.
|
query, key, value: map a query and a set of key-value pairs to an output.
|
||||||
|
@ -1,6 +1,9 @@
|
|||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
import re
|
||||||
|
import LangSegment
|
||||||
|
from text import chinese
|
||||||
|
|
||||||
inp_text = os.environ.get("inp_text")
|
inp_text = os.environ.get("inp_text")
|
||||||
inp_wav_dir = os.environ.get("inp_wav_dir")
|
inp_wav_dir = os.environ.get("inp_wav_dir")
|
||||||
@ -83,24 +86,104 @@ if os.path.exists(txt_path) == False:
|
|||||||
|
|
||||||
return phone_level_feature.T
|
return phone_level_feature.T
|
||||||
|
|
||||||
|
def get_bert_inf(phones:list, word2ph:list, norm_text:str, language:str):
|
||||||
|
language=language.replace("all_","")
|
||||||
|
if language == "zh":
|
||||||
|
feature = get_bert_feature(norm_text, word2ph).to(device)
|
||||||
|
else:
|
||||||
|
feature = torch.zeros(
|
||||||
|
(1024, len(phones)),
|
||||||
|
dtype=torch.float32,
|
||||||
|
).to(device)
|
||||||
|
|
||||||
|
return feature
|
||||||
|
|
||||||
|
def get_phones_and_bert(text:str, language:str, version:str, final:bool=False):
|
||||||
|
if language in {"en", "all_zh", "all_ja", "all_ko", "all_yue"}:
|
||||||
|
language = language.replace("all_","")
|
||||||
|
if language == "en":
|
||||||
|
LangSegment.setfilters(["en"])
|
||||||
|
formattext = " ".join(tmp["text"] for tmp in LangSegment.getTexts(text))
|
||||||
|
else:
|
||||||
|
# 因无法区别中日韩文汉字,以用户输入为准
|
||||||
|
formattext = text
|
||||||
|
while " " in formattext:
|
||||||
|
formattext = formattext.replace(" ", " ")
|
||||||
|
if language == "zh":
|
||||||
|
if re.search(r'[A-Za-z]', formattext):
|
||||||
|
formattext = re.sub(r'[a-z]', lambda x: x.group(0).upper(), formattext)
|
||||||
|
formattext = chinese.mix_text_normalize(formattext)
|
||||||
|
return get_phones_and_bert(formattext,"zh",version)
|
||||||
|
else:
|
||||||
|
phones, word2ph, norm_text = clean_text(formattext, language, version)
|
||||||
|
bert = get_bert_feature(norm_text, word2ph).to(device)
|
||||||
|
elif language == "yue" and re.search(r'[A-Za-z]', formattext):
|
||||||
|
formattext = re.sub(r'[a-z]', lambda x: x.group(0).upper(), formattext)
|
||||||
|
formattext = chinese.mix_text_normalize(formattext)
|
||||||
|
return get_phones_and_bert(formattext,"yue",version)
|
||||||
|
else:
|
||||||
|
phones, word2ph, norm_text = clean_text(formattext, language, version)
|
||||||
|
bert = torch.zeros(
|
||||||
|
(1024, len(phones)),
|
||||||
|
dtype=torch.float32,
|
||||||
|
).to(device)
|
||||||
|
elif language in {"zh", "ja", "ko", "yue", "auto", "auto_yue"}:
|
||||||
|
textlist=[]
|
||||||
|
langlist=[]
|
||||||
|
LangSegment.setfilters(["zh","ja","en","ko"])
|
||||||
|
if language == "auto":
|
||||||
|
for tmp in LangSegment.getTexts(text):
|
||||||
|
langlist.append(tmp["lang"])
|
||||||
|
textlist.append(tmp["text"])
|
||||||
|
elif language == "auto_yue":
|
||||||
|
for tmp in LangSegment.getTexts(text):
|
||||||
|
if tmp["lang"] == "zh":
|
||||||
|
tmp["lang"] = "yue"
|
||||||
|
langlist.append(tmp["lang"])
|
||||||
|
textlist.append(tmp["text"])
|
||||||
|
else:
|
||||||
|
for tmp in LangSegment.getTexts(text):
|
||||||
|
if tmp["lang"] == "en":
|
||||||
|
langlist.append(tmp["lang"])
|
||||||
|
else:
|
||||||
|
# 因无法区别中日韩文汉字,以用户输入为准
|
||||||
|
langlist.append(language)
|
||||||
|
textlist.append(tmp["text"])
|
||||||
|
# print(textlist)
|
||||||
|
# print(langlist)
|
||||||
|
phones_list = []
|
||||||
|
bert_list = []
|
||||||
|
norm_text_list = []
|
||||||
|
for i in range(len(textlist)):
|
||||||
|
lang = langlist[i]
|
||||||
|
phones, word2ph, norm_text = clean_text(textlist[i], lang, version)
|
||||||
|
bert = get_bert_inf(phones, word2ph, norm_text, lang)
|
||||||
|
phones_list.append(phones)
|
||||||
|
norm_text_list.append(norm_text)
|
||||||
|
bert_list.append(bert)
|
||||||
|
bert = torch.cat(bert_list, dim=1)
|
||||||
|
phones = sum(phones_list, [])
|
||||||
|
norm_text = ''.join(norm_text_list)
|
||||||
|
|
||||||
|
return phones, bert, norm_text
|
||||||
|
|
||||||
def process(data, res):
|
def process(data, res):
|
||||||
for name, text, lan in data:
|
for name, text, lan in data:
|
||||||
try:
|
try:
|
||||||
name=clean_path(name)
|
name=clean_path(name)
|
||||||
name = os.path.basename(name)
|
name = os.path.basename(name)
|
||||||
print(name)
|
print(name)
|
||||||
phones, word2ph, norm_text = clean_text(
|
phones, bert_feature, norm_text = get_phones_and_bert(
|
||||||
text.replace("%", "-").replace("¥", ","), lan, version
|
text.replace("%", "-").replace("¥", ","), lan, 'v2'
|
||||||
)
|
)
|
||||||
path_bert = "%s/%s.pt" % (bert_dir, name)
|
path_bert = "%s/%s.pt" % (bert_dir, name)
|
||||||
if os.path.exists(path_bert) == False and lan == "zh":
|
if os.path.exists(path_bert) == False and lan == "zh":
|
||||||
bert_feature = get_bert_feature(norm_text, word2ph)
|
|
||||||
assert bert_feature.shape[-1] == len(phones)
|
assert bert_feature.shape[-1] == len(phones)
|
||||||
# torch.save(bert_feature, path_bert)
|
# torch.save(bert_feature, path_bert)
|
||||||
my_save(bert_feature, path_bert)
|
my_save(bert_feature, path_bert)
|
||||||
phones = " ".join(phones)
|
phones = " ".join(phones)
|
||||||
# res.append([name,phones])
|
# res.append([name,phones])
|
||||||
res.append([name, phones, word2ph, norm_text])
|
res.append([name, phones, None, norm_text])
|
||||||
except:
|
except:
|
||||||
print(name, text, traceback.format_exc())
|
print(name, text, traceback.format_exc())
|
||||||
|
|
||||||
|
@ -81,7 +81,7 @@ if os.path.exists(semantic_path) == False:
|
|||||||
# utils.load_checkpoint(pretrained_s2G, vq_model, None, True)
|
# utils.load_checkpoint(pretrained_s2G, vq_model, None, True)
|
||||||
print(
|
print(
|
||||||
vq_model.load_state_dict(
|
vq_model.load_state_dict(
|
||||||
torch.load(pretrained_s2G, map_location="cpu")["weight"], strict=False
|
torch.load(pretrained_s2G, map_location="cpu", weights_only=False)["weight"], strict=False
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
59
install.sh
59
install.sh
@ -2,8 +2,13 @@
|
|||||||
|
|
||||||
# 安装构建工具
|
# 安装构建工具
|
||||||
# Install build tools
|
# Install build tools
|
||||||
|
echo "Installing GCC..."
|
||||||
conda install -c conda-forge gcc=14
|
conda install -c conda-forge gcc=14
|
||||||
|
|
||||||
|
echo "Installing G++..."
|
||||||
conda install -c conda-forge gxx
|
conda install -c conda-forge gxx
|
||||||
|
|
||||||
|
echo "Installing ffmpeg and cmake..."
|
||||||
conda install ffmpeg cmake
|
conda install ffmpeg cmake
|
||||||
|
|
||||||
# 设置编译环境
|
# 设置编译环境
|
||||||
@ -12,10 +17,60 @@ export CMAKE_MAKE_PROGRAM="$CONDA_PREFIX/bin/cmake"
|
|||||||
export CC="$CONDA_PREFIX/bin/gcc"
|
export CC="$CONDA_PREFIX/bin/gcc"
|
||||||
export CXX="$CONDA_PREFIX/bin/g++"
|
export CXX="$CONDA_PREFIX/bin/g++"
|
||||||
|
|
||||||
conda install pytorch==2.1.1 torchvision==0.16.1 torchaudio==2.1.1 pytorch-cuda=11.8 -c pytorch -c nvidia
|
echo "Checking for CUDA installation..."
|
||||||
|
if command -v nvidia-smi &> /dev/null; then
|
||||||
|
USE_CUDA=true
|
||||||
|
echo "CUDA found."
|
||||||
|
else
|
||||||
|
echo "CUDA not found."
|
||||||
|
USE_CUDA=false
|
||||||
|
fi
|
||||||
|
|
||||||
|
|
||||||
|
if [ "$USE_CUDA" = false ]; then
|
||||||
|
echo "Checking for ROCm installation..."
|
||||||
|
if [ -d "/opt/rocm" ]; then
|
||||||
|
USE_ROCM=true
|
||||||
|
echo "ROCm found."
|
||||||
|
if grep -qi "microsoft" /proc/version; then
|
||||||
|
echo "You are running WSL."
|
||||||
|
IS_WSL=true
|
||||||
|
else
|
||||||
|
echo "You are NOT running WSL."
|
||||||
|
IS_WSL=false
|
||||||
|
fi
|
||||||
|
else
|
||||||
|
echo "ROCm not found."
|
||||||
|
USE_ROCM=false
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ "$USE_CUDA" = true ]; then
|
||||||
|
echo "Installing PyTorch with CUDA support..."
|
||||||
|
conda install pytorch==2.1.1 torchvision==0.16.1 torchaudio==2.1.1 pytorch-cuda=11.8 -c pytorch -c nvidia
|
||||||
|
elif [ "$USE_ROCM" = true ] ; then
|
||||||
|
echo "Installing PyTorch with ROCm support..."
|
||||||
|
pip install torch==2.5.1 torchvision==0.20.1 torchaudio==2.5.1 --index-url https://download.pytorch.org/whl/rocm6.2
|
||||||
|
else
|
||||||
|
echo "Installing PyTorch for CPU..."
|
||||||
|
conda install pytorch==2.1.1 torchvision==0.16.1 torchaudio==2.1.1 cpuonly -c pytorch
|
||||||
|
fi
|
||||||
|
|
||||||
|
|
||||||
|
echo "Installing Python dependencies from requirements.txt..."
|
||||||
|
|
||||||
# 刷新环境
|
# 刷新环境
|
||||||
# Refresh environment
|
# Refresh environment
|
||||||
hash -r
|
hash -r
|
||||||
|
pip install -r requirements.txt
|
||||||
|
|
||||||
|
if [ "$USE_ROCM" = true ] && [ "$IS_WSL" = true ] ; then
|
||||||
|
echo "Update to WSL compatible runtime lib..."
|
||||||
|
location=`pip show torch | grep Location | awk -F ": " '{print $2}'`
|
||||||
|
cd ${location}/torch/lib/
|
||||||
|
rm libhsa-runtime64.so*
|
||||||
|
cp /opt/rocm/lib/libhsa-runtime64.so.1.2 libhsa-runtime64.so
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "Installation completed successfully!"
|
||||||
|
|
||||||
pip install -r requirements.txt
|
|
@ -32,7 +32,7 @@ def clean_path(path_str:str):
|
|||||||
if path_str.endswith(('\\','/')):
|
if path_str.endswith(('\\','/')):
|
||||||
return clean_path(path_str[0:-1])
|
return clean_path(path_str[0:-1])
|
||||||
path_str = path_str.replace('/', os.sep).replace('\\', os.sep)
|
path_str = path_str.replace('/', os.sep).replace('\\', os.sep)
|
||||||
return path_str.strip(" ").strip('\'').strip("\n").strip('"').strip(" ").strip("\u202a")
|
return path_str.strip(" \'\n\"\u202a")#path_str.strip(" ").strip('\'').strip("\n").strip('"').strip(" ").strip("\u202a")
|
||||||
|
|
||||||
|
|
||||||
def check_for_existance(file_list:list=None,is_train=False,is_dataset_processing=False):
|
def check_for_existance(file_list:list=None,is_train=False,is_dataset_processing=False):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user