老显卡判断半精度问题

当前逻辑对于老显卡不友好,比如NVIDIA T500,即使修改配置文件,is_half变量也会被反复赋值为True,现在改为自动判断当前设备是否支持半精度
This commit is contained in:
刘悦 2024-03-23 20:17:17 +08:00 committed by GitHub
parent ed75ecdd6d
commit 748b682fdd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -29,7 +29,24 @@ is_share = os.environ.get("is_share", "False")
is_share = eval(is_share)
if "_CUDA_VISIBLE_DEVICES" in os.environ:
os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"]
is_half = eval(os.environ.get("is_half", "True")) and torch.cuda.is_available()
# is_half = eval(os.environ.get("is_half", "True")) and torch.cuda.is_available()
# 创建一个FP16张量
fp16_tensor = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float16)
# 尝试在GPU上执行操作
try:
# 将张量移动到GPU
fp16_tensor = fp16_tensor.cuda()
# 创建一个与fp16_tensor相同大小的FP32张量
fp32_tensor = torch.tensor(fp16_tensor).cuda()
# 执行一个简单的操作,比如加法
result = fp16_tensor + fp32_tensor
is_half = True
print("FP16 is supported on this device.")
except RuntimeError as e:
# 如果发生运行时错误可能是因为设备不支持FP16
print(f"FP16 is not supported on this device. Error: {e}")
is_half = False
gpt_path = os.environ.get("gpt_path", None)
sovits_path = os.environ.get("sovits_path", None)
cnhubert_base_path = os.environ.get("cnhubert_base_path", None)