Compare commits

..

1 Commits

4 changed files with 29 additions and 84 deletions

View File

@ -12,33 +12,33 @@ import torch
def multi_head_attention_forward_patched( def multi_head_attention_forward_patched(
query, query: Tensor,
key, key: Tensor,
value, value: Tensor,
embed_dim_to_check, embed_dim_to_check: int,
num_heads, num_heads: int,
in_proj_weight, in_proj_weight: Optional[Tensor],
in_proj_bias, in_proj_bias: Optional[Tensor],
bias_k, bias_k: Optional[Tensor],
bias_v, bias_v: Optional[Tensor],
add_zero_attn, add_zero_attn: bool,
dropout_p: float, dropout_p: float,
out_proj_weight, out_proj_weight: Tensor,
out_proj_bias, out_proj_bias: Optional[Tensor],
training = True, training: bool = True,
key_padding_mask = None, key_padding_mask: Optional[Tensor] = None,
need_weights = True, need_weights: bool = True,
attn_mask = None, attn_mask: Optional[Tensor] = None,
use_separate_proj_weight = False, use_separate_proj_weight: bool = False,
q_proj_weight = None, q_proj_weight: Optional[Tensor] = None,
k_proj_weight = None, k_proj_weight: Optional[Tensor] = None,
v_proj_weight = None, v_proj_weight: Optional[Tensor] = None,
static_k = None, static_k: Optional[Tensor] = None,
static_v = None, static_v: Optional[Tensor] = None,
average_attn_weights = True, average_attn_weights: bool = True,
is_causal = False, is_causal: bool = False,
cache=None, cache=None,
): ) -> Tuple[Tensor, Optional[Tensor]]:
r""" r"""
Args: Args:
query, key, value: map a query and a set of key-value pairs to an output. query, key, value: map a query and a set of key-value pairs to an output.

View File

@ -81,7 +81,7 @@ if os.path.exists(semantic_path) == False:
# utils.load_checkpoint(pretrained_s2G, vq_model, None, True) # utils.load_checkpoint(pretrained_s2G, vq_model, None, True)
print( print(
vq_model.load_state_dict( vq_model.load_state_dict(
torch.load(pretrained_s2G, map_location="cpu", weights_only=False)["weight"], strict=False torch.load(pretrained_s2G, map_location="cpu")["weight"], strict=False
) )
) )

View File

@ -2,13 +2,8 @@
# 安装构建工具 # 安装构建工具
# Install build tools # Install build tools
echo "Installing GCC..."
conda install -c conda-forge gcc=14 conda install -c conda-forge gcc=14
echo "Installing G++..."
conda install -c conda-forge gxx conda install -c conda-forge gxx
echo "Installing ffmpeg and cmake..."
conda install ffmpeg cmake conda install ffmpeg cmake
# 设置编译环境 # 设置编译环境
@ -17,60 +12,10 @@ export CMAKE_MAKE_PROGRAM="$CONDA_PREFIX/bin/cmake"
export CC="$CONDA_PREFIX/bin/gcc" export CC="$CONDA_PREFIX/bin/gcc"
export CXX="$CONDA_PREFIX/bin/g++" export CXX="$CONDA_PREFIX/bin/g++"
echo "Checking for CUDA installation..."
if command -v nvidia-smi &> /dev/null; then
USE_CUDA=true
echo "CUDA found."
else
echo "CUDA not found."
USE_CUDA=false
fi
if [ "$USE_CUDA" = false ]; then
echo "Checking for ROCm installation..."
if [ -d "/opt/rocm" ]; then
USE_ROCM=true
echo "ROCm found."
if grep -qi "microsoft" /proc/version; then
echo "You are running WSL."
IS_WSL=true
else
echo "You are NOT running WSL."
IS_WSL=false
fi
else
echo "ROCm not found."
USE_ROCM=false
fi
fi
if [ "$USE_CUDA" = true ]; then
echo "Installing PyTorch with CUDA support..."
conda install pytorch==2.1.1 torchvision==0.16.1 torchaudio==2.1.1 pytorch-cuda=11.8 -c pytorch -c nvidia conda install pytorch==2.1.1 torchvision==0.16.1 torchaudio==2.1.1 pytorch-cuda=11.8 -c pytorch -c nvidia
elif [ "$USE_ROCM" = true ] ; then
echo "Installing PyTorch with ROCm support..."
pip install torch==2.5.1 torchvision==0.20.1 torchaudio==2.5.1 --index-url https://download.pytorch.org/whl/rocm6.2
else
echo "Installing PyTorch for CPU..."
conda install pytorch==2.1.1 torchvision==0.16.1 torchaudio==2.1.1 cpuonly -c pytorch
fi
echo "Installing Python dependencies from requirements.txt..."
# 刷新环境 # 刷新环境
# Refresh environment # Refresh environment
hash -r hash -r
pip install -r requirements.txt pip install -r requirements.txt
if [ "$USE_ROCM" = true ] && [ "$IS_WSL" = true ] ; then
echo "Update to WSL compatible runtime lib..."
location=`pip show torch | grep Location | awk -F ": " '{print $2}'`
cd ${location}/torch/lib/
rm libhsa-runtime64.so*
cp /opt/rocm/lib/libhsa-runtime64.so.1.2 libhsa-runtime64.so
fi
echo "Installation completed successfully!"

View File

@ -32,7 +32,7 @@ def clean_path(path_str:str):
if path_str.endswith(('\\','/')): if path_str.endswith(('\\','/')):
return clean_path(path_str[0:-1]) return clean_path(path_str[0:-1])
path_str = path_str.replace('/', os.sep).replace('\\', os.sep) path_str = path_str.replace('/', os.sep).replace('\\', os.sep)
return path_str.strip(" \'\n\"\u202a")#path_str.strip(" ").strip('\'').strip("\n").strip('"').strip(" ").strip("\u202a") return path_str.strip(" ").strip('\'').strip("\n").strip('"').strip(" ").strip("\u202a")
def check_for_existance(file_list:list=None,is_train=False,is_dataset_processing=False): def check_for_existance(file_list:list=None,is_train=False,is_dataset_processing=False):