优化代码

This commit is contained in:
chasonjiang 2024-03-14 11:15:09 +08:00
parent 0ff60a947a
commit 8698c28c90

View File

@ -820,7 +820,7 @@ class TTS:
def empty_cache(self): def empty_cache(self):
try: try:
if str(self.configs.device) == "cuda": if "cuda" in str(self.configs.device):
torch.cuda.empty_cache() torch.cuda.empty_cache()
elif str(self.configs.device) == "mps": elif str(self.configs.device) == "mps":
torch.mps.empty_cache() torch.mps.empty_cache()