Fix: handle empty bert_list and align BERT/phones length

This commit is contained in:
Bronya Zaychik 2026-04-19 13:28:08 +08:00
parent d9f03dad3e
commit 4a2d88ce4e

27
api.py
View File

@ -599,13 +599,38 @@ def get_phones_and_bert(text, language, version, final=False):
phones_list.append(phones)
norm_text_list.append(norm_text)
bert_list.append(bert)
bert = torch.cat(bert_list, dim=1)
# Fix: handle empty bert_list to avoid torch.cat() crash
if bert_list:
bert = torch.cat(bert_list, dim=1)
else:
phones_total = sum(phones_list, []) if phones_list else []
bert = torch.zeros(
(1024, max(len(phones_total), 1)),
dtype=torch.float16 if is_half == True else torch.float32,
).to(device)
phones = sum(phones_list, [])
norm_text = "".join(norm_text_list)
if not final and len(phones) < 6:
return get_phones_and_bert("." + text, language, version, final=True)
# Fix: align BERT feature length with phones length to avoid dimension mismatch
# e.g., RuntimeError: tensor a (44) must match tensor b (45) at non-singleton dimension 1
phones_len = len(phones)
bert_len = bert.shape[1]
if phones_len != bert_len:
import torch.nn.functional as F
logger.warning(f"[TTS] BERT length mismatch: phones={phones_len}, bert={bert_len}, adjusting...")
bert = bert.transpose(1, 2) # (1024, seq) -> (seq, 1024)
if phones_len > bert_len:
# Interpolate to enlarge
bert = F.interpolate(bert.unsqueeze(0), size=phones_len, mode='linear', align_corners=False)
else:
# Truncate excess
bert = bert[:, :phones_len, :]
bert = bert.squeeze(0).transpose(0, 1) # (seq, 1024) -> (1024, seq)
bert = bert.to(device)
return phones, bert.to(torch.float16 if is_half == True else torch.float32), norm_text