diff --git a/GPT_SoVITS/AR/models/t2s_model.py b/GPT_SoVITS/AR/models/t2s_model.py index 4725b7a3..2336dbb1 100644 --- a/GPT_SoVITS/AR/models/t2s_model.py +++ b/GPT_SoVITS/AR/models/t2s_model.py @@ -794,7 +794,7 @@ class Text2SemanticDecoder(nn.Module): y_list = [] idx_list = [] for i in range(len(x)): - y, idx = self.infer_panel_naive( + y, idx = next(self.infer_panel_naive( x[i].unsqueeze(0), x_lens[i], prompts[i].unsqueeze(0) if prompts is not None else None, @@ -805,7 +805,7 @@ class Text2SemanticDecoder(nn.Module): temperature, repetition_penalty, **kwargs, - ) + )) y_list.append(y[0]) idx_list.append(idx) @@ -822,6 +822,8 @@ class Text2SemanticDecoder(nn.Module): early_stop_num: int = -1, temperature: float = 1.0, repetition_penalty: float = 1.35, + streaming_mode: bool = False, + chunk_length: int = 24, **kwargs, ): x = self.ar_text_embedding(x) @@ -875,7 +877,9 @@ class Text2SemanticDecoder(nn.Module): .to(device=x.device, dtype=torch.bool) ) + token_counter = 0 for idx in tqdm(range(1500)): + token_counter+=1 if xy_attn_mask is not None: xy_dec, k_cache, v_cache = self.t2s_transformer.process_prompt(xy_pos, xy_attn_mask, None) else: @@ -900,22 +904,42 @@ class Text2SemanticDecoder(nn.Module): if torch.argmax(logits, dim=-1)[0] == self.EOS or samples[0, 0] == self.EOS: stop = True + y=y[:, :-1] + token_counter -= 1 + + if idx == 1499: + stop = True + if stop: if y.shape[1] == 0: y = torch.concat([y, torch.zeros_like(samples)], dim=1) print("bad zero prediction") - print(f"T2S Decoding EOS [{prefix_len} -> {y.shape[1]}]") + # print(f"T2S Decoding EOS [{prefix_len} -> {y.shape[1]}]") + if streaming_mode: + # y=y[:, :-1] + # res_len = (y.shape[1] - prefix_len)%chunk_length + yield (y[:, -token_counter:]) if token_counter!= 0 else None, True break + if streaming_mode and token_counter == chunk_length: + token_counter = 0 + yield y[:, -chunk_length:], False + + ####################### update next step ################################### y_emb = self.ar_audio_embedding(y[:, -1:]) xy_pos = y_emb * self.ar_audio_position.x_scale + self.ar_audio_position.alpha * self.ar_audio_position.pe[ :, y_len + idx ].to(dtype=y_emb.dtype, device=y_emb.device) - if ref_free: - return y[:, :-1], 0 - return y[:, :-1], idx + + + if not streaming_mode: + if ref_free: + yield y, 0 + yield y, idx + + def infer_panel( self, @@ -930,6 +954,6 @@ class Text2SemanticDecoder(nn.Module): repetition_penalty: float = 1.35, **kwargs, ): - return self.infer_panel_naive( + return next(self.infer_panel_naive( x, x_lens, prompts, bert_feature, top_k, top_p, early_stop_num, temperature, repetition_penalty, **kwargs - ) + )) diff --git a/GPT_SoVITS/TTS_infer_pack/TTS.py b/GPT_SoVITS/TTS_infer_pack/TTS.py index d20daee3..8fd0a084 100644 --- a/GPT_SoVITS/TTS_infer_pack/TTS.py +++ b/GPT_SoVITS/TTS_infer_pack/TTS.py @@ -258,6 +258,12 @@ class TTS_Config: v1_languages: list = ["auto", "en", "zh", "ja", "all_zh", "all_ja"] v2_languages: list = ["auto", "auto_yue", "en", "zh", "ja", "yue", "ko", "all_zh", "all_ja", "all_yue", "all_ko"] languages: list = v2_languages + mute_tokens: dict = { + "v1" : 486, + "v2" : 486, + "v3" : 486, + "v4" : 486, + } # "all_zh",#全部按中文识别 # "en",#全部按英文识别#######不变 # "all_ja",#全部按日文识别 @@ -956,7 +962,10 @@ class TTS: "parallel_infer": True, # bool. whether to use parallel inference. "repetition_penalty": 1.35 # float. repetition penalty for T2S model. "sample_steps": 32, # int. number of sampling steps for VITS model V3. - "super_sampling": False, # bool. whether to use super-sampling for audio when using VITS model V3. + "super_sampling": False, # bool. whether to use super-sampling for audio when using VITS model V3. + "streaming_mode": False, # bool. return audio chunk by chunk. + "overlap_length": 2, # int. overlap length of semantic tokens for streaming mode. + "chunk_length: 24, # int. chunk length of semantic tokens for streaming mode. (affects audio chunk size) } returns: Tuple[int, np.ndarray]: sampling rate and audio data. @@ -986,19 +995,40 @@ class TTS: repetition_penalty = inputs.get("repetition_penalty", 1.35) sample_steps = inputs.get("sample_steps", 32) super_sampling = inputs.get("super_sampling", False) + streaming_mode = inputs.get("streaming_mode", False) + overlap_length = inputs.get("overlap_length", 2) + chunk_length = inputs.get("chunk_length", 24) - if parallel_infer: + if parallel_infer and not streaming_mode: print(i18n("并行推理模式已开启")) self.t2s_model.model.infer_panel = self.t2s_model.model.infer_panel_batch_infer + elif not parallel_infer and streaming_mode and not self.configs.use_vocoder: + print(i18n("流式推理模式已开启")) + self.t2s_model.model.infer_panel = self.t2s_model.model.infer_panel_naive + elif streaming_mode and self.configs.use_vocoder: + print(i18n("SoVits V3/4模型不支持流式推理模式,已自动回退到分段返回模式")) + streaming_mode = False + return_fragment = True + if parallel_infer: + self.t2s_model.model.infer_panel = self.t2s_model.model.infer_panel_batch_infer + else: + self.t2s_model.model.infer_panel = self.t2s_model.model.infer_panel_naive_batched + elif parallel_infer and streaming_mode: + print(i18n("不支持同时开启并行推理和流式推理模式,已自动关闭并行推理模式")) + parallel_infer = False + self.t2s_model.model.infer_panel = self.t2s_model.model.infer_panel_naive else: - print(i18n("并行推理模式已关闭")) + print(i18n("朴素推理模式已开启")) self.t2s_model.model.infer_panel = self.t2s_model.model.infer_panel_naive_batched - if return_fragment: - print(i18n("分段返回模式已开启")) - if split_bucket: - split_bucket = False - print(i18n("分段返回模式不支持分桶处理,已自动关闭分桶处理")) + if return_fragment and streaming_mode: + print(i18n("流式推理模式不支持分段返回,已自动关闭分段返回")) + return_fragment = False + + if (return_fragment or streaming_mode) and split_bucket: + print(i18n("分段返回模式/流式推理模式不支持分桶处理,已自动关闭分桶处理")) + split_bucket = False + if split_bucket and speed_factor == 1.0 and not (self.configs.use_vocoder and parallel_infer): print(i18n("分桶处理模式已开启")) @@ -1011,9 +1041,9 @@ class TTS: else: print(i18n("分桶处理模式已关闭")) - if fragment_interval < 0.01: - fragment_interval = 0.01 - print(i18n("分段间隔过小,已自动设置为0.01")) + # if fragment_interval < 0.01: + # fragment_interval = 0.01 + # print(i18n("分段间隔过小,已自动设置为0.01")) no_prompt_text = False if prompt_text in [None, ""]: @@ -1071,7 +1101,7 @@ class TTS: ###### text preprocessing ######## t1 = time.perf_counter() data: list = None - if not return_fragment: + if not (return_fragment or streaming_mode): data = self.text_preprocessor.preprocess(text, text_lang, text_split_method, self.configs.version) if len(data) == 0: yield 16000, np.zeros(int(16000), dtype=np.int16) @@ -1134,7 +1164,7 @@ class TTS: output_sr = self.configs.sampling_rate if not self.configs.use_vocoder else self.vocoder_configs["sr"] for item in data: t3 = time.perf_counter() - if return_fragment: + if return_fragment or streaming_mode: item = make_batch(item) if item is None: continue @@ -1156,94 +1186,214 @@ class TTS: self.prompt_cache["prompt_semantic"].expand(len(all_phoneme_ids), -1).to(self.configs.device) ) - print(f"############ {i18n('预测语义Token')} ############") - pred_semantic_list, idx_list = self.t2s_model.model.infer_panel( - all_phoneme_ids, - all_phoneme_lens, - prompt, - all_bert_features, - # prompt_phone_len=ph_offset, - top_k=top_k, - top_p=top_p, - temperature=temperature, - early_stop_num=self.configs.hz * self.configs.max_sec, - max_len=max_len, - repetition_penalty=repetition_penalty, - ) - t4 = time.perf_counter() - t_34 += t4 - t3 + if not streaming_mode: + print(f"############ {i18n('预测语义Token')} ############") + pred_semantic_list, idx_list = self.t2s_model.model.infer_panel( + all_phoneme_ids, + all_phoneme_lens, + prompt, + all_bert_features, + # prompt_phone_len=ph_offset, + top_k=top_k, + top_p=top_p, + temperature=temperature, + early_stop_num=self.configs.hz * self.configs.max_sec, + max_len=max_len, + repetition_penalty=repetition_penalty, + ) + t4 = time.perf_counter() + t_34 += t4 - t3 - refer_audio_spec: torch.Tensor = [ - item.to(dtype=self.precision, device=self.configs.device) - for item in self.prompt_cache["refer_spec"] - ] + refer_audio_spec: torch.Tensor = [ + item.to(dtype=self.precision, device=self.configs.device) + for item in self.prompt_cache["refer_spec"] + ] - batch_audio_fragment = [] + batch_audio_fragment = [] - # ## vits并行推理 method 1 - # pred_semantic_list = [item[-idx:] for item, idx in zip(pred_semantic_list, idx_list)] - # pred_semantic_len = torch.LongTensor([item.shape[0] for item in pred_semantic_list]).to(self.configs.device) - # pred_semantic = self.batch_sequences(pred_semantic_list, axis=0, pad_value=0).unsqueeze(0) - # max_len = 0 - # for i in range(0, len(batch_phones)): - # max_len = max(max_len, batch_phones[i].shape[-1]) - # batch_phones = self.batch_sequences(batch_phones, axis=0, pad_value=0, max_length=max_len) - # batch_phones = batch_phones.to(self.configs.device) - # batch_audio_fragment = (self.vits_model.batched_decode( - # pred_semantic, pred_semantic_len, batch_phones, batch_phones_len,refer_audio_spec - # )) - print(f"############ {i18n('合成音频')} ############") - if not self.configs.use_vocoder: - if speed_factor == 1.0: - print(f"{i18n('并行合成中')}...") - # ## vits并行推理 method 2 - pred_semantic_list = [item[-idx:] for item, idx in zip(pred_semantic_list, idx_list)] - upsample_rate = math.prod(self.vits_model.upsample_rates) - audio_frag_idx = [ - pred_semantic_list[i].shape[0] * 2 * upsample_rate - for i in range(0, len(pred_semantic_list)) - ] - audio_frag_end_idx = [sum(audio_frag_idx[: i + 1]) for i in range(0, len(audio_frag_idx))] - all_pred_semantic = ( - torch.cat(pred_semantic_list).unsqueeze(0).unsqueeze(0).to(self.configs.device) - ) - _batch_phones = torch.cat(batch_phones).unsqueeze(0).to(self.configs.device) - _batch_audio_fragment = self.vits_model.decode( - all_pred_semantic, _batch_phones, refer_audio_spec, speed=speed_factor - ).detach()[0, 0, :] - audio_frag_end_idx.insert(0, 0) - batch_audio_fragment = [ - _batch_audio_fragment[audio_frag_end_idx[i - 1] : audio_frag_end_idx[i]] - for i in range(1, len(audio_frag_end_idx)) - ] - else: - # ## vits串行推理 - for i, idx in enumerate(tqdm(idx_list)): - phones = batch_phones[i].unsqueeze(0).to(self.configs.device) - _pred_semantic = ( - pred_semantic_list[i][-idx:].unsqueeze(0).unsqueeze(0) - ) # .unsqueeze(0)#mq要多unsqueeze一次 - audio_fragment = self.vits_model.decode( - _pred_semantic, phones, refer_audio_spec, speed=speed_factor - ).detach()[0, 0, :] - batch_audio_fragment.append(audio_fragment) ###试试重建不带上prompt部分 - else: - if parallel_infer: - print(f"{i18n('并行合成中')}...") - audio_fragments = self.using_vocoder_synthesis_batched_infer( - idx_list, pred_semantic_list, batch_phones, speed=speed_factor, sample_steps=sample_steps - ) - batch_audio_fragment.extend(audio_fragments) - else: - for i, idx in enumerate(tqdm(idx_list)): - phones = batch_phones[i].unsqueeze(0).to(self.configs.device) - _pred_semantic = ( - pred_semantic_list[i][-idx:].unsqueeze(0).unsqueeze(0) - ) # .unsqueeze(0)#mq要多unsqueeze一次 - audio_fragment = self.using_vocoder_synthesis( - _pred_semantic, phones, speed=speed_factor, sample_steps=sample_steps + # ## vits并行推理 method 1 + # pred_semantic_list = [item[-idx:] for item, idx in zip(pred_semantic_list, idx_list)] + # pred_semantic_len = torch.LongTensor([item.shape[0] for item in pred_semantic_list]).to(self.configs.device) + # pred_semantic = self.batch_sequences(pred_semantic_list, axis=0, pad_value=0).unsqueeze(0) + # max_len = 0 + # for i in range(0, len(batch_phones)): + # max_len = max(max_len, batch_phones[i].shape[-1]) + # batch_phones = self.batch_sequences(batch_phones, axis=0, pad_value=0, max_length=max_len) + # batch_phones = batch_phones.to(self.configs.device) + # batch_audio_fragment = (self.vits_model.batched_decode( + # pred_semantic, pred_semantic_len, batch_phones, batch_phones_len,refer_audio_spec + # )) + print(f"############ {i18n('合成音频')} ############") + if not self.configs.use_vocoder: + if speed_factor == 1.0: + print(f"{i18n('并行合成中')}...") + # ## vits并行推理 method 2 + pred_semantic_list = [item[-idx:] for item, idx in zip(pred_semantic_list, idx_list)] + upsample_rate = math.prod(self.vits_model.upsample_rates) + audio_frag_idx = [ + pred_semantic_list[i].shape[0] * 2 * upsample_rate + for i in range(0, len(pred_semantic_list)) + ] + audio_frag_end_idx = [sum(audio_frag_idx[: i + 1]) for i in range(0, len(audio_frag_idx))] + all_pred_semantic = ( + torch.cat(pred_semantic_list).unsqueeze(0).unsqueeze(0).to(self.configs.device) ) - batch_audio_fragment.append(audio_fragment) + _batch_phones = torch.cat(batch_phones).unsqueeze(0).to(self.configs.device) + _batch_audio_fragment = self.vits_model.decode( + all_pred_semantic, _batch_phones, refer_audio_spec, speed=speed_factor + ).detach()[0, 0, :] + audio_frag_end_idx.insert(0, 0) + batch_audio_fragment = [ + _batch_audio_fragment[audio_frag_end_idx[i - 1] : audio_frag_end_idx[i]] + for i in range(1, len(audio_frag_end_idx)) + ] + else: + # ## vits串行推理 + for i, idx in enumerate(tqdm(idx_list)): + phones = batch_phones[i].unsqueeze(0).to(self.configs.device) + _pred_semantic = ( + pred_semantic_list[i][-idx:].unsqueeze(0).unsqueeze(0) + ) # .unsqueeze(0)#mq要多unsqueeze一次 + audio_fragment = self.vits_model.decode( + _pred_semantic, phones, refer_audio_spec, speed=speed_factor + ).detach()[0, 0, :] + batch_audio_fragment.append(audio_fragment) ###试试重建不带上prompt部分 + else: + if parallel_infer: + print(f"{i18n('并行合成中')}...") + audio_fragments = self.using_vocoder_synthesis_batched_infer( + idx_list, pred_semantic_list, batch_phones, speed=speed_factor, sample_steps=sample_steps + ) + batch_audio_fragment.extend(audio_fragments) + else: + for i, idx in enumerate(tqdm(idx_list)): + phones = batch_phones[i].unsqueeze(0).to(self.configs.device) + _pred_semantic = ( + pred_semantic_list[i][-idx:].unsqueeze(0).unsqueeze(0) + ) # .unsqueeze(0)#mq要多unsqueeze一次 + audio_fragment = self.using_vocoder_synthesis( + _pred_semantic, phones, speed=speed_factor, sample_steps=sample_steps + ) + batch_audio_fragment.append(audio_fragment) + + else: + refer_audio_spec: torch.Tensor = [ + item.to(dtype=self.precision, device=self.configs.device) + for item in self.prompt_cache["refer_spec"] + ] + semantic_token_generator =self.t2s_model.model.infer_panel( + all_phoneme_ids[0].unsqueeze(0), + all_phoneme_lens, + prompt, + all_bert_features[0].unsqueeze(0), + # prompt_phone_len=ph_offset, + top_k=top_k, + top_p=top_p, + temperature=temperature, + early_stop_num=self.configs.hz * self.configs.max_sec, + max_len=max_len, + repetition_penalty=repetition_penalty, + streaming_mode=True, + chunk_length=chunk_length, + ) + t4 = time.perf_counter() + t_34 += t4 - t3 + phones = batch_phones[0].unsqueeze(0).to(self.configs.device) + is_first_chunk = True + + if not self.configs.use_vocoder: + if speed_factor == 1.0: + upsample_rate = math.prod(self.vits_model.upsample_rates)*(2 if self.vits_model.semantic_frame_rate == "25hz" else 1) + else: + upsample_rate = math.prod(self.vits_model.upsample_rates)*((2 if self.vits_model.semantic_frame_rate == "25hz" else 1)/speed_factor) + else: + if speed_factor == 1.0: + upsample_rate = self.vocoder_configs["upsample_rate"]*(3.875 if self.configs.version == "v3" else 4) + else: + upsample_rate = self.vocoder_configs["upsample_rate"]*((3.875 if self.configs.version == "v3" else 4)/speed_factor) + + last_audio_chunk = None + last_tokens = None + previous_tokens = [] + overlap_size = math.ceil(overlap_length*upsample_rate) + for semantic_tokens, is_final in semantic_token_generator: + if semantic_tokens is None and last_audio_chunk is not None: + yield self.audio_postprocess( + [[last_audio_chunk[-overlap_size:]]], + output_sr, + None, + speed_factor, + False, + 0.0, + super_sampling if self.configs.use_vocoder and self.configs.version == "v3" else False, + ) + continue + + _semantic_tokens = semantic_tokens + # if is_first_chunk: + # _semantic_tokens = torch.cat([torch.ones((1,overlap_length), dtype=torch.long, device=self.configs.device)*self.configs.mute_tokens[self.configs.version], _semantic_tokens], dim=-1) + # else: + # _semantic_tokens = torch.cat([last_tokens[:, -overlap_length:], _semantic_tokens], dim=-1) + # # _semantic_tokens = torch.cat(previous_tokens+[_semantic_tokens,], dim=-1) + + previous_tokens.append(semantic_tokens) + + _semantic_tokens = torch.cat(previous_tokens, dim=-1) + + + # last_tokens = semantic_tokens + + # print(f"_semantic_tokens shape:{_semantic_tokens.shape}") + + + if not self.configs.use_vocoder: + audio_chunk = self.vits_model.decode( + _semantic_tokens.unsqueeze(0), + phones, refer_audio_spec, + speed=speed_factor, + result_length=semantic_tokens.shape[-1]+overlap_length if not is_first_chunk else None + # result_length=chunk_length if not is_first_chunk else None + ).detach()[0, 0, :] + else: + audio_chunk = self.using_vocoder_synthesis( + _semantic_tokens.unsqueeze(0), phones, + speed=speed_factor, sample_steps=sample_steps, + result_length = semantic_tokens.shape[-1]+overlap_length if not is_first_chunk else None + ) + + + + # if is_first_chunk: + # audio_chunk = audio_chunk[overlap_size:] + # # is_first_chunk = False + + audio_chunk_ = audio_chunk + if is_first_chunk and not is_final: + is_first_chunk = False + audio_chunk_ = audio_chunk_[:-overlap_size] + elif is_first_chunk and is_final: + is_first_chunk = False + elif not is_first_chunk and not is_final: + audio_chunk_ = self.sola_algorithm([last_audio_chunk, audio_chunk_], overlap_size) + audio_chunk_ = ( + audio_chunk_[last_audio_chunk.shape[0]-overlap_size:-overlap_size] if not is_final \ + else audio_chunk_[last_audio_chunk.shape[0]-overlap_size:] + ) + # audio_chunk_ = audio_chunk_[:-overlap_size] if not is_final else audio_chunk_ + + last_audio_chunk = audio_chunk + yield self.audio_postprocess( + [[audio_chunk_]], + output_sr, + None, + speed_factor, + False, + 0.0, + super_sampling if self.configs.use_vocoder and self.configs.version == "v3" else False, + ) + print(f"first_package_delay: {time.perf_counter()-t0:.3f}") + + yield output_sr, np.zeros(int(output_sr*fragment_interval), dtype=np.int16) t5 = time.perf_counter() t_45 += t5 - t4 @@ -1258,17 +1408,18 @@ class TTS: fragment_interval, super_sampling if self.configs.use_vocoder and self.configs.version == "v3" else False, ) + elif streaming_mode:... else: audio.append(batch_audio_fragment) if self.stop_flag: - yield 16000, np.zeros(int(16000), dtype=np.int16) + yield output_sr, np.zeros(int(output_sr), dtype=np.int16) return - if not return_fragment: + if not (return_fragment or streaming_mode): print("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t_34, t_45)) if len(audio) == 0: - yield 16000, np.zeros(int(16000), dtype=np.int16) + yield output_sr, np.zeros(int(output_sr), dtype=np.int16) return yield self.audio_postprocess( audio, @@ -1315,16 +1466,17 @@ class TTS: fragment_interval: float = 0.3, super_sampling: bool = False, ) -> Tuple[int, np.ndarray]: - zero_wav = torch.zeros( - int(self.configs.sampling_rate * fragment_interval), dtype=self.precision, device=self.configs.device - ) + if fragment_interval>0: + zero_wav = torch.zeros( + int(self.configs.sampling_rate * fragment_interval), dtype=self.precision, device=self.configs.device + ) for i, batch in enumerate(audio): for j, audio_fragment in enumerate(batch): max_audio = torch.abs(audio_fragment).max() # 简单防止16bit爆音 if max_audio > 1: audio_fragment /= max_audio - audio_fragment: torch.Tensor = torch.cat([audio_fragment, zero_wav], dim=0) + audio_fragment: torch.Tensor = torch.cat([audio_fragment, zero_wav], dim=0) if fragment_interval>0 else audio_fragment audio[i][j] = audio_fragment if split_bucket: @@ -1344,12 +1496,12 @@ class TTS: max_audio = np.abs(audio).max() if max_audio > 1: audio /= max_audio + audio = (audio * 32768).astype(np.int16) t2 = time.perf_counter() print(f"超采样用时:{t2 - t1:.3f}s") else: - audio = audio.cpu().numpy() - - audio = (audio * 32768).astype(np.int16) + audio = audio.float() * 32768 + audio = audio.to(dtype=torch.int16).cpu().numpy() # try: # if speed_factor != 1.0: @@ -1360,7 +1512,7 @@ class TTS: return sr, audio def using_vocoder_synthesis( - self, semantic_tokens: torch.Tensor, phones: torch.Tensor, speed: float = 1.0, sample_steps: int = 32 + self, semantic_tokens: torch.Tensor, phones: torch.Tensor, speed: float = 1.0, sample_steps: int = 32, result_length:int=None ): prompt_semantic_tokens = self.prompt_cache["prompt_semantic"].unsqueeze(0).unsqueeze(0).to(self.configs.device) prompt_phones = torch.LongTensor(self.prompt_cache["phones"]).unsqueeze(0).to(self.configs.device) @@ -1392,7 +1544,7 @@ class TTS: chunk_len = T_chunk - T_min mel2 = mel2.to(self.precision) - fea_todo, ge = self.vits_model.decode_encp(semantic_tokens, phones, refer_audio_spec, ge, speed) + fea_todo, ge = self.vits_model.decode_encp(semantic_tokens, phones, refer_audio_spec, ge, speed, result_length=result_length) cfm_resss = [] idx = 0 @@ -1415,7 +1567,7 @@ class TTS: cfm_res = torch.cat(cfm_resss, 2) cfm_res = denorm_spec(cfm_res) - with torch.inference_mode(): + with torch.no_grad(): wav_gen = self.vocoder(cfm_res) audio = wav_gen[0][0] # .cpu().detach().numpy() diff --git a/GPT_SoVITS/module/models.py b/GPT_SoVITS/module/models.py index 21f60d99..4f170871 100644 --- a/GPT_SoVITS/module/models.py +++ b/GPT_SoVITS/module/models.py @@ -209,7 +209,7 @@ class TextEncoder(nn.Module): self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1) - def forward(self, y, y_lengths, text, text_lengths, ge, speed=1, test=None): + def forward(self, y, y_lengths, text, text_lengths, ge, speed=1, test=None, result_length:int=None): y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(y.dtype) y = self.ssl_proj(y * y_mask) * y_mask @@ -223,6 +223,11 @@ class TextEncoder(nn.Module): text = self.encoder_text(text * text_mask, text_mask) y = self.mrte(y, y_mask, text, text_mask, ge) y = self.encoder2(y * y_mask, y_mask) + + if result_length is not None: + y = y[:, :, -result_length:] + y_mask = y_mask[:, :, -result_length:] + if speed != 1: y = F.interpolate(y, size=int(y.shape[-1] / speed) + 1, mode="linear") y_mask = F.interpolate(y_mask, size=y.shape[-1], mode="nearest") @@ -940,7 +945,7 @@ class SynthesizerTrn(nn.Module): return o, y_mask, (z, z_p, m_p, logs_p) @torch.no_grad() - def decode(self, codes, text, refer, noise_scale=0.5, speed=1): + def decode(self, codes, text, refer, noise_scale=0.5, speed=1, result_length:int=None): def get_ge(refer): ge = None if refer is not None: @@ -967,7 +972,8 @@ class SynthesizerTrn(nn.Module): quantized = self.quantizer.decode(codes) if self.semantic_frame_rate == "25hz": quantized = F.interpolate(quantized, size=int(quantized.shape[-1] * 2), mode="nearest") - x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge, speed) + result_length = (2*result_length) if result_length is not None else None + x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge, speed, result_length=result_length) z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale z = self.flow(z_p, y_mask, g=ge, reverse=True) @@ -1187,7 +1193,7 @@ class SynthesizerTrnV3(nn.Module): return cfm_loss @torch.no_grad() - def decode_encp(self, codes, text, refer, ge=None, speed=1): + def decode_encp(self, codes, text, refer, ge=None, speed=1, result_length:int=None): # print(2333333,refer.shape) # ge=None if ge == None: @@ -1195,17 +1201,21 @@ class SynthesizerTrnV3(nn.Module): refer_mask = torch.unsqueeze(commons.sequence_mask(refer_lengths, refer.size(2)), 1).to(refer.dtype) ge = self.ref_enc(refer[:, :704] * refer_mask, refer_mask) y_lengths = torch.LongTensor([int(codes.size(2) * 2)]).to(codes.device) + if speed == 1: - sizee = int(codes.size(2) * (3.875 if self.version=="v3"else 4)) + sizee = int((codes.size(2) if result_length is None else result_length) * (3.875 if self.version=="v3"else 4)) else: - sizee = int(codes.size(2) * (3.875 if self.version=="v3"else 4) / speed) + 1 + sizee = int((codes.size(2) if result_length is None else result_length) * (3.875 if self.version=="v3"else 4) / speed) + 1 y_lengths1 = torch.LongTensor([sizee]).to(codes.device) text_lengths = torch.LongTensor([text.size(-1)]).to(text.device) quantized = self.quantizer.decode(codes) if self.semantic_frame_rate == "25hz": quantized = F.interpolate(quantized, scale_factor=2, mode="nearest") ##BCT - x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge, speed) + result_length = result_length * 2 if result_length is not None else None + + + x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge, speed, result_length=result_length) fea = self.bridge(x) fea = F.interpolate(fea, scale_factor=(1.875 if self.version=="v3"else 2), mode="nearest") ##BCT ####more wn paramter to learn mel diff --git a/api_v2.py b/api_v2.py index 87082074..d623693e 100644 --- a/api_v2.py +++ b/api_v2.py @@ -40,7 +40,10 @@ POST: "parallel_infer": True, # bool. whether to use parallel inference. "repetition_penalty": 1.35 # float. repetition penalty for T2S model. "sample_steps": 32, # int. number of sampling steps for VITS model V3. - "super_sampling": False, # bool. whether to use super-sampling for audio when using VITS model V3. + "super_sampling": False, # bool. whether to use super-sampling for audio when using VITS model V3. + "overlap_length": 2, # int. overlap length of semantic tokens for streaming mode. + "chunk_length: 24, # int. chunk length of semantic tokens for streaming mode. (affects audio chunk size) + "return_fragment": False, # bool. step by step return the audio fragment. (old version of streaming mode) } ``` @@ -170,6 +173,9 @@ class TTS_Request(BaseModel): repetition_penalty: float = 1.35 sample_steps: int = 32 super_sampling: bool = False + overlap_length: int = 2 + chunk_length: int = 24 + return_fragment: bool = False ### modify from https://github.com/RVC-Boss/GPT-SoVITS/pull/894/files @@ -325,7 +331,10 @@ async def tts_handle(req: dict): "parallel_infer": True, # bool.(optional) whether to use parallel inference. "repetition_penalty": 1.35 # float.(optional) repetition penalty for T2S model. "sample_steps": 32, # int. number of sampling steps for VITS model V3. - "super_sampling": False, # bool. whether to use super-sampling for audio when using VITS model V3. + "super_sampling": False, # bool. whether to use super-sampling for audio when using VITS model V3. + "overlap_length": 2, # int. overlap length of semantic tokens for streaming mode. + "chunk_length: 24, # int. chunk length of semantic tokens for streaming mode. (affects audio chunk size) + "return_fragment": False, # bool. step by step return the audio fragment. (old version of streaming mode) } returns: StreamingResponse: audio stream response. @@ -339,8 +348,10 @@ async def tts_handle(req: dict): if check_res is not None: return check_res - if streaming_mode or return_fragment: - req["return_fragment"] = True + req["streaming_mode"] = streaming_mode + req["return_fragment"] = return_fragment + streaming_mode = streaming_mode or return_fragment + try: tts_generator = tts_pipeline.run(req) @@ -404,6 +415,9 @@ async def tts_get_endpoint( repetition_penalty: float = 1.35, sample_steps: int = 32, super_sampling: bool = False, + overlap_length: int = 2, + chunk_length: int = 24, + return_fragment: bool = False, ): req = { "text": text, @@ -428,6 +442,9 @@ async def tts_get_endpoint( "repetition_penalty": float(repetition_penalty), "sample_steps": int(sample_steps), "super_sampling": super_sampling, + "overlap_length": int(overlap_length), + "chunk_length": int(chunk_length), + "return_fragment": return_fragment, } return await tts_handle(req)