mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-06-29 00:59:16 +08:00
Compare commits
1 Commits
6bb625cfe4
...
9d5b89e7ee
Author | SHA1 | Date | |
---|---|---|---|
|
9d5b89e7ee |
@ -12,33 +12,33 @@ import torch
|
||||
|
||||
|
||||
def multi_head_attention_forward_patched(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
embed_dim_to_check,
|
||||
num_heads,
|
||||
in_proj_weight,
|
||||
in_proj_bias,
|
||||
bias_k,
|
||||
bias_v,
|
||||
add_zero_attn,
|
||||
query: Tensor,
|
||||
key: Tensor,
|
||||
value: Tensor,
|
||||
embed_dim_to_check: int,
|
||||
num_heads: int,
|
||||
in_proj_weight: Optional[Tensor],
|
||||
in_proj_bias: Optional[Tensor],
|
||||
bias_k: Optional[Tensor],
|
||||
bias_v: Optional[Tensor],
|
||||
add_zero_attn: bool,
|
||||
dropout_p: float,
|
||||
out_proj_weight,
|
||||
out_proj_bias,
|
||||
training = True,
|
||||
key_padding_mask = None,
|
||||
need_weights = True,
|
||||
attn_mask = None,
|
||||
use_separate_proj_weight = False,
|
||||
q_proj_weight = None,
|
||||
k_proj_weight = None,
|
||||
v_proj_weight = None,
|
||||
static_k = None,
|
||||
static_v = None,
|
||||
average_attn_weights = True,
|
||||
is_causal = False,
|
||||
out_proj_weight: Tensor,
|
||||
out_proj_bias: Optional[Tensor],
|
||||
training: bool = True,
|
||||
key_padding_mask: Optional[Tensor] = None,
|
||||
need_weights: bool = True,
|
||||
attn_mask: Optional[Tensor] = None,
|
||||
use_separate_proj_weight: bool = False,
|
||||
q_proj_weight: Optional[Tensor] = None,
|
||||
k_proj_weight: Optional[Tensor] = None,
|
||||
v_proj_weight: Optional[Tensor] = None,
|
||||
static_k: Optional[Tensor] = None,
|
||||
static_v: Optional[Tensor] = None,
|
||||
average_attn_weights: bool = True,
|
||||
is_causal: bool = False,
|
||||
cache=None,
|
||||
):
|
||||
) -> Tuple[Tensor, Optional[Tensor]]:
|
||||
r"""
|
||||
Args:
|
||||
query, key, value: map a query and a set of key-value pairs to an output.
|
||||
|
@ -81,7 +81,7 @@ if os.path.exists(semantic_path) == False:
|
||||
# utils.load_checkpoint(pretrained_s2G, vq_model, None, True)
|
||||
print(
|
||||
vq_model.load_state_dict(
|
||||
torch.load(pretrained_s2G, map_location="cpu", weights_only=False)["weight"], strict=False
|
||||
torch.load(pretrained_s2G, map_location="cpu")["weight"], strict=False
|
||||
)
|
||||
)
|
||||
|
||||
|
59
install.sh
59
install.sh
@ -2,13 +2,8 @@
|
||||
|
||||
# 安装构建工具
|
||||
# Install build tools
|
||||
echo "Installing GCC..."
|
||||
conda install -c conda-forge gcc=14
|
||||
|
||||
echo "Installing G++..."
|
||||
conda install -c conda-forge gxx
|
||||
|
||||
echo "Installing ffmpeg and cmake..."
|
||||
conda install ffmpeg cmake
|
||||
|
||||
# 设置编译环境
|
||||
@ -17,60 +12,10 @@ export CMAKE_MAKE_PROGRAM="$CONDA_PREFIX/bin/cmake"
|
||||
export CC="$CONDA_PREFIX/bin/gcc"
|
||||
export CXX="$CONDA_PREFIX/bin/g++"
|
||||
|
||||
echo "Checking for CUDA installation..."
|
||||
if command -v nvidia-smi &> /dev/null; then
|
||||
USE_CUDA=true
|
||||
echo "CUDA found."
|
||||
else
|
||||
echo "CUDA not found."
|
||||
USE_CUDA=false
|
||||
fi
|
||||
|
||||
|
||||
if [ "$USE_CUDA" = false ]; then
|
||||
echo "Checking for ROCm installation..."
|
||||
if [ -d "/opt/rocm" ]; then
|
||||
USE_ROCM=true
|
||||
echo "ROCm found."
|
||||
if grep -qi "microsoft" /proc/version; then
|
||||
echo "You are running WSL."
|
||||
IS_WSL=true
|
||||
else
|
||||
echo "You are NOT running WSL."
|
||||
IS_WSL=false
|
||||
fi
|
||||
else
|
||||
echo "ROCm not found."
|
||||
USE_ROCM=false
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ "$USE_CUDA" = true ]; then
|
||||
echo "Installing PyTorch with CUDA support..."
|
||||
conda install pytorch==2.1.1 torchvision==0.16.1 torchaudio==2.1.1 pytorch-cuda=11.8 -c pytorch -c nvidia
|
||||
elif [ "$USE_ROCM" = true ] ; then
|
||||
echo "Installing PyTorch with ROCm support..."
|
||||
pip install torch==2.5.1 torchvision==0.20.1 torchaudio==2.5.1 --index-url https://download.pytorch.org/whl/rocm6.2
|
||||
else
|
||||
echo "Installing PyTorch for CPU..."
|
||||
conda install pytorch==2.1.1 torchvision==0.16.1 torchaudio==2.1.1 cpuonly -c pytorch
|
||||
fi
|
||||
|
||||
|
||||
echo "Installing Python dependencies from requirements.txt..."
|
||||
conda install pytorch==2.1.1 torchvision==0.16.1 torchaudio==2.1.1 pytorch-cuda=11.8 -c pytorch -c nvidia
|
||||
|
||||
# 刷新环境
|
||||
# Refresh environment
|
||||
hash -r
|
||||
|
||||
pip install -r requirements.txt
|
||||
|
||||
if [ "$USE_ROCM" = true ] && [ "$IS_WSL" = true ] ; then
|
||||
echo "Update to WSL compatible runtime lib..."
|
||||
location=`pip show torch | grep Location | awk -F ": " '{print $2}'`
|
||||
cd ${location}/torch/lib/
|
||||
rm libhsa-runtime64.so*
|
||||
cp /opt/rocm/lib/libhsa-runtime64.so.1.2 libhsa-runtime64.so
|
||||
fi
|
||||
|
||||
echo "Installation completed successfully!"
|
||||
|
||||
|
@ -32,7 +32,7 @@ def clean_path(path_str:str):
|
||||
if path_str.endswith(('\\','/')):
|
||||
return clean_path(path_str[0:-1])
|
||||
path_str = path_str.replace('/', os.sep).replace('\\', os.sep)
|
||||
return path_str.strip(" \'\n\"\u202a")#path_str.strip(" ").strip('\'').strip("\n").strip('"').strip(" ").strip("\u202a")
|
||||
return path_str.strip(" ").strip('\'').strip("\n").strip('"').strip(" ").strip("\u202a")
|
||||
|
||||
|
||||
def check_for_existance(file_list:list=None,is_train=False,is_dataset_processing=False):
|
||||
|
Loading…
x
Reference in New Issue
Block a user