diff --git a/Docker/setup.sh b/Docker/setup.sh index 381d7e40..2b4c093e 100644 --- a/Docker/setup.sh +++ b/Docker/setup.sh @@ -48,13 +48,9 @@ fi source "$HOME/anaconda3/etc/profile.d/conda.sh" -echo "CUDA_VERSION: $CUDA_VERSION" - -if [ "$CUDA_VERSION" = 128 ]; then - echo 1111111 +if [ "$CUDA_VERSION" = 12.8 ]; then pip install torch torchaudio --no-cache-dir --index-url https://download.pytorch.org/whl/cu128 -elif [ "$CUDA_VERSION" = 124 ]; then - echo 2222222 +elif [ "$CUDA_VERSION" = 12.4 ]; then pip install torch==2.5.1 torchaudio==2.5.1 --no-cache-dir --index-url https://download.pytorch.org/whl/cu124 fi diff --git a/install.sh b/install.sh index f141be7f..fdb91c33 100644 --- a/install.sh +++ b/install.sh @@ -69,11 +69,11 @@ while [[ $# -gt 0 ]]; do --device) case "$2" in CU124) - CUDA_VERSION=124 + CUDA=124 USE_CUDA=true ;; CU128) - CUDA_VERSION=128 + CUDA=128 USE_CUDA=true ;; ROCM) @@ -228,9 +228,11 @@ fi if [ "$USE_CUDA" = true ] && [ "$WORKFLOW" = false ]; then echo "Installing PyTorch with CUDA support..." - if [ "$CUDA_VERSION" = 128 ]; then + if [ "$CUDA" = 128 ]; then + echo 11111 pip install torch torchaudio --index-url https://download.pytorch.org/whl/cu128 - elif [ "$CUDA_VERSION" = 124 ]; then + elif [ "$CUDA" = 124 ]; then + echo 22222 pip install torch==2.5.1 torchaudio==2.5.1 --index-url https://download.pytorch.org/whl/cu124 fi elif [ "$USE_ROCM" = true ] && [ "$WORKFLOW" = false ]; then