mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-04-06 03:57:44 +08:00
Merge pull request #183 from Lion-Wu/mps
Add MPS Backend Support for macOS
This commit is contained in:
commit
fe16adc861
@ -41,12 +41,13 @@ class DistributedBucketSampler(Sampler[T_co]):
|
||||
if num_replicas is None:
|
||||
if not dist.is_available():
|
||||
raise RuntimeError("Requires distributed package to be available")
|
||||
num_replicas = dist.get_world_size()
|
||||
num_replicas = dist.get_world_size() if torch.cuda.is_available() else 1
|
||||
if rank is None:
|
||||
if not dist.is_available():
|
||||
raise RuntimeError("Requires distributed package to be available")
|
||||
rank = dist.get_rank()
|
||||
torch.cuda.set_device(rank)
|
||||
rank = dist.get_rank() if torch.cuda.is_available() else 0
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.set_device(rank)
|
||||
if rank >= num_replicas or rank < 0:
|
||||
raise ValueError(
|
||||
"Invalid rank {}, rank should be in the interval"
|
||||
|
@ -40,7 +40,15 @@ from my_utils import load_audio
|
||||
from tools.i18n.i18n import I18nAuto
|
||||
i18n = I18nAuto()
|
||||
|
||||
device = "cuda"
|
||||
os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1' # 确保直接启动推理UI时也能够设置。
|
||||
|
||||
if torch.cuda.is_available():
|
||||
device = "cuda"
|
||||
elif torch.backends.mps.is_available():
|
||||
device = "mps"
|
||||
else:
|
||||
device = "cpu"
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(bert_path)
|
||||
bert_model = AutoModelForMaskedLM.from_pretrained(bert_path)
|
||||
if is_half == True:
|
||||
|
@ -46,7 +46,7 @@ if os.path.exists(txt_path) == False:
|
||||
bert_dir = "%s/3-bert" % (opt_dir)
|
||||
os.makedirs(opt_dir, exist_ok=True)
|
||||
os.makedirs(bert_dir, exist_ok=True)
|
||||
device = "cuda:0"
|
||||
device = "cuda:0" if torch.cuda.is_available() else "mps"
|
||||
tokenizer = AutoTokenizer.from_pretrained(bert_pretrained_dir)
|
||||
bert_model = AutoModelForMaskedLM.from_pretrained(bert_pretrained_dir)
|
||||
if is_half == True:
|
||||
|
@ -47,7 +47,7 @@ os.makedirs(wav32dir,exist_ok=True)
|
||||
|
||||
maxx=0.95
|
||||
alpha=0.5
|
||||
device="cuda:0"
|
||||
device="cuda:0" if torch.cuda.is_available() else "mps"
|
||||
model=cnhubert.get_model()
|
||||
# is_half=False
|
||||
if(is_half==True):
|
||||
|
@ -38,7 +38,7 @@ semantic_path = "%s/6-name2semantic-%s.tsv" % (opt_dir, i_part)
|
||||
if os.path.exists(semantic_path) == False:
|
||||
os.makedirs(opt_dir, exist_ok=True)
|
||||
|
||||
device = "cuda:0"
|
||||
device = "cuda:0" if torch.cuda.is_available() else "mps"
|
||||
hps = utils.get_hparams_from_file(s2config_path)
|
||||
vq_model = SynthesizerTrn(
|
||||
hps.data.filter_length // 2 + 1,
|
||||
|
@ -116,9 +116,9 @@ def main(args):
|
||||
devices=-1,
|
||||
benchmark=False,
|
||||
fast_dev_run=False,
|
||||
strategy=DDPStrategy(
|
||||
strategy = "auto" if torch.backends.mps.is_available() else DDPStrategy(
|
||||
process_group_backend="nccl" if platform.system() != "Windows" else "gloo"
|
||||
),
|
||||
), # mps 不支持多节点训练
|
||||
precision=config["train"]["precision"],
|
||||
logger=logger,
|
||||
num_sanity_val_steps=0,
|
||||
|
@ -44,9 +44,12 @@ global_step = 0
|
||||
|
||||
def main():
|
||||
"""Assume Single Node Multi GPUs Training Only"""
|
||||
assert torch.cuda.is_available(), "CPU training is not allowed."
|
||||
assert torch.cuda.is_available() or torch.backends.mps.is_available(), "Only GPU training is allowed."
|
||||
|
||||
n_gpus = torch.cuda.device_count()
|
||||
if torch.backends.mps.is_available():
|
||||
n_gpus = 1
|
||||
else:
|
||||
n_gpus = torch.cuda.device_count()
|
||||
os.environ["MASTER_ADDR"] = "localhost"
|
||||
os.environ["MASTER_PORT"] = str(randint(20000, 55555))
|
||||
|
||||
@ -70,13 +73,14 @@ def run(rank, n_gpus, hps):
|
||||
writer_eval = SummaryWriter(log_dir=os.path.join(hps.s2_ckpt_dir, "eval"))
|
||||
|
||||
dist.init_process_group(
|
||||
backend="gloo" if os.name == "nt" else "nccl",
|
||||
backend = "gloo" if os.name == "nt" or torch.backends.mps.is_available() else "nccl",
|
||||
init_method="env://",
|
||||
world_size=n_gpus,
|
||||
rank=rank,
|
||||
)
|
||||
torch.manual_seed(hps.train.seed)
|
||||
torch.cuda.set_device(rank)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.set_device(rank)
|
||||
|
||||
train_dataset = TextAudioSpeakerLoader(hps.data) ########
|
||||
train_sampler = DistributedBucketSampler(
|
||||
@ -128,9 +132,14 @@ def run(rank, n_gpus, hps):
|
||||
hps.train.segment_size // hps.data.hop_length,
|
||||
n_speakers=hps.data.n_speakers,
|
||||
**hps.model,
|
||||
).cuda(rank)
|
||||
).cuda(rank) if torch.cuda.is_available() else SynthesizerTrn(
|
||||
hps.data.filter_length // 2 + 1,
|
||||
hps.train.segment_size // hps.data.hop_length,
|
||||
n_speakers=hps.data.n_speakers,
|
||||
**hps.model,
|
||||
).to("mps")
|
||||
|
||||
net_d = MultiPeriodDiscriminator(hps.model.use_spectral_norm).cuda(rank)
|
||||
net_d = MultiPeriodDiscriminator(hps.model.use_spectral_norm).cuda(rank) if torch.cuda.is_available() else MultiPeriodDiscriminator(hps.model.use_spectral_norm).to("mps")
|
||||
for name, param in net_g.named_parameters():
|
||||
if not param.requires_grad:
|
||||
print(name, "not requires_grad")
|
||||
@ -174,8 +183,12 @@ def run(rank, n_gpus, hps):
|
||||
betas=hps.train.betas,
|
||||
eps=hps.train.eps,
|
||||
)
|
||||
net_g = DDP(net_g, device_ids=[rank], find_unused_parameters=True)
|
||||
net_d = DDP(net_d, device_ids=[rank], find_unused_parameters=True)
|
||||
if torch.cuda.is_available():
|
||||
net_g = DDP(net_g, device_ids=[rank], find_unused_parameters=True)
|
||||
net_d = DDP(net_d, device_ids=[rank], find_unused_parameters=True)
|
||||
else:
|
||||
net_g = net_g.to("mps")
|
||||
net_d = net_d.to("mps")
|
||||
|
||||
try: # 如果能加载自动resume
|
||||
_, _, _, epoch_str = utils.load_checkpoint(
|
||||
@ -205,6 +218,9 @@ def run(rank, n_gpus, hps):
|
||||
net_g.module.load_state_dict(
|
||||
torch.load(hps.train.pretrained_s2G, map_location="cpu")["weight"],
|
||||
strict=False,
|
||||
) if torch.cuda.is_available() else net_g.load_state_dict(
|
||||
torch.load(hps.train.pretrained_s2G, map_location="cpu")["weight"],
|
||||
strict=False,
|
||||
)
|
||||
) ##测试不加载优化器
|
||||
if hps.train.pretrained_s2D != "":
|
||||
@ -213,6 +229,8 @@ def run(rank, n_gpus, hps):
|
||||
print(
|
||||
net_d.module.load_state_dict(
|
||||
torch.load(hps.train.pretrained_s2D, map_location="cpu")["weight"]
|
||||
) if torch.cuda.is_available() else net_d.load_state_dict(
|
||||
torch.load(hps.train.pretrained_s2D, map_location="cpu")["weight"]
|
||||
)
|
||||
)
|
||||
|
||||
@ -288,18 +306,26 @@ def train_and_evaluate(
|
||||
text,
|
||||
text_lengths,
|
||||
) in tqdm(enumerate(train_loader)):
|
||||
spec, spec_lengths = spec.cuda(rank, non_blocking=True), spec_lengths.cuda(
|
||||
rank, non_blocking=True
|
||||
)
|
||||
y, y_lengths = y.cuda(rank, non_blocking=True), y_lengths.cuda(
|
||||
rank, non_blocking=True
|
||||
)
|
||||
ssl = ssl.cuda(rank, non_blocking=True)
|
||||
ssl.requires_grad = False
|
||||
# ssl_lengths = ssl_lengths.cuda(rank, non_blocking=True)
|
||||
text, text_lengths = text.cuda(rank, non_blocking=True), text_lengths.cuda(
|
||||
rank, non_blocking=True
|
||||
)
|
||||
if torch.cuda.is_available():
|
||||
spec, spec_lengths = spec.cuda(rank, non_blocking=True), spec_lengths.cuda(
|
||||
rank, non_blocking=True
|
||||
)
|
||||
y, y_lengths = y.cuda(rank, non_blocking=True), y_lengths.cuda(
|
||||
rank, non_blocking=True
|
||||
)
|
||||
ssl = ssl.cuda(rank, non_blocking=True)
|
||||
ssl.requires_grad = False
|
||||
# ssl_lengths = ssl_lengths.cuda(rank, non_blocking=True)
|
||||
text, text_lengths = text.cuda(rank, non_blocking=True), text_lengths.cuda(
|
||||
rank, non_blocking=True
|
||||
)
|
||||
else:
|
||||
spec, spec_lengths = spec.to("mps"), spec_lengths.to("mps")
|
||||
y, y_lengths = y.to("mps"), y_lengths.to("mps")
|
||||
ssl = ssl.to("mps")
|
||||
ssl.requires_grad = False
|
||||
# ssl_lengths = ssl_lengths.cuda(rank, non_blocking=True)
|
||||
text, text_lengths = text.to("mps"), text_lengths.to("mps")
|
||||
|
||||
with autocast(enabled=hps.train.fp16_run):
|
||||
(
|
||||
@ -500,13 +526,21 @@ def evaluate(hps, generator, eval_loader, writer_eval):
|
||||
text_lengths,
|
||||
) in enumerate(eval_loader):
|
||||
print(111)
|
||||
spec, spec_lengths = spec.cuda(), spec_lengths.cuda()
|
||||
y, y_lengths = y.cuda(), y_lengths.cuda()
|
||||
ssl = ssl.cuda()
|
||||
text, text_lengths = text.cuda(), text_lengths.cuda()
|
||||
if torch.cuda.is_available():
|
||||
spec, spec_lengths = spec.cuda(), spec_lengths.cuda()
|
||||
y, y_lengths = y.cuda(), y_lengths.cuda()
|
||||
ssl = ssl.cuda()
|
||||
text, text_lengths = text.cuda(), text_lengths.cuda()
|
||||
else:
|
||||
spec, spec_lengths = spec.to("mps"), spec_lengths.to("mps")
|
||||
y, y_lengths = y.to("mps"), y_lengths.to("mps")
|
||||
ssl = ssl.to("mps")
|
||||
text, text_lengths = text.to("mps"), text_lengths.to("mps")
|
||||
for test in [0, 1]:
|
||||
y_hat, mask, *_ = generator.module.infer(
|
||||
ssl, spec, spec_lengths, text, text_lengths, test=test
|
||||
) if torch.cuda.is_available() else generator.infer(
|
||||
ssl, spec, spec_lengths, text, text_lengths, test=test
|
||||
)
|
||||
y_hat_lengths = mask.sum([1, 2]).long() * hps.data.hop_length
|
||||
|
||||
|
24
README.md
24
README.md
@ -43,9 +43,24 @@ If you are a Windows user (tested with win>=10) you can install directly via the
|
||||
|
||||
- Python 3.9, PyTorch 2.0.1, CUDA 11
|
||||
- Python 3.10.13, PyTorch 2.1.2, CUDA 12.3
|
||||
- Python 3.9, PyTorch 2.3.0.dev20240122, macOS 14.3 (Apple Silicon, MPS)
|
||||
|
||||
_Note: numba==0.56.4 require py<3.11_
|
||||
|
||||
### For Mac Users
|
||||
If you are a Mac user, please install by using the following commands:
|
||||
#### Create Environment
|
||||
```bash
|
||||
conda create -n GPTSoVits python=3.9
|
||||
conda activate GPTSoVits
|
||||
```
|
||||
#### Install Requirements
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
pip uninstall torch torchaudio
|
||||
pip3 install --pre torch torchaudio --index-url https://download.pytorch.org/whl/nightly/cpu
|
||||
```
|
||||
_Note: For preprocessing with UVR5, it is recommended to [download the original project GUI](https://github.com/Anjok07/ultimatevocalremovergui) and select GPU for operation. Additionally, there may be memory leak issues when using Mac for inference, restarting the inference webUI can release the memory._
|
||||
### Quick Install with Conda
|
||||
|
||||
```bash
|
||||
@ -58,16 +73,9 @@ bash install.sh
|
||||
#### Pip Packages
|
||||
|
||||
```bash
|
||||
pip install torch numpy scipy tensorboard librosa==0.9.2 numba==0.56.4 pytorch-lightning gradio==3.14.0 ffmpeg-python onnxruntime tqdm cn2an pypinyin pyopenjtalk g2p_en chardet transformers jieba_fast
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
#### Additional Requirements
|
||||
|
||||
If you need Chinese ASR (supported by FunASR), install:
|
||||
|
||||
```bash
|
||||
pip install modelscope torchaudio sentencepiece funasr
|
||||
```
|
||||
|
||||
#### FFmpeg
|
||||
|
||||
|
14
api.py
14
api.py
@ -30,25 +30,25 @@ endpoint: `/`
|
||||
|
||||
使用执行参数指定的参考音频:
|
||||
GET:
|
||||
`http://127.0.0.1:9880?text=你所热爱的,就是你的生活。&text_language=zh`
|
||||
`http://127.0.0.1:9880?text=先帝创业未半而中道崩殂,今天下三分,益州疲弊,此诚危急存亡之秋也。&text_language=zh`
|
||||
POST:
|
||||
```json
|
||||
{
|
||||
"text": "你所热爱的,就是你的生活。",
|
||||
"text": "先帝创业未半而中道崩殂,今天下三分,益州疲弊,此诚危急存亡之秋也。",
|
||||
"text_language": "zh"
|
||||
}
|
||||
```
|
||||
|
||||
手动指定当次推理所使用的参考音频:
|
||||
GET:
|
||||
`http://127.0.0.1:9880?refer_wav_path=123.wav&prompt_text=一二三。&prompt_language=zh&text=你所热爱的,就是你的生活。&text_language=zh`
|
||||
`http://127.0.0.1:9880?refer_wav_path=123.wav&prompt_text=一二三。&prompt_language=zh&text=先帝创业未半而中道崩殂,今天下三分,益州疲弊,此诚危急存亡之秋也。&text_language=zh`
|
||||
POST:
|
||||
```json
|
||||
{
|
||||
"refer_wav_path": "123.wav",
|
||||
"prompt_text": "一二三。",
|
||||
"prompt_language": "zh",
|
||||
"text": "你所热爱的,就是你的生活。",
|
||||
"text": "先帝创业未半而中道崩殂,今天下三分,益州疲弊,此诚危急存亡之秋也。",
|
||||
"text_language": "zh"
|
||||
}
|
||||
```
|
||||
@ -138,7 +138,8 @@ parser.add_argument("-dr", "--default_refer_path", type=str, default="", help="
|
||||
parser.add_argument("-dt", "--default_refer_text", type=str, default="", help="默认参考音频文本")
|
||||
parser.add_argument("-dl", "--default_refer_language", type=str, default="", help="默认参考音频语种")
|
||||
|
||||
parser.add_argument("-d", "--device", type=str, default=g_config.infer_device, help="cuda / cpu")
|
||||
parser.add_argument("-d", "--device", type=str, default=g_config.infer_device, help="cuda / cpu / mps")
|
||||
parser.add_argument("-p", "--port", type=int, default=g_config.api_port, help="default: 9880")
|
||||
parser.add_argument("-a", "--bind_addr", type=str, default="127.0.0.1", help="default: 127.0.0.1")
|
||||
parser.add_argument("-p", "--port", type=int, default=g_config.api_port, help="default: 9880")
|
||||
parser.add_argument("-fp", "--full_precision", action="store_true", default=False, help="覆盖config.is_half为False, 使用全精度")
|
||||
@ -278,7 +279,7 @@ vq_model.eval()
|
||||
print(vq_model.load_state_dict(dict_s2["weight"], strict=False))
|
||||
hz = 50
|
||||
max_sec = config['data']['max_sec']
|
||||
t2s_model = Text2SemanticLightningModule(config, "ojbk", is_train=False)
|
||||
t2s_model = Text2SemanticLightningModule(config, "****", is_train=False)
|
||||
t2s_model.load_state_dict(dict_s1["weight"])
|
||||
if is_half:
|
||||
t2s_model = t2s_model.half()
|
||||
@ -439,6 +440,7 @@ def handle(refer_wav_path, prompt_text, prompt_language, text, text_language):
|
||||
wav.seek(0)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
torch.mps.empty_cache()
|
||||
return StreamingResponse(wav, media_type="audio/wav")
|
||||
|
||||
|
||||
|
@ -1,5 +1,6 @@
|
||||
import sys,os
|
||||
|
||||
import torch
|
||||
|
||||
# 推理用的指定模型
|
||||
sovits_path = ""
|
||||
@ -14,7 +15,12 @@ pretrained_gpt_path = "GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=
|
||||
|
||||
exp_root = "logs"
|
||||
python_exec = sys.executable or "python"
|
||||
infer_device = "cuda"
|
||||
if torch.cuda.is_available():
|
||||
infer_device = "cuda"
|
||||
elif torch.backends.mps.is_available():
|
||||
infer_device = "mps"
|
||||
else:
|
||||
infer_device = "cpu"
|
||||
|
||||
webui_port_main = 9874
|
||||
webui_port_uvr5 = 9873
|
||||
|
@ -38,10 +38,29 @@ https://github.com/RVC-Boss/GPT-SoVITS/assets/129054828/05bee1fa-bdd8-4d85-9350-
|
||||
|
||||
如果你是Windows用户(已在win>=10上测试),可以直接通过预打包文件安装。只需下载[预打包文件](https://huggingface.co/lj1995/GPT-SoVITS-windows-package/resolve/main/GPT-SoVITS-beta.7z?download=true),解压后双击go-webui.bat即可启动GPT-SoVITS-WebUI。
|
||||
|
||||
### Python和PyTorch版本
|
||||
|
||||
已在Python 3.9、PyTorch 2.0.1和CUDA 11上测试。
|
||||
### 测试通过的Python和PyTorch版本
|
||||
|
||||
- Python 3.9、PyTorch 2.0.1和CUDA 11
|
||||
- Python 3.10.13, PyTorch 2.1.2和CUDA 12.3
|
||||
- Python 3.9、Pytorch 2.3.0.dev20240122和macOS 14.3(Apple 芯片,MPS)
|
||||
|
||||
_注意: numba==0.56.4 需要 python<3.11_
|
||||
|
||||
### Mac 用户
|
||||
如果你是Mac用户,请使用以下命令安装:
|
||||
#### 创建环境
|
||||
```bash
|
||||
conda create -n GPTSoVits python=3.9
|
||||
conda activate GPTSoVits
|
||||
```
|
||||
#### 安装依赖
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
pip uninstall torch torchaudio
|
||||
pip3 install --pre torch torchaudio --index-url https://download.pytorch.org/whl/nightly/cpu
|
||||
```
|
||||
_注意:如需使用UVR5进行预处理,建议[下载原项目GUI](https://github.com/Anjok07/ultimatevocalremovergui),勾选GPU运行。另外,使用Mac推理时可能存在内存泄漏问题,重启推理UI即可释放内存。_
|
||||
### 使用Conda快速安装
|
||||
|
||||
```bash
|
||||
@ -53,15 +72,7 @@ bash install.sh
|
||||
#### Pip包
|
||||
|
||||
```bash
|
||||
pip install torch numpy scipy tensorboard librosa==0.9.2 numba==0.56.4 pytorch-lightning gradio==3.14.0 ffmpeg-python onnxruntime tqdm cn2an pypinyin pyopenjtalk g2p_en chardet transformers
|
||||
```
|
||||
|
||||
#### 额外要求
|
||||
|
||||
如果你需要中文自动语音识别(由FunASR支持),请安装:
|
||||
|
||||
```bash
|
||||
pip install modelscope torchaudio sentencepiece funasr
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
#### FFmpeg
|
||||
|
@ -37,9 +37,26 @@ https://github.com/RVC-Boss/GPT-SoVITS/assets/129054828/05bee1fa-bdd8-4d85-9350-
|
||||
Windows ユーザーであれば(win>=10 にてテスト済み)、prezip 経由で直接インストールできます。[prezip](https://huggingface.co/lj1995/GPT-SoVITS-windows-package/resolve/main/GPT-SoVITS-beta.7z?download=true) をダウンロードして解凍し、go-webui.bat をダブルクリックするだけで GPT-SoVITS-WebUI が起動します。
|
||||
|
||||
### Python と PyTorch のバージョン
|
||||
- Python 3.9, PyTorch 2.0.1, CUDA 11
|
||||
- Python 3.10.13, PyTorch 2.1.2, CUDA 12.3
|
||||
- Python 3.9, PyTorch 2.3.0.dev20240122, macOS 14.3 (Apple Silicon, MPS)
|
||||
|
||||
Python 3.9、PyTorch 2.0.1、CUDA 11でテスト済。
|
||||
_注記: numba==0.56.4 は py<3.11 が必要です_
|
||||
|
||||
### Macユーザーへ
|
||||
Macユーザーの方は、以下のコマンドを使用してインストールしてください。
|
||||
#### 環境作成
|
||||
```bash
|
||||
conda create -n GPTSoVits python=3.9
|
||||
conda activate GPTSoVits
|
||||
```
|
||||
#### Pip パッケージ
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
pip uninstall torch torchaudio
|
||||
pip3 install --pre torch torchaudio --index-url https://download.pytorch.org/whl/nightly/cpu
|
||||
```
|
||||
_注記: UVR5を使用した前処理には、[元のプロジェクトGUIをダウンロード](https://github.com/Anjok07/ultimatevocalremovergui)して、操作にGPUを選択することを推奨します。さらに、Macを使用して推論する際にメモリリークの問題が発生する可能性がありますが、推論のwebUIを再起動することでメモリを解放できます。_
|
||||
### Conda によるクイックインストール
|
||||
|
||||
```bash
|
||||
@ -52,15 +69,7 @@ bash install.sh
|
||||
#### Pip パッケージ
|
||||
|
||||
```bash
|
||||
pip install torch numpy scipy tensorboard librosa==0.9.2 numba==0.56.4 pytorch-lightning gradio==3.14.0 ffmpeg-python onnxruntime tqdm cn2an pypinyin pyopenjtalk g2p_en chardet transformers
|
||||
```
|
||||
|
||||
#### 追加要件
|
||||
|
||||
中国語の ASR(FunASR がサポート)が必要な場合は、以下をインストールしてください:
|
||||
|
||||
```bash
|
||||
pip install modelscope torchaudio sentencepiece funasr
|
||||
pip install -r requirementx.txt
|
||||
```
|
||||
|
||||
#### FFmpeg
|
||||
|
11
webui.py
11
webui.py
@ -45,14 +45,17 @@ i18n = I18nAuto()
|
||||
from scipy.io import wavfile
|
||||
from tools.my_utils import load_audio
|
||||
from multiprocessing import cpu_count
|
||||
|
||||
os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1' # 当遇到mps不支持的步骤时使用cpu
|
||||
|
||||
n_cpu=cpu_count()
|
||||
|
||||
# 判断是否有能用来训练和加速推理的N卡
|
||||
ngpu = torch.cuda.device_count()
|
||||
gpu_infos = []
|
||||
mem = []
|
||||
if_gpu_ok = False
|
||||
|
||||
# 判断是否有能用来训练和加速推理的N卡
|
||||
if torch.cuda.is_available() or ngpu != 0:
|
||||
for i in range(ngpu):
|
||||
gpu_name = torch.cuda.get_device_name(i)
|
||||
@ -61,6 +64,12 @@ if torch.cuda.is_available() or ngpu != 0:
|
||||
if_gpu_ok = True # 至少有一张能用的N卡
|
||||
gpu_infos.append("%s\t%s" % (i, gpu_name))
|
||||
mem.append(int(torch.cuda.get_device_properties(i).total_memory/ 1024/ 1024/ 1024+ 0.4))
|
||||
# 判断是否支持mps加速
|
||||
if torch.backends.mps.is_available():
|
||||
if_gpu_ok = True
|
||||
gpu_infos.append("%s\t%s" % ("0", "Apple GPU"))
|
||||
mem.append(psutil.virtual_memory().total/ 1024 / 1024 / 1024) # 实测使用系统内存作为显存不会爆显存
|
||||
|
||||
if if_gpu_ok and len(gpu_infos) > 0:
|
||||
gpu_info = "\n".join(gpu_infos)
|
||||
default_batch_size = min(mem) // 2
|
||||
|
Loading…
x
Reference in New Issue
Block a user