Fix training error caused by float type of default_batch_size parameter (#2662)

This commit is contained in:
Spr_Aachen 2025-11-28 22:53:43 +08:00 committed by GitHub
parent 6fb441f65e
commit 92d2d337fd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -116,8 +116,8 @@ def set_default():
gpu_info = "\n".join(gpu_infos)
if is_gpu_ok:
minmem = min(mem)
default_batch_size = minmem // 2 if version not in v3v4set else minmem // 8
default_batch_size_s1 = minmem // 2
default_batch_size = int(minmem // 2 if version not in v3v4set else minmem // 8)
default_batch_size_s1 = int(minmem // 2)
else:
default_batch_size = default_batch_size_s1 = int(psutil.virtual_memory().total / 1024 / 1024 / 1024 / 4)
if version not in v3v4set: