From 6591e86df39282c9cc672bec334fb9c40716bdb2 Mon Sep 17 00:00:00 2001 From: XTer Date: Sat, 6 Apr 2024 23:39:40 +0800 Subject: [PATCH] =?UTF-8?q?=E6=81=A2=E5=A4=8Dmake=20batch=E7=9A=84?= =?UTF-8?q?=E4=BD=8D=E7=BD=AE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- GPT_SoVITS/TTS_infer_pack/TTS.py | 51 ++++++++++++++++---------------- 1 file changed, 26 insertions(+), 25 deletions(-) diff --git a/GPT_SoVITS/TTS_infer_pack/TTS.py b/GPT_SoVITS/TTS_infer_pack/TTS.py index d3e394ac..509f0592 100644 --- a/GPT_SoVITS/TTS_infer_pack/TTS.py +++ b/GPT_SoVITS/TTS_infer_pack/TTS.py @@ -598,30 +598,7 @@ class TTS: tuple[int, np.ndarray]: sampling rate and audio data. """ - def make_batch(batch_texts): - batch_data = [] - print(i18n("############ 提取文本Bert特征 ############")) - for text in tqdm(batch_texts): - phones, bert_features, norm_text = self.text_preprocessor.segment_and_extract_feature_for_text(text, text_lang) - if phones is None: - continue - res={ - "phones": phones, - "bert_features": bert_features, - "norm_text": norm_text, - } - batch_data.append(res) - if len(batch_data) == 0: - return None - batch, _ = self.to_batch(batch_data, - prompt_data=self.prompt_cache if not no_prompt_text else None, - batch_size=batch_size, - threshold=batch_threshold, - split_bucket=False, - device=self.configs.device, - precision=self.precision - ) - return batch[0] + # 直接给全体套一个torch.no_grad() with torch.no_grad(): @@ -719,7 +696,31 @@ class TTS: if i%batch_size == 0: data.append([]) data[-1].append(texts[i]) - + + def make_batch(batch_texts): + batch_data = [] + print(i18n("############ 提取文本Bert特征 ############")) + for text in tqdm(batch_texts): + phones, bert_features, norm_text = self.text_preprocessor.segment_and_extract_feature_for_text(text, text_lang) + if phones is None: + continue + res={ + "phones": phones, + "bert_features": bert_features, + "norm_text": norm_text, + } + batch_data.append(res) + if len(batch_data) == 0: + return None + batch, _ = self.to_batch(batch_data, + prompt_data=self.prompt_cache if not no_prompt_text else None, + batch_size=batch_size, + threshold=batch_threshold, + split_bucket=False, + device=self.configs.device, + precision=self.precision + ) + return batch[0] t2 = ttime()