From 8e161c46fabdf4c291036eaf2fb98ceb201e2187 Mon Sep 17 00:00:00 2001 From: ChasonJiang <1440499136@qq.com> Date: Mon, 5 May 2025 20:05:29 +0800 Subject: [PATCH 01/10] =?UTF-8?q?=E6=9B=B4=E5=A5=BD=E7=9A=84=E6=B5=81?= =?UTF-8?q?=E5=BC=8F=E6=8E=A8=E7=90=86=E6=A8=A1=E5=BC=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- GPT_SoVITS/AR/models/t2s_model.py | 40 +++- GPT_SoVITS/TTS_infer_pack/TTS.py | 384 +++++++++++++++++++++--------- GPT_SoVITS/module/models.py | 24 +- api_v2.py | 25 +- 4 files changed, 338 insertions(+), 135 deletions(-) diff --git a/GPT_SoVITS/AR/models/t2s_model.py b/GPT_SoVITS/AR/models/t2s_model.py index 7196d6ab..4e672503 100644 --- a/GPT_SoVITS/AR/models/t2s_model.py +++ b/GPT_SoVITS/AR/models/t2s_model.py @@ -794,7 +794,7 @@ class Text2SemanticDecoder(nn.Module): y_list = [] idx_list = [] for i in range(len(x)): - y, idx = self.infer_panel_naive( + y, idx = next(self.infer_panel_naive( x[i].unsqueeze(0), x_lens[i], prompts[i].unsqueeze(0) if prompts is not None else None, @@ -805,7 +805,7 @@ class Text2SemanticDecoder(nn.Module): temperature, repetition_penalty, **kwargs, - ) + )) y_list.append(y[0]) idx_list.append(idx) @@ -822,6 +822,8 @@ class Text2SemanticDecoder(nn.Module): early_stop_num: int = -1, temperature: float = 1.0, repetition_penalty: float = 1.35, + streaming_mode: bool = False, + chunk_length: int = 24, **kwargs, ): x = self.ar_text_embedding(x) @@ -875,7 +877,9 @@ class Text2SemanticDecoder(nn.Module): .to(device=x.device, dtype=torch.bool) ) + token_counter = 0 for idx in tqdm(range(1500)): + token_counter+=1 if xy_attn_mask is not None: xy_dec, k_cache, v_cache = self.t2s_transformer.process_prompt(xy_pos, xy_attn_mask, None) else: @@ -900,22 +904,42 @@ class Text2SemanticDecoder(nn.Module): if torch.argmax(logits, dim=-1)[0] == self.EOS or samples[0, 0] == self.EOS: stop = True + y=y[:, :-1] + token_counter -= 1 + + if idx == 1499: + stop = True + if stop: if y.shape[1] == 0: y = torch.concat([y, torch.zeros_like(samples)], dim=1) print("bad zero prediction") - print(f"T2S Decoding EOS [{prefix_len} -> {y.shape[1]}]") + # print(f"T2S Decoding EOS [{prefix_len} -> {y.shape[1]}]") + if streaming_mode: + # y=y[:, :-1] + # res_len = (y.shape[1] - prefix_len)%chunk_length + yield (y[:, -token_counter:]) if token_counter!= 0 else None, True break + if streaming_mode and token_counter == chunk_length: + token_counter = 0 + yield y[:, -chunk_length:], False + + ####################### update next step ################################### y_emb = self.ar_audio_embedding(y[:, -1:]) xy_pos = y_emb * self.ar_audio_position.x_scale + self.ar_audio_position.alpha * self.ar_audio_position.pe[ :, y_len + idx ].to(dtype=y_emb.dtype, device=y_emb.device) - if ref_free: - return y[:, :-1], 0 - return y[:, :-1], idx + + + if not streaming_mode: + if ref_free: + yield y, 0 + yield y, idx + + def infer_panel( self, @@ -930,6 +954,6 @@ class Text2SemanticDecoder(nn.Module): repetition_penalty: float = 1.35, **kwargs, ): - return self.infer_panel_naive( + return next(self.infer_panel_naive( x, x_lens, prompts, bert_feature, top_k, top_p, early_stop_num, temperature, repetition_penalty, **kwargs - ) + )) diff --git a/GPT_SoVITS/TTS_infer_pack/TTS.py b/GPT_SoVITS/TTS_infer_pack/TTS.py index 0c1d2484..32a9c4dd 100644 --- a/GPT_SoVITS/TTS_infer_pack/TTS.py +++ b/GPT_SoVITS/TTS_infer_pack/TTS.py @@ -275,6 +275,12 @@ class TTS_Config: v1_languages: list = ["auto", "en", "zh", "ja", "all_zh", "all_ja"] v2_languages: list = ["auto", "auto_yue", "en", "zh", "ja", "yue", "ko", "all_zh", "all_ja", "all_yue", "all_ko"] languages: list = v2_languages + mute_tokens: dict = { + "v1" : 486, + "v2" : 486, + "v3" : 486, + "v4" : 486, + } # "all_zh",#全部按中文识别 # "en",#全部按英文识别#######不变 # "all_ja",#全部按日文识别 @@ -1008,7 +1014,10 @@ class TTS: "parallel_infer": True, # bool. whether to use parallel inference. "repetition_penalty": 1.35 # float. repetition penalty for T2S model. "sample_steps": 32, # int. number of sampling steps for VITS model V3. - "super_sampling": False, # bool. whether to use super-sampling for audio when using VITS model V3. + "super_sampling": False, # bool. whether to use super-sampling for audio when using VITS model V3. + "streaming_mode": False, # bool. return audio chunk by chunk. + "overlap_length": 2, # int. overlap length of semantic tokens for streaming mode. + "chunk_length: 24, # int. chunk length of semantic tokens for streaming mode. (affects audio chunk size) } returns: Tuple[int, np.ndarray]: sampling rate and audio data. @@ -1038,19 +1047,40 @@ class TTS: repetition_penalty = inputs.get("repetition_penalty", 1.35) sample_steps = inputs.get("sample_steps", 32) super_sampling = inputs.get("super_sampling", False) + streaming_mode = inputs.get("streaming_mode", False) + overlap_length = inputs.get("overlap_length", 2) + chunk_length = inputs.get("chunk_length", 24) - if parallel_infer: + if parallel_infer and not streaming_mode: print(i18n("并行推理模式已开启")) self.t2s_model.model.infer_panel = self.t2s_model.model.infer_panel_batch_infer + elif not parallel_infer and streaming_mode and not self.configs.use_vocoder: + print(i18n("流式推理模式已开启")) + self.t2s_model.model.infer_panel = self.t2s_model.model.infer_panel_naive + elif streaming_mode and self.configs.use_vocoder: + print(i18n("SoVits V3/4模型不支持流式推理模式,已自动回退到分段返回模式")) + streaming_mode = False + return_fragment = True + if parallel_infer: + self.t2s_model.model.infer_panel = self.t2s_model.model.infer_panel_batch_infer + else: + self.t2s_model.model.infer_panel = self.t2s_model.model.infer_panel_naive_batched + elif parallel_infer and streaming_mode: + print(i18n("不支持同时开启并行推理和流式推理模式,已自动关闭并行推理模式")) + parallel_infer = False + self.t2s_model.model.infer_panel = self.t2s_model.model.infer_panel_naive else: - print(i18n("并行推理模式已关闭")) + print(i18n("朴素推理模式已开启")) self.t2s_model.model.infer_panel = self.t2s_model.model.infer_panel_naive_batched - if return_fragment: - print(i18n("分段返回模式已开启")) - if split_bucket: - split_bucket = False - print(i18n("分段返回模式不支持分桶处理,已自动关闭分桶处理")) + if return_fragment and streaming_mode: + print(i18n("流式推理模式不支持分段返回,已自动关闭分段返回")) + return_fragment = False + + if (return_fragment or streaming_mode) and split_bucket: + print(i18n("分段返回模式/流式推理模式不支持分桶处理,已自动关闭分桶处理")) + split_bucket = False + if split_bucket and speed_factor == 1.0 and not (self.configs.use_vocoder and parallel_infer): print(i18n("分桶处理模式已开启")) @@ -1063,9 +1093,9 @@ class TTS: else: print(i18n("分桶处理模式已关闭")) - if fragment_interval < 0.01: - fragment_interval = 0.01 - print(i18n("分段间隔过小,已自动设置为0.01")) + # if fragment_interval < 0.01: + # fragment_interval = 0.01 + # print(i18n("分段间隔过小,已自动设置为0.01")) no_prompt_text = False if prompt_text in [None, ""]: @@ -1126,7 +1156,7 @@ class TTS: ###### text preprocessing ######## t1 = time.perf_counter() data: list = None - if not return_fragment: + if not (return_fragment or streaming_mode): data = self.text_preprocessor.preprocess(text, text_lang, text_split_method, self.configs.version) if len(data) == 0: yield 16000, np.zeros(int(16000), dtype=np.int16) @@ -1189,7 +1219,7 @@ class TTS: output_sr = self.configs.sampling_rate if not self.configs.use_vocoder else self.vocoder_configs["sr"] for item in data: t3 = time.perf_counter() - if return_fragment: + if return_fragment or streaming_mode: item = make_batch(item) if item is None: continue @@ -1211,22 +1241,23 @@ class TTS: self.prompt_cache["prompt_semantic"].expand(len(all_phoneme_ids), -1).to(self.configs.device) ) - print(f"############ {i18n('预测语义Token')} ############") - pred_semantic_list, idx_list = self.t2s_model.model.infer_panel( - all_phoneme_ids, - all_phoneme_lens, - prompt, - all_bert_features, - # prompt_phone_len=ph_offset, - top_k=top_k, - top_p=top_p, - temperature=temperature, - early_stop_num=self.configs.hz * self.configs.max_sec, - max_len=max_len, - repetition_penalty=repetition_penalty, - ) - t4 = time.perf_counter() - t_34 += t4 - t3 + if not streaming_mode: + print(f"############ {i18n('预测语义Token')} ############") + pred_semantic_list, idx_list = self.t2s_model.model.infer_panel( + all_phoneme_ids, + all_phoneme_lens, + prompt, + all_bert_features, + # prompt_phone_len=ph_offset, + top_k=top_k, + top_p=top_p, + temperature=temperature, + early_stop_num=self.configs.hz * self.configs.max_sec, + max_len=max_len, + repetition_penalty=repetition_penalty, + ) + t4 = time.perf_counter() + t_34 += t4 - t3 refer_audio_spec = [] if self.is_v2pro: @@ -1237,82 +1268,201 @@ class TTS: if self.is_v2pro: sv_emb.append(self.sv_model.compute_embedding3(audio_tensor)) - batch_audio_fragment = [] + batch_audio_fragment = [] - # ## vits并行推理 method 1 - # pred_semantic_list = [item[-idx:] for item, idx in zip(pred_semantic_list, idx_list)] - # pred_semantic_len = torch.LongTensor([item.shape[0] for item in pred_semantic_list]).to(self.configs.device) - # pred_semantic = self.batch_sequences(pred_semantic_list, axis=0, pad_value=0).unsqueeze(0) - # max_len = 0 - # for i in range(0, len(batch_phones)): - # max_len = max(max_len, batch_phones[i].shape[-1]) - # batch_phones = self.batch_sequences(batch_phones, axis=0, pad_value=0, max_length=max_len) - # batch_phones = batch_phones.to(self.configs.device) - # batch_audio_fragment = (self.vits_model.batched_decode( - # pred_semantic, pred_semantic_len, batch_phones, batch_phones_len,refer_audio_spec - # )) - print(f"############ {i18n('合成音频')} ############") - if not self.configs.use_vocoder: - if speed_factor == 1.0: - print(f"{i18n('并行合成中')}...") - # ## vits并行推理 method 2 - pred_semantic_list = [item[-idx:] for item, idx in zip(pred_semantic_list, idx_list)] - upsample_rate = math.prod(self.vits_model.upsample_rates) - audio_frag_idx = [ - pred_semantic_list[i].shape[0] * 2 * upsample_rate - for i in range(0, len(pred_semantic_list)) - ] - audio_frag_end_idx = [sum(audio_frag_idx[: i + 1]) for i in range(0, len(audio_frag_idx))] - all_pred_semantic = ( - torch.cat(pred_semantic_list).unsqueeze(0).unsqueeze(0).to(self.configs.device) - ) - _batch_phones = torch.cat(batch_phones).unsqueeze(0).to(self.configs.device) - if self.is_v2pro != True: - _batch_audio_fragment = self.vits_model.decode( - all_pred_semantic, _batch_phones, refer_audio_spec, speed=speed_factor - ).detach()[0, 0, :] - else: - _batch_audio_fragment = self.vits_model.decode( - all_pred_semantic, _batch_phones, refer_audio_spec, speed=speed_factor, sv_emb=sv_emb - ).detach()[0, 0, :] - audio_frag_end_idx.insert(0, 0) - batch_audio_fragment = [ - _batch_audio_fragment[audio_frag_end_idx[i - 1] : audio_frag_end_idx[i]] - for i in range(1, len(audio_frag_end_idx)) - ] - else: - # ## vits串行推理 - for i, idx in enumerate(tqdm(idx_list)): - phones = batch_phones[i].unsqueeze(0).to(self.configs.device) - _pred_semantic = ( - pred_semantic_list[i][-idx:].unsqueeze(0).unsqueeze(0) - ) # .unsqueeze(0)#mq要多unsqueeze一次 - if self.is_v2pro != True: - audio_fragment = self.vits_model.decode( - _pred_semantic, phones, refer_audio_spec, speed=speed_factor - ).detach()[0, 0, :] - else: - audio_fragment = self.vits_model.decode( - _pred_semantic, phones, refer_audio_spec, speed=speed_factor, sv_emb=sv_emb - ).detach()[0, 0, :] - batch_audio_fragment.append(audio_fragment) ###试试重建不带上prompt部分 - else: - if parallel_infer: - print(f"{i18n('并行合成中')}...") - audio_fragments = self.using_vocoder_synthesis_batched_infer( - idx_list, pred_semantic_list, batch_phones, speed=speed_factor, sample_steps=sample_steps - ) - batch_audio_fragment.extend(audio_fragments) - else: - for i, idx in enumerate(tqdm(idx_list)): - phones = batch_phones[i].unsqueeze(0).to(self.configs.device) - _pred_semantic = ( - pred_semantic_list[i][-idx:].unsqueeze(0).unsqueeze(0) - ) # .unsqueeze(0)#mq要多unsqueeze一次 - audio_fragment = self.using_vocoder_synthesis( - _pred_semantic, phones, speed=speed_factor, sample_steps=sample_steps + # ## vits并行推理 method 1 + # pred_semantic_list = [item[-idx:] for item, idx in zip(pred_semantic_list, idx_list)] + # pred_semantic_len = torch.LongTensor([item.shape[0] for item in pred_semantic_list]).to(self.configs.device) + # pred_semantic = self.batch_sequences(pred_semantic_list, axis=0, pad_value=0).unsqueeze(0) + # max_len = 0 + # for i in range(0, len(batch_phones)): + # max_len = max(max_len, batch_phones[i].shape[-1]) + # batch_phones = self.batch_sequences(batch_phones, axis=0, pad_value=0, max_length=max_len) + # batch_phones = batch_phones.to(self.configs.device) + # batch_audio_fragment = (self.vits_model.batched_decode( + # pred_semantic, pred_semantic_len, batch_phones, batch_phones_len,refer_audio_spec + # )) + print(f"############ {i18n('合成音频')} ############") + if not self.configs.use_vocoder: + if speed_factor == 1.0: + print(f"{i18n('并行合成中')}...") + # ## vits并行推理 method 2 + pred_semantic_list = [item[-idx:] for item, idx in zip(pred_semantic_list, idx_list)] + upsample_rate = math.prod(self.vits_model.upsample_rates) + audio_frag_idx = [ + pred_semantic_list[i].shape[0] * 2 * upsample_rate + for i in range(0, len(pred_semantic_list)) + ] + audio_frag_end_idx = [sum(audio_frag_idx[: i + 1]) for i in range(0, len(audio_frag_idx))] + all_pred_semantic = ( + torch.cat(pred_semantic_list).unsqueeze(0).unsqueeze(0).to(self.configs.device) ) - batch_audio_fragment.append(audio_fragment) + _batch_phones = torch.cat(batch_phones).unsqueeze(0).to(self.configs.device) + if self.is_v2pro != True: + _batch_audio_fragment = self.vits_model.decode( + all_pred_semantic, _batch_phones, refer_audio_spec, speed=speed_factor + ).detach()[0, 0, :] + else: + _batch_audio_fragment = self.vits_model.decode( + all_pred_semantic, _batch_phones, refer_audio_spec, speed=speed_factor, sv_emb=sv_emb + ).detach()[0, 0, :] + audio_frag_end_idx.insert(0, 0) + batch_audio_fragment = [ + _batch_audio_fragment[audio_frag_end_idx[i - 1] : audio_frag_end_idx[i]] + for i in range(1, len(audio_frag_end_idx)) + ] + else: + # ## vits串行推理 + for i, idx in enumerate(tqdm(idx_list)): + phones = batch_phones[i].unsqueeze(0).to(self.configs.device) + _pred_semantic = ( + pred_semantic_list[i][-idx:].unsqueeze(0).unsqueeze(0) + ) # .unsqueeze(0)#mq要多unsqueeze一次 + if self.is_v2pro != True: + audio_fragment = self.vits_model.decode( + _pred_semantic, phones, refer_audio_spec, speed=speed_factor + ).detach()[0, 0, :] + else: + audio_fragment = self.vits_model.decode( + _pred_semantic, phones, refer_audio_spec, speed=speed_factor, sv_emb=sv_emb + ).detach()[0, 0, :] + batch_audio_fragment.append(audio_fragment) ###试试重建不带上prompt部分 + else: + if parallel_infer: + print(f"{i18n('并行合成中')}...") + audio_fragments = self.using_vocoder_synthesis_batched_infer( + idx_list, pred_semantic_list, batch_phones, speed=speed_factor, sample_steps=sample_steps + ) + batch_audio_fragment.extend(audio_fragments) + else: + for i, idx in enumerate(tqdm(idx_list)): + phones = batch_phones[i].unsqueeze(0).to(self.configs.device) + _pred_semantic = ( + pred_semantic_list[i][-idx:].unsqueeze(0).unsqueeze(0) + ) # .unsqueeze(0)#mq要多unsqueeze一次 + audio_fragment = self.using_vocoder_synthesis( + _pred_semantic, phones, speed=speed_factor, sample_steps=sample_steps + ) + batch_audio_fragment.append(audio_fragment) + + else: + refer_audio_spec: torch.Tensor = [ + item.to(dtype=self.precision, device=self.configs.device) + for item in self.prompt_cache["refer_spec"] + ] + semantic_token_generator =self.t2s_model.model.infer_panel( + all_phoneme_ids[0].unsqueeze(0), + all_phoneme_lens, + prompt, + all_bert_features[0].unsqueeze(0), + # prompt_phone_len=ph_offset, + top_k=top_k, + top_p=top_p, + temperature=temperature, + early_stop_num=self.configs.hz * self.configs.max_sec, + max_len=max_len, + repetition_penalty=repetition_penalty, + streaming_mode=True, + chunk_length=chunk_length, + ) + t4 = time.perf_counter() + t_34 += t4 - t3 + phones = batch_phones[0].unsqueeze(0).to(self.configs.device) + is_first_chunk = True + + if not self.configs.use_vocoder: + if speed_factor == 1.0: + upsample_rate = math.prod(self.vits_model.upsample_rates)*(2 if self.vits_model.semantic_frame_rate == "25hz" else 1) + else: + upsample_rate = math.prod(self.vits_model.upsample_rates)*((2 if self.vits_model.semantic_frame_rate == "25hz" else 1)/speed_factor) + else: + if speed_factor == 1.0: + upsample_rate = self.vocoder_configs["upsample_rate"]*(3.875 if self.configs.version == "v3" else 4) + else: + upsample_rate = self.vocoder_configs["upsample_rate"]*((3.875 if self.configs.version == "v3" else 4)/speed_factor) + + last_audio_chunk = None + last_tokens = None + previous_tokens = [] + overlap_size = math.ceil(overlap_length*upsample_rate) + for semantic_tokens, is_final in semantic_token_generator: + if semantic_tokens is None and last_audio_chunk is not None: + yield self.audio_postprocess( + [[last_audio_chunk[-overlap_size:]]], + output_sr, + None, + speed_factor, + False, + 0.0, + super_sampling if self.configs.use_vocoder and self.configs.version == "v3" else False, + ) + continue + + _semantic_tokens = semantic_tokens + # if is_first_chunk: + # _semantic_tokens = torch.cat([torch.ones((1,overlap_length), dtype=torch.long, device=self.configs.device)*self.configs.mute_tokens[self.configs.version], _semantic_tokens], dim=-1) + # else: + # _semantic_tokens = torch.cat([last_tokens[:, -overlap_length:], _semantic_tokens], dim=-1) + # # _semantic_tokens = torch.cat(previous_tokens+[_semantic_tokens,], dim=-1) + + previous_tokens.append(semantic_tokens) + + _semantic_tokens = torch.cat(previous_tokens, dim=-1) + + + # last_tokens = semantic_tokens + + # print(f"_semantic_tokens shape:{_semantic_tokens.shape}") + + + if not self.configs.use_vocoder: + audio_chunk = self.vits_model.decode( + _semantic_tokens.unsqueeze(0), + phones, refer_audio_spec, + speed=speed_factor, + result_length=semantic_tokens.shape[-1]+overlap_length if not is_first_chunk else None + # result_length=chunk_length if not is_first_chunk else None + ).detach()[0, 0, :] + else: + audio_chunk = self.using_vocoder_synthesis( + _semantic_tokens.unsqueeze(0), phones, + speed=speed_factor, sample_steps=sample_steps, + result_length = semantic_tokens.shape[-1]+overlap_length if not is_first_chunk else None + ) + + + + # if is_first_chunk: + # audio_chunk = audio_chunk[overlap_size:] + # # is_first_chunk = False + + audio_chunk_ = audio_chunk + if is_first_chunk and not is_final: + is_first_chunk = False + audio_chunk_ = audio_chunk_[:-overlap_size] + elif is_first_chunk and is_final: + is_first_chunk = False + elif not is_first_chunk and not is_final: + audio_chunk_ = self.sola_algorithm([last_audio_chunk, audio_chunk_], overlap_size) + audio_chunk_ = ( + audio_chunk_[last_audio_chunk.shape[0]-overlap_size:-overlap_size] if not is_final \ + else audio_chunk_[last_audio_chunk.shape[0]-overlap_size:] + ) + # audio_chunk_ = audio_chunk_[:-overlap_size] if not is_final else audio_chunk_ + + last_audio_chunk = audio_chunk + yield self.audio_postprocess( + [[audio_chunk_]], + output_sr, + None, + speed_factor, + False, + 0.0, + super_sampling if self.configs.use_vocoder and self.configs.version == "v3" else False, + ) + print(f"first_package_delay: {time.perf_counter()-t0:.3f}") + + yield output_sr, np.zeros(int(output_sr*fragment_interval), dtype=np.int16) t5 = time.perf_counter() t_45 += t5 - t4 @@ -1327,17 +1477,18 @@ class TTS: fragment_interval, super_sampling if self.configs.use_vocoder and self.configs.version == "v3" else False, ) + elif streaming_mode:... else: audio.append(batch_audio_fragment) if self.stop_flag: - yield 16000, np.zeros(int(16000), dtype=np.int16) + yield output_sr, np.zeros(int(output_sr), dtype=np.int16) return - if not return_fragment: + if not (return_fragment or streaming_mode): print("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t_34, t_45)) if len(audio) == 0: - yield 16000, np.zeros(int(16000), dtype=np.int16) + yield output_sr, np.zeros(int(output_sr), dtype=np.int16) return yield self.audio_postprocess( audio, @@ -1384,16 +1535,17 @@ class TTS: fragment_interval: float = 0.3, super_sampling: bool = False, ) -> Tuple[int, np.ndarray]: - zero_wav = torch.zeros( - int(self.configs.sampling_rate * fragment_interval), dtype=self.precision, device=self.configs.device - ) + if fragment_interval>0: + zero_wav = torch.zeros( + int(self.configs.sampling_rate * fragment_interval), dtype=self.precision, device=self.configs.device + ) for i, batch in enumerate(audio): for j, audio_fragment in enumerate(batch): max_audio = torch.abs(audio_fragment).max() # 简单防止16bit爆音 if max_audio > 1: audio_fragment /= max_audio - audio_fragment: torch.Tensor = torch.cat([audio_fragment, zero_wav], dim=0) + audio_fragment: torch.Tensor = torch.cat([audio_fragment, zero_wav], dim=0) if fragment_interval>0 else audio_fragment audio[i][j] = audio_fragment if split_bucket: @@ -1413,12 +1565,12 @@ class TTS: max_audio = np.abs(audio).max() if max_audio > 1: audio /= max_audio + audio = (audio * 32768).astype(np.int16) t2 = time.perf_counter() print(f"超采样用时:{t2 - t1:.3f}s") else: - audio = audio.cpu().numpy() - - audio = (audio * 32768).astype(np.int16) + audio = audio.float() * 32768 + audio = audio.to(dtype=torch.int16).cpu().numpy() # try: # if speed_factor != 1.0: @@ -1429,7 +1581,7 @@ class TTS: return sr, audio def using_vocoder_synthesis( - self, semantic_tokens: torch.Tensor, phones: torch.Tensor, speed: float = 1.0, sample_steps: int = 32 + self, semantic_tokens: torch.Tensor, phones: torch.Tensor, speed: float = 1.0, sample_steps: int = 32, result_length:int=None ): prompt_semantic_tokens = self.prompt_cache["prompt_semantic"].unsqueeze(0).unsqueeze(0).to(self.configs.device) prompt_phones = torch.LongTensor(self.prompt_cache["phones"]).unsqueeze(0).to(self.configs.device) @@ -1464,7 +1616,7 @@ class TTS: chunk_len = T_chunk - T_min mel2 = mel2.to(self.precision) - fea_todo, ge = self.vits_model.decode_encp(semantic_tokens, phones, refer_audio_spec, ge, speed) + fea_todo, ge = self.vits_model.decode_encp(semantic_tokens, phones, refer_audio_spec, ge, speed, result_length=result_length) cfm_resss = [] idx = 0 @@ -1487,7 +1639,7 @@ class TTS: cfm_res = torch.cat(cfm_resss, 2) cfm_res = denorm_spec(cfm_res) - with torch.inference_mode(): + with torch.no_grad(): wav_gen = self.vocoder(cfm_res) audio = wav_gen[0][0] # .cpu().detach().numpy() diff --git a/GPT_SoVITS/module/models.py b/GPT_SoVITS/module/models.py index 1c8e662f..2d63ecbd 100644 --- a/GPT_SoVITS/module/models.py +++ b/GPT_SoVITS/module/models.py @@ -209,7 +209,7 @@ class TextEncoder(nn.Module): self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1) - def forward(self, y, y_lengths, text, text_lengths, ge, speed=1, test=None): + def forward(self, y, y_lengths, text, text_lengths, ge, speed=1, test=None, result_length:int=None): y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(y.dtype) y = self.ssl_proj(y * y_mask) * y_mask @@ -223,6 +223,11 @@ class TextEncoder(nn.Module): text = self.encoder_text(text * text_mask, text_mask) y = self.mrte(y, y_mask, text, text_mask, ge) y = self.encoder2(y * y_mask, y_mask) + + if result_length is not None: + y = y[:, :, -result_length:] + y_mask = y_mask[:, :, -result_length:] + if speed != 1: y = F.interpolate(y, size=int(y.shape[-1] / speed) + 1, mode="linear") y_mask = F.interpolate(y_mask, size=y.shape[-1], mode="nearest") @@ -958,7 +963,7 @@ class SynthesizerTrn(nn.Module): return o, y_mask, (z, z_p, m_p, logs_p) @torch.no_grad() - def decode(self, codes, text, refer, noise_scale=0.5, speed=1, sv_emb=None): + def decode(self, codes, text, refer, noise_scale=0.5, speed=1, sv_emb=None, result_length:int=None): def get_ge(refer, sv_emb): ge = None if refer is not None: @@ -989,6 +994,7 @@ class SynthesizerTrn(nn.Module): quantized = self.quantizer.decode(codes) if self.semantic_frame_rate == "25hz": quantized = F.interpolate(quantized, size=int(quantized.shape[-1] * 2), mode="nearest") + result_length = (2*result_length) if result_length is not None else None x, m_p, logs_p, y_mask = self.enc_p( quantized, y_lengths, @@ -996,7 +1002,7 @@ class SynthesizerTrn(nn.Module): text_lengths, self.ge_to512(ge.transpose(2, 1)).transpose(2, 1) if self.is_v2pro else ge, speed, - ) + , result_length=result_length) z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale z = self.flow(z_p, y_mask, g=ge, reverse=True) @@ -1242,7 +1248,7 @@ class SynthesizerTrnV3(nn.Module): return cfm_loss @torch.no_grad() - def decode_encp(self, codes, text, refer, ge=None, speed=1): + def decode_encp(self, codes, text, refer, ge=None, speed=1, result_length:int=None): # print(2333333,refer.shape) # ge=None if ge == None: @@ -1250,17 +1256,21 @@ class SynthesizerTrnV3(nn.Module): refer_mask = torch.unsqueeze(commons.sequence_mask(refer_lengths, refer.size(2)), 1).to(refer.dtype) ge = self.ref_enc(refer[:, :704] * refer_mask, refer_mask) y_lengths = torch.LongTensor([int(codes.size(2) * 2)]).to(codes.device) + if speed == 1: - sizee = int(codes.size(2) * (3.875 if self.version == "v3" else 4)) + sizee = int((codes.size(2) if result_length is None else result_length) * (3.875 if self.version=="v3"else 4)) else: - sizee = int(codes.size(2) * (3.875 if self.version == "v3" else 4) / speed) + 1 + sizee = int((codes.size(2) if result_length is None else result_length) * (3.875 if self.version=="v3"else 4) / speed) + 1 y_lengths1 = torch.LongTensor([sizee]).to(codes.device) text_lengths = torch.LongTensor([text.size(-1)]).to(text.device) quantized = self.quantizer.decode(codes) if self.semantic_frame_rate == "25hz": quantized = F.interpolate(quantized, scale_factor=2, mode="nearest") ##BCT - x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge, speed) + result_length = result_length * 2 if result_length is not None else None + + + x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge, speed, result_length=result_length) fea = self.bridge(x) fea = F.interpolate(fea, scale_factor=(1.875 if self.version == "v3" else 2), mode="nearest") ##BCT ####more wn paramter to learn mel diff --git a/api_v2.py b/api_v2.py index 5947df53..6dcec851 100644 --- a/api_v2.py +++ b/api_v2.py @@ -40,7 +40,10 @@ POST: "parallel_infer": True, # bool. whether to use parallel inference. "repetition_penalty": 1.35, # float. repetition penalty for T2S model. "sample_steps": 32, # int. number of sampling steps for VITS model V3. - "super_sampling": False # bool. whether to use super-sampling for audio when using VITS model V3. + "super_sampling": False, # bool. whether to use super-sampling for audio when using VITS model V3. + "overlap_length": 2, # int. overlap length of semantic tokens for streaming mode. + "chunk_length: 24, # int. chunk length of semantic tokens for streaming mode. (affects audio chunk size) + "return_fragment": False, # bool. step by step return the audio fragment. (old version of streaming mode) } ``` @@ -170,6 +173,9 @@ class TTS_Request(BaseModel): repetition_penalty: float = 1.35 sample_steps: int = 32 super_sampling: bool = False + overlap_length: int = 2 + chunk_length: int = 24 + return_fragment: bool = False ### modify from https://github.com/RVC-Boss/GPT-SoVITS/pull/894/files @@ -325,7 +331,10 @@ async def tts_handle(req: dict): "parallel_infer": True, # bool.(optional) whether to use parallel inference. "repetition_penalty": 1.35 # float.(optional) repetition penalty for T2S model. "sample_steps": 32, # int. number of sampling steps for VITS model V3. - "super_sampling": False, # bool. whether to use super-sampling for audio when using VITS model V3. + "super_sampling": False, # bool. whether to use super-sampling for audio when using VITS model V3. + "overlap_length": 2, # int. overlap length of semantic tokens for streaming mode. + "chunk_length: 24, # int. chunk length of semantic tokens for streaming mode. (affects audio chunk size) + "return_fragment": False, # bool. step by step return the audio fragment. (old version of streaming mode) } returns: StreamingResponse: audio stream response. @@ -339,8 +348,10 @@ async def tts_handle(req: dict): if check_res is not None: return check_res - if streaming_mode or return_fragment: - req["return_fragment"] = True + req["streaming_mode"] = streaming_mode + req["return_fragment"] = return_fragment + streaming_mode = streaming_mode or return_fragment + try: tts_generator = tts_pipeline.run(req) @@ -404,6 +415,9 @@ async def tts_get_endpoint( repetition_penalty: float = 1.35, sample_steps: int = 32, super_sampling: bool = False, + overlap_length: int = 2, + chunk_length: int = 24, + return_fragment: bool = False, ): req = { "text": text, @@ -428,6 +442,9 @@ async def tts_get_endpoint( "repetition_penalty": float(repetition_penalty), "sample_steps": int(sample_steps), "super_sampling": super_sampling, + "overlap_length": int(overlap_length), + "chunk_length": int(chunk_length), + "return_fragment": return_fragment, } return await tts_handle(req) From 9ff381b5191715a95ef98b816c2d994c28ff773a Mon Sep 17 00:00:00 2001 From: ChasonJiang <1440499136@qq.com> Date: Mon, 5 May 2025 20:16:59 +0800 Subject: [PATCH 02/10] =?UTF-8?q?=E6=B8=85=E7=90=86=E6=97=A0=E7=94=A8?= =?UTF-8?q?=E4=BB=A3=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- GPT_SoVITS/TTS_infer_pack/TTS.py | 20 ++------------------ 1 file changed, 2 insertions(+), 18 deletions(-) diff --git a/GPT_SoVITS/TTS_infer_pack/TTS.py b/GPT_SoVITS/TTS_infer_pack/TTS.py index 32a9c4dd..6c39bf16 100644 --- a/GPT_SoVITS/TTS_infer_pack/TTS.py +++ b/GPT_SoVITS/TTS_infer_pack/TTS.py @@ -1399,22 +1399,12 @@ class TTS: continue _semantic_tokens = semantic_tokens - # if is_first_chunk: - # _semantic_tokens = torch.cat([torch.ones((1,overlap_length), dtype=torch.long, device=self.configs.device)*self.configs.mute_tokens[self.configs.version], _semantic_tokens], dim=-1) - # else: - # _semantic_tokens = torch.cat([last_tokens[:, -overlap_length:], _semantic_tokens], dim=-1) - # # _semantic_tokens = torch.cat(previous_tokens+[_semantic_tokens,], dim=-1) previous_tokens.append(semantic_tokens) _semantic_tokens = torch.cat(previous_tokens, dim=-1) - # last_tokens = semantic_tokens - - # print(f"_semantic_tokens shape:{_semantic_tokens.shape}") - - if not self.configs.use_vocoder: audio_chunk = self.vits_model.decode( _semantic_tokens.unsqueeze(0), @@ -1429,13 +1419,7 @@ class TTS: speed=speed_factor, sample_steps=sample_steps, result_length = semantic_tokens.shape[-1]+overlap_length if not is_first_chunk else None ) - - - # if is_first_chunk: - # audio_chunk = audio_chunk[overlap_size:] - # # is_first_chunk = False - audio_chunk_ = audio_chunk if is_first_chunk and not is_final: is_first_chunk = False @@ -1448,7 +1432,7 @@ class TTS: audio_chunk_[last_audio_chunk.shape[0]-overlap_size:-overlap_size] if not is_final \ else audio_chunk_[last_audio_chunk.shape[0]-overlap_size:] ) - # audio_chunk_ = audio_chunk_[:-overlap_size] if not is_final else audio_chunk_ + last_audio_chunk = audio_chunk yield self.audio_postprocess( @@ -1460,7 +1444,7 @@ class TTS: 0.0, super_sampling if self.configs.use_vocoder and self.configs.version == "v3" else False, ) - print(f"first_package_delay: {time.perf_counter()-t0:.3f}") + # print(f"first_package_delay: {time.perf_counter()-t0:.3f}") yield output_sr, np.zeros(int(output_sr*fragment_interval), dtype=np.int16) From 0825ae80e10a4837aa230abd208bee33c47198c1 Mon Sep 17 00:00:00 2001 From: ChasonJiang <1440499136@qq.com> Date: Tue, 24 Jun 2025 20:43:46 +0800 Subject: [PATCH 03/10] modified: GPT_SoVITS/AR/models/t2s_model.py modified: GPT_SoVITS/TTS_infer_pack/TTS.py modified: GPT_SoVITS/module/models.py --- GPT_SoVITS/AR/models/t2s_model.py | 42 ++++++++++++- GPT_SoVITS/TTS_infer_pack/TTS.py | 98 +++++++++++++++++++++++-------- GPT_SoVITS/module/models.py | 46 ++++++++++----- 3 files changed, 147 insertions(+), 39 deletions(-) diff --git a/GPT_SoVITS/AR/models/t2s_model.py b/GPT_SoVITS/AR/models/t2s_model.py index 4e672503..806729d0 100644 --- a/GPT_SoVITS/AR/models/t2s_model.py +++ b/GPT_SoVITS/AR/models/t2s_model.py @@ -826,6 +826,13 @@ class Text2SemanticDecoder(nn.Module): chunk_length: int = 24, **kwargs, ): + mute_emb_sim_matrix = kwargs.get("mute_emb_sim_matrix", None) + sim_thershold = kwargs.get("sim_thershold", 0.3) + min_chunk_len = kwargs.get("min_chunk_len", 12) + limited_chunk_len = kwargs.get("limited_chunk_len", False) + only_for_the_first_chunk = kwargs.get("only_for_the_first_chunk", True) + + x = self.ar_text_embedding(x) x = x + self.bert_proj(bert_feature.transpose(1, 2)) x = self.ar_text_position(x) @@ -877,6 +884,7 @@ class Text2SemanticDecoder(nn.Module): .to(device=x.device, dtype=torch.bool) ) + is_yield = False token_counter = 0 for idx in tqdm(range(1500)): token_counter+=1 @@ -921,9 +929,39 @@ class Text2SemanticDecoder(nn.Module): yield (y[:, -token_counter:]) if token_counter!= 0 else None, True break - if streaming_mode and token_counter == chunk_length: + # if streaming_mode and (mute_emb_sim_matrix is not None) and (token_counter > min_chunk_len): + # sim = mute_emb_sim_matrix[y[0,-1]] + # if sim >= sim_thershold: is_yield = True + # elif streaming_mode and (mute_emb_sim_matrix is None): + # is_yield = token_counter == chunk_length + + # if streaming_mode and is_yield: + # is_yield = False + # yield y[:, -token_counter:], False + # token_counter = 0 + + if streaming_mode and (mute_emb_sim_matrix is not None) and (token_counter > min_chunk_len): + last_sim = mute_emb_sim_matrix[y[0,-1]] + + if (not limited_chunk_len) and last_sim >= sim_thershold: + yield y[:, -token_counter:], False + token_counter = 0 + # if is_first_package: is_first_package = False + + elif limited_chunk_len and token_counter == chunk_length: + # is_first_package = False + limited_chunk_len = False if only_for_the_first_chunk else limited_chunk_len + sim = mute_emb_sim_matrix[y[0,-(token_counter-min_chunk_len):]] + # print(f"sim:{sim}") + i = chunk_length-(sim.argmax()+min_chunk_len+1) + token_counter = i + yield y[:, -chunk_length:-i] if i!= 0 else y[:, -chunk_length:], False + + + elif streaming_mode and (mute_emb_sim_matrix is None): + is_yield = token_counter == chunk_length + yield y[:, -token_counter:], False token_counter = 0 - yield y[:, -chunk_length:], False ####################### update next step ################################### diff --git a/GPT_SoVITS/TTS_infer_pack/TTS.py b/GPT_SoVITS/TTS_infer_pack/TTS.py index 6c39bf16..d2a2d3ce 100644 --- a/GPT_SoVITS/TTS_infer_pack/TTS.py +++ b/GPT_SoVITS/TTS_infer_pack/TTS.py @@ -281,6 +281,7 @@ class TTS_Config: "v3" : 486, "v4" : 486, } + mute_emb_sim_matrix: torch.Tensor = None # "all_zh",#全部按中文识别 # "en",#全部按英文识别#######不变 # "all_ja",#全部按日文识别 @@ -604,6 +605,11 @@ class TTS: if self.configs.is_half and str(self.configs.device) != "cpu": self.t2s_model = self.t2s_model.half() + codebook = t2s_model.model.ar_audio_embedding.weight.clone() + mute_emb = codebook[self.configs.mute_tokens[self.configs.version]].unsqueeze(0) + sim_matrix = F.cosine_similarity(mute_emb.float(), codebook.float(), dim=-1) + self.configs.mute_emb_sim_matrix = sim_matrix + def init_vocoder(self, version: str): if version == "v3": if self.vocoder is not None and self.vocoder.__class__.__name__ == "BigVGAN": @@ -1065,6 +1071,7 @@ class TTS: self.t2s_model.model.infer_panel = self.t2s_model.model.infer_panel_batch_infer else: self.t2s_model.model.infer_panel = self.t2s_model.model.infer_panel_naive_batched + # self.t2s_model.model.infer_panel = self.t2s_model.model.infer_panel_naive elif parallel_infer and streaming_mode: print(i18n("不支持同时开启并行推理和流式推理模式,已自动关闭并行推理模式")) parallel_infer = False @@ -1216,6 +1223,7 @@ class TTS: t_34 = 0.0 t_45 = 0.0 audio = [] + is_first_package = True output_sr = self.configs.sampling_rate if not self.configs.use_vocoder else self.vocoder_configs["sr"] for item in data: t3 = time.perf_counter() @@ -1299,13 +1307,14 @@ class TTS: ) _batch_phones = torch.cat(batch_phones).unsqueeze(0).to(self.configs.device) if self.is_v2pro != True: - _batch_audio_fragment = self.vits_model.decode( + _batch_audio_fragment, _, _ = self.vits_model.decode( all_pred_semantic, _batch_phones, refer_audio_spec, speed=speed_factor ).detach()[0, 0, :] else: _batch_audio_fragment = self.vits_model.decode( all_pred_semantic, _batch_phones, refer_audio_spec, speed=speed_factor, sv_emb=sv_emb - ).detach()[0, 0, :] + ) + _batch_audio_fragment = _batch_audio_fragment.detach()[0, 0, :] audio_frag_end_idx.insert(0, 0) batch_audio_fragment = [ _batch_audio_fragment[audio_frag_end_idx[i - 1] : audio_frag_end_idx[i]] @@ -1319,18 +1328,19 @@ class TTS: pred_semantic_list[i][-idx:].unsqueeze(0).unsqueeze(0) ) # .unsqueeze(0)#mq要多unsqueeze一次 if self.is_v2pro != True: - audio_fragment = self.vits_model.decode( + audio_fragment, _, _ = self.vits_model.decode( _pred_semantic, phones, refer_audio_spec, speed=speed_factor ).detach()[0, 0, :] else: audio_fragment = self.vits_model.decode( _pred_semantic, phones, refer_audio_spec, speed=speed_factor, sv_emb=sv_emb - ).detach()[0, 0, :] + ) + audio_fragment=audio_fragment.detach()[0, 0, :] batch_audio_fragment.append(audio_fragment) ###试试重建不带上prompt部分 else: if parallel_infer: print(f"{i18n('并行合成中')}...") - audio_fragments = self.using_vocoder_synthesis_batched_infer( + audio_fragments, y, y_mask = self.using_vocoder_synthesis_batched_infer( idx_list, pred_semantic_list, batch_phones, speed=speed_factor, sample_steps=sample_steps ) batch_audio_fragment.extend(audio_fragments) @@ -1340,7 +1350,7 @@ class TTS: _pred_semantic = ( pred_semantic_list[i][-idx:].unsqueeze(0).unsqueeze(0) ) # .unsqueeze(0)#mq要多unsqueeze一次 - audio_fragment = self.using_vocoder_synthesis( + audio_fragment, y, y_mask = self.using_vocoder_synthesis( _pred_semantic, phones, speed=speed_factor, sample_steps=sample_steps ) batch_audio_fragment.append(audio_fragment) @@ -1364,6 +1374,9 @@ class TTS: repetition_penalty=repetition_penalty, streaming_mode=True, chunk_length=chunk_length, + mute_emb_sim_matrix=self.configs.mute_emb_sim_matrix, + only_for_the_first_chunk=is_first_package, + limited_chunk_len=True ) t4 = time.perf_counter() t_34 += t4 - t3 @@ -1382,8 +1395,10 @@ class TTS: upsample_rate = self.vocoder_configs["upsample_rate"]*((3.875 if self.configs.version == "v3" else 4)/speed_factor) last_audio_chunk = None - last_tokens = None + # last_tokens = None + last_latent = None previous_tokens = [] + overlap_len = overlap_length overlap_size = math.ceil(overlap_length*upsample_rate) for semantic_tokens, is_final in semantic_token_generator: if semantic_tokens is None and last_audio_chunk is not None: @@ -1396,29 +1411,45 @@ class TTS: 0.0, super_sampling if self.configs.use_vocoder and self.configs.version == "v3" else False, ) - continue + break _semantic_tokens = semantic_tokens + print(f"semantic_tokens shape:{semantic_tokens.shape}") previous_tokens.append(semantic_tokens) _semantic_tokens = torch.cat(previous_tokens, dim=-1) + if not is_first_chunk and semantic_tokens.shape[-1] < 10: + overlap_len = overlap_length+(10-semantic_tokens.shape[-1]) + # overlap_size = math.ceil(overlap_len*upsample_rate) + else: + overlap_len = overlap_length + # overlap_size = math.ceil(overlap_length*upsample_rate) + if not self.configs.use_vocoder: - audio_chunk = self.vits_model.decode( + audio_chunk, latent, latent_mask = self.vits_model.decode( _semantic_tokens.unsqueeze(0), phones, refer_audio_spec, speed=speed_factor, - result_length=semantic_tokens.shape[-1]+overlap_length if not is_first_chunk else None + result_length=semantic_tokens.shape[-1]+overlap_len if not is_first_chunk else None, + overlap_frames=last_latent[:,:,-overlap_len*(2 if self.vits_model.semantic_frame_rate == "25hz" else 1):] \ + if last_latent is not None else None, # result_length=chunk_length if not is_first_chunk else None - ).detach()[0, 0, :] + ) + audio_chunk=audio_chunk.detach()[0, 0, :] else: - audio_chunk = self.using_vocoder_synthesis( + audio_chunk, latent, latent_mask = self.using_vocoder_synthesis( _semantic_tokens.unsqueeze(0), phones, speed=speed_factor, sample_steps=sample_steps, - result_length = semantic_tokens.shape[-1]+overlap_length if not is_first_chunk else None + result_length = semantic_tokens.shape[-1]+overlap_len if not is_first_chunk else None, + overlap_frames=last_latent[:,:,-overlap_len*(2 if self.vits_model.semantic_frame_rate == "25hz" else 1):] \ + if last_latent is not None else None, ) + + if overlap_len>overlap_length: + audio_chunk=audio_chunk[-int((overlap_length+semantic_tokens.shape[-1])*upsample_rate):] audio_chunk_ = audio_chunk if is_first_chunk and not is_final: @@ -1433,7 +1464,12 @@ class TTS: else audio_chunk_[last_audio_chunk.shape[0]-overlap_size:] ) + # audio_chunk_ = ( + # audio_chunk_[overlap_size:-overlap_size] if not is_final \ + # else audio_chunk_[overlap_size:] + # ) + last_latent = latent last_audio_chunk = audio_chunk yield self.audio_postprocess( [[audio_chunk_]], @@ -1444,8 +1480,12 @@ class TTS: 0.0, super_sampling if self.configs.use_vocoder and self.configs.version == "v3" else False, ) - # print(f"first_package_delay: {time.perf_counter()-t0:.3f}") + if is_first_package: + print(f"first_package_delay: {time.perf_counter()-t0:.3f}") + is_first_package = False + + yield output_sr, np.zeros(int(output_sr*fragment_interval), dtype=np.int16) t5 = time.perf_counter() @@ -1553,8 +1593,13 @@ class TTS: t2 = time.perf_counter() print(f"超采样用时:{t2 - t1:.3f}s") else: - audio = audio.float() * 32768 - audio = audio.to(dtype=torch.int16).cpu().numpy() + # audio = audio.float() * 32768 + # audio = audio.to(dtype=torch.int16).clamp(-32768, 32767).cpu().numpy() + + audio = audio.cpu().numpy() + + audio = (audio * 32768).astype(np.int16) + # try: # if speed_factor != 1.0: @@ -1565,7 +1610,7 @@ class TTS: return sr, audio def using_vocoder_synthesis( - self, semantic_tokens: torch.Tensor, phones: torch.Tensor, speed: float = 1.0, sample_steps: int = 32, result_length:int=None + self, semantic_tokens: torch.Tensor, phones: torch.Tensor, speed: float = 1.0, sample_steps: int = 32, result_length:int=None, overlap_frames:torch.Tensor=None ): prompt_semantic_tokens = self.prompt_cache["prompt_semantic"].unsqueeze(0).unsqueeze(0).to(self.configs.device) prompt_phones = torch.LongTensor(self.prompt_cache["phones"]).unsqueeze(0).to(self.configs.device) @@ -1574,7 +1619,7 @@ class TTS: raw_entry = raw_entry[0] refer_audio_spec = raw_entry.to(dtype=self.precision, device=self.configs.device) - fea_ref, ge = self.vits_model.decode_encp(prompt_semantic_tokens, prompt_phones, refer_audio_spec) + fea_ref, ge, y, y_mask = self.vits_model.decode_encp(prompt_semantic_tokens, prompt_phones, refer_audio_spec) ref_audio: torch.Tensor = self.prompt_cache["raw_audio"] ref_sr = self.prompt_cache["raw_sr"] ref_audio = ref_audio.to(self.configs.device).float() @@ -1600,7 +1645,7 @@ class TTS: chunk_len = T_chunk - T_min mel2 = mel2.to(self.precision) - fea_todo, ge = self.vits_model.decode_encp(semantic_tokens, phones, refer_audio_spec, ge, speed, result_length=result_length) + fea_todo, ge, y, y_mask = self.vits_model.decode_encp(semantic_tokens, phones, refer_audio_spec, ge, speed, result_length=result_length, overlap_frames=overlap_frames) cfm_resss = [] idx = 0 @@ -1627,7 +1672,7 @@ class TTS: wav_gen = self.vocoder(cfm_res) audio = wav_gen[0][0] # .cpu().detach().numpy() - return audio + return audio, y, y_mask def using_vocoder_synthesis_batched_infer( self, @@ -1644,7 +1689,7 @@ class TTS: raw_entry = raw_entry[0] refer_audio_spec = raw_entry.to(dtype=self.precision, device=self.configs.device) - fea_ref, ge = self.vits_model.decode_encp(prompt_semantic_tokens, prompt_phones, refer_audio_spec) + fea_ref, ge, y, y_mask = self.vits_model.decode_encp(prompt_semantic_tokens, prompt_phones, refer_audio_spec) ref_audio: torch.Tensor = self.prompt_cache["raw_audio"] ref_sr = self.prompt_cache["raw_sr"] ref_audio = ref_audio.to(self.configs.device).float() @@ -1682,7 +1727,7 @@ class TTS: semantic_tokens = ( semantic_tokens_list[i][-idx:].unsqueeze(0).unsqueeze(0) ) # .unsqueeze(0)#mq要多unsqueeze一次 - feat, _ = self.vits_model.decode_encp(semantic_tokens, phones, refer_audio_spec, ge, speed) + feat, _, y, y_mask = self.vits_model.decode_encp(semantic_tokens, phones, refer_audio_spec, ge, speed) feat_list.append(feat) feat_lens.append(feat.shape[2]) @@ -1742,7 +1787,7 @@ class TTS: audio_fragments.append(audio_fragment) audio = audio[feat_len * upsample_rate :] - return audio_fragments + return audio_fragments, y, y_mask def sola_algorithm( self, @@ -1766,6 +1811,13 @@ class TTS: window[: (overlap_len - idx)] * f2_[: (overlap_len - idx)] + window[(overlap_len - idx) :] * f1[-(overlap_len - idx) :] ) + + # window = torch.sin(torch.arange((overlap_len - idx), device=f1.device) * np.pi / (overlap_len - idx)) + # f2_[: (overlap_len - idx)] = ( + # window * f2_[: (overlap_len - idx)] + # + (1-window) * f1[-(overlap_len - idx) :] + # ) + audio_fragments[i + 1] = f2_ return torch.cat(audio_fragments, 0) diff --git a/GPT_SoVITS/module/models.py b/GPT_SoVITS/module/models.py index 2d63ecbd..2076960e 100644 --- a/GPT_SoVITS/module/models.py +++ b/GPT_SoVITS/module/models.py @@ -151,6 +151,8 @@ class DurationPredictor(nn.Module): return x * x_mask +HANN_WINDOW = {} + class TextEncoder(nn.Module): def __init__( self, @@ -209,7 +211,7 @@ class TextEncoder(nn.Module): self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1) - def forward(self, y, y_lengths, text, text_lengths, ge, speed=1, test=None, result_length:int=None): + def forward(self, y, y_lengths, text, text_lengths, ge, speed=1, test=None, result_length:int=None, overlap_frames:torch.Tensor=None): y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(y.dtype) y = self.ssl_proj(y * y_mask) * y_mask @@ -227,13 +229,29 @@ class TextEncoder(nn.Module): if result_length is not None: y = y[:, :, -result_length:] y_mask = y_mask[:, :, -result_length:] + + if overlap_frames is not None: + overlap_len = overlap_frames.shape[-1] + window = HANN_WINDOW.get(overlap_len, None) + if window is None: + HANN_WINDOW[overlap_len] = torch.hann_window(overlap_len*2, device=y.device, dtype=y.dtype) + window = HANN_WINDOW[overlap_len] + window = window.to(y.device) + y[:,:,:overlap_len] = ( + window[:overlap_len].view(1, 1, -1) * y[:,:,:overlap_len] + + window[overlap_len:].view(1, 1, -1) * overlap_frames + ) + y_ = y + y_mask_ = y_mask + + if speed != 1: y = F.interpolate(y, size=int(y.shape[-1] / speed) + 1, mode="linear") y_mask = F.interpolate(y_mask, size=y.shape[-1], mode="nearest") stats = self.proj(y) * y_mask m, logs = torch.split(stats, self.out_channels, dim=1) - return y, m, logs, y_mask + return y, m, logs, y_mask, y_, y_mask_ def extract_latent(self, x): x = self.ssl_proj(x) @@ -926,7 +944,7 @@ class SynthesizerTrn(nn.Module): if self.semantic_frame_rate == "25hz": quantized = F.interpolate(quantized, size=int(quantized.shape[-1] * 2), mode="nearest") - x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge512 if self.is_v2pro else ge) + x, m_p, logs_p, y_mask, _, _ = self.enc_p(quantized, y_lengths, text, text_lengths, ge512 if self.is_v2pro else ge) z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=ge) z_p = self.flow(z, y_mask, g=ge) @@ -954,7 +972,7 @@ class SynthesizerTrn(nn.Module): if self.semantic_frame_rate == "25hz": quantized = F.interpolate(quantized, size=int(quantized.shape[-1] * 2), mode="nearest") - x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge, test=test) + x, m_p, logs_p, y_mask, _, _ = self.enc_p(quantized, y_lengths, text, text_lengths, ge, test=test) z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale z = self.flow(z_p, y_mask, g=ge, reverse=True) @@ -963,7 +981,7 @@ class SynthesizerTrn(nn.Module): return o, y_mask, (z, z_p, m_p, logs_p) @torch.no_grad() - def decode(self, codes, text, refer, noise_scale=0.5, speed=1, sv_emb=None, result_length:int=None): + def decode(self, codes, text, refer, noise_scale=0.5, speed=1, sv_emb=None, result_length:int=None, overlap_frames:torch.Tensor=None): def get_ge(refer, sv_emb): ge = None if refer is not None: @@ -995,20 +1013,20 @@ class SynthesizerTrn(nn.Module): if self.semantic_frame_rate == "25hz": quantized = F.interpolate(quantized, size=int(quantized.shape[-1] * 2), mode="nearest") result_length = (2*result_length) if result_length is not None else None - x, m_p, logs_p, y_mask = self.enc_p( + x, m_p, logs_p, y_mask, y_, y_mask_ = self.enc_p( quantized, y_lengths, text, text_lengths, self.ge_to512(ge.transpose(2, 1)).transpose(2, 1) if self.is_v2pro else ge, speed, - , result_length=result_length) + , result_length=result_length, overlap_frames=overlap_frames) z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale z = self.flow(z_p, y_mask, g=ge, reverse=True) o = self.dec((z * y_mask)[:, :, :], g=ge) - return o + return o, y_, y_mask_ def extract_latent(self, x): ssl = self.ssl_proj(x) @@ -1232,7 +1250,7 @@ class SynthesizerTrnV3(nn.Module): ssl = self.ssl_proj(ssl) quantized, codes, commit_loss, quantized_list = self.quantizer(ssl, layers=[0]) quantized = F.interpolate(quantized, scale_factor=2, mode="nearest") ##BCT - x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge) + x, m_p, logs_p, y_mask, y_, y_mask_ = self.enc_p(quantized, y_lengths, text, text_lengths, ge) fea = self.bridge(x) fea = F.interpolate(fea, scale_factor=(1.875 if self.version == "v3" else 2), mode="nearest") ##BCT fea, y_mask_ = self.wns1( @@ -1248,7 +1266,7 @@ class SynthesizerTrnV3(nn.Module): return cfm_loss @torch.no_grad() - def decode_encp(self, codes, text, refer, ge=None, speed=1, result_length:int=None): + def decode_encp(self, codes, text, refer, ge=None, speed=1, result_length:int=None, overlap_frames:torch.Tensor=None): # print(2333333,refer.shape) # ge=None if ge == None: @@ -1270,12 +1288,12 @@ class SynthesizerTrnV3(nn.Module): result_length = result_length * 2 if result_length is not None else None - x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge, speed, result_length=result_length) + x, m_p, logs_p, y_mask, y_, y_mask_ = self.enc_p(quantized, y_lengths, text, text_lengths, ge, speed, result_length=result_length, overlap_frames=overlap_frames) fea = self.bridge(x) fea = F.interpolate(fea, scale_factor=(1.875 if self.version == "v3" else 2), mode="nearest") ##BCT ####more wn paramter to learn mel fea, y_mask_ = self.wns1(fea, y_lengths1, ge) - return fea, ge + return fea, ge, y_, y_mask_ def extract_latent(self, x): ssl = self.ssl_proj(x) @@ -1387,7 +1405,7 @@ class SynthesizerTrnV3b(nn.Module): ssl = self.ssl_proj(ssl) quantized, codes, commit_loss, quantized_list = self.quantizer(ssl, layers=[0]) quantized = F.interpolate(quantized, scale_factor=2, mode="nearest") ##BCT - x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge) + x, m_p, logs_p, y_mask, y_, y_mask_ = self.enc_p(quantized, y_lengths, text, text_lengths, ge) z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=ge) z_p = self.flow(z, y_mask, g=ge) z_slice, ids_slice = commons.rand_slice_segments(z, y_lengths, self.segment_size) @@ -1430,7 +1448,7 @@ class SynthesizerTrnV3b(nn.Module): quantized = self.quantizer.decode(codes) if self.semantic_frame_rate == "25hz": quantized = F.interpolate(quantized, scale_factor=2, mode="nearest") ##BCT - x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge) + x, m_p, logs_p, y_mask, y_, y_mask_ = self.enc_p(quantized, y_lengths, text, text_lengths, ge) fea = self.bridge(x) fea = F.interpolate(fea, scale_factor=1.875, mode="nearest") ##BCT ####more wn paramter to learn mel From d08214dd22e8837b4f9eb92f94a1d419806e182a Mon Sep 17 00:00:00 2001 From: ChasonJiang <1440499136@qq.com> Date: Tue, 1 Jul 2025 22:27:03 +0800 Subject: [PATCH 04/10] modified: GPT_SoVITS/TTS_infer_pack/TTS.py --- GPT_SoVITS/TTS_infer_pack/TTS.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/GPT_SoVITS/TTS_infer_pack/TTS.py b/GPT_SoVITS/TTS_infer_pack/TTS.py index d2a2d3ce..813117a2 100644 --- a/GPT_SoVITS/TTS_infer_pack/TTS.py +++ b/GPT_SoVITS/TTS_infer_pack/TTS.py @@ -1799,9 +1799,16 @@ class TTS: f2 = audio_fragments[i + 1] w1 = f1[-overlap_len:] w2 = f2[:overlap_len] - assert w1.shape == w2.shape - corr = F.conv1d(w1.view(1, 1, -1), w2.view(1, 1, -1), padding=w2.shape[-1] // 2).view(-1)[:-1] - idx = corr.argmax() + w2 = w2[-w2.shape[-1]//2:] + # assert w1.shape == w2.shape + corr = F.conv1d(w1.view(1, 1, -1), w2.view(1, 1, -1)).view(-1) + + squared_sum = F.conv1d(w1.view(1, 1, -1)**2, torch.ones_like(w2).view(1, 1, -1)).view(-1)+ 1e-8 + idx = (corr/squared_sum.sqrt()).argmax() + + print(f"seg_idx: {idx}") + + # idx = corr.argmax() f1_ = f1[: -(overlap_len - idx)] audio_fragments[i] = f1_ From af7b95bc9d4e586d769cbd58db1b7582eb07fc5a Mon Sep 17 00:00:00 2001 From: ChasonJiang <1440499136@qq.com> Date: Mon, 24 Nov 2025 18:52:35 +0800 Subject: [PATCH 05/10] modified: .gitignore modified: GPT_SoVITS/AR/models/t2s_model.py modified: GPT_SoVITS/TTS_infer_pack/TTS.py modified: GPT_SoVITS/module/models.py --- GPT_SoVITS/AR/models/t2s_model.py | 49 ++++++++++--------------------- GPT_SoVITS/TTS_infer_pack/TTS.py | 12 +++++--- GPT_SoVITS/module/models.py | 27 ++++++++++++----- 3 files changed, 42 insertions(+), 46 deletions(-) diff --git a/GPT_SoVITS/AR/models/t2s_model.py b/GPT_SoVITS/AR/models/t2s_model.py index 806729d0..ae1fcc3c 100644 --- a/GPT_SoVITS/AR/models/t2s_model.py +++ b/GPT_SoVITS/AR/models/t2s_model.py @@ -828,9 +828,7 @@ class Text2SemanticDecoder(nn.Module): ): mute_emb_sim_matrix = kwargs.get("mute_emb_sim_matrix", None) sim_thershold = kwargs.get("sim_thershold", 0.3) - min_chunk_len = kwargs.get("min_chunk_len", 12) - limited_chunk_len = kwargs.get("limited_chunk_len", False) - only_for_the_first_chunk = kwargs.get("only_for_the_first_chunk", True) + check_token_num = 2 x = self.ar_text_embedding(x) @@ -884,8 +882,8 @@ class Text2SemanticDecoder(nn.Module): .to(device=x.device, dtype=torch.bool) ) - is_yield = False token_counter = 0 + curr_ptr = prefix_len for idx in tqdm(range(1500)): token_counter+=1 if xy_attn_mask is not None: @@ -924,42 +922,25 @@ class Text2SemanticDecoder(nn.Module): print("bad zero prediction") # print(f"T2S Decoding EOS [{prefix_len} -> {y.shape[1]}]") if streaming_mode: - # y=y[:, :-1] - # res_len = (y.shape[1] - prefix_len)%chunk_length - yield (y[:, -token_counter:]) if token_counter!= 0 else None, True + yield y[:, curr_ptr:] if curr_ptr min_chunk_len): - # sim = mute_emb_sim_matrix[y[0,-1]] - # if sim >= sim_thershold: is_yield = True - # elif streaming_mode and (mute_emb_sim_matrix is None): - # is_yield = token_counter == chunk_length - # if streaming_mode and is_yield: - # is_yield = False - # yield y[:, -token_counter:], False - # token_counter = 0 + if streaming_mode and (mute_emb_sim_matrix is not None) and (token_counter >= chunk_length+check_token_num): + score = mute_emb_sim_matrix[y[0, curr_ptr:]] - sim_thershold + score[score<0]=-1 + score[:-1]=score[:-1]+score[1:] + argmax_idx = score.argmax() - if streaming_mode and (mute_emb_sim_matrix is not None) and (token_counter > min_chunk_len): - last_sim = mute_emb_sim_matrix[y[0,-1]] + if score[argmax_idx]>=0 and argmax_idx+1>=chunk_length: + print(f"\n\ncurr_ptr:{curr_ptr}") + yield y[:, curr_ptr:], False + token_counter -= argmax_idx+1 + curr_ptr += argmax_idx+1 - if (not limited_chunk_len) and last_sim >= sim_thershold: - yield y[:, -token_counter:], False - token_counter = 0 - # if is_first_package: is_first_package = False - elif limited_chunk_len and token_counter == chunk_length: - # is_first_package = False - limited_chunk_len = False if only_for_the_first_chunk else limited_chunk_len - sim = mute_emb_sim_matrix[y[0,-(token_counter-min_chunk_len):]] - # print(f"sim:{sim}") - i = chunk_length-(sim.argmax()+min_chunk_len+1) - token_counter = i - yield y[:, -chunk_length:-i] if i!= 0 else y[:, -chunk_length:], False - - - elif streaming_mode and (mute_emb_sim_matrix is None): - is_yield = token_counter == chunk_length + elif streaming_mode and (mute_emb_sim_matrix is None) and (token_counter >= chunk_length): + token_counter == chunk_length yield y[:, -token_counter:], False token_counter = 0 diff --git a/GPT_SoVITS/TTS_infer_pack/TTS.py b/GPT_SoVITS/TTS_infer_pack/TTS.py index 813117a2..2f097f6c 100644 --- a/GPT_SoVITS/TTS_infer_pack/TTS.py +++ b/GPT_SoVITS/TTS_infer_pack/TTS.py @@ -1365,7 +1365,6 @@ class TTS: all_phoneme_lens, prompt, all_bert_features[0].unsqueeze(0), - # prompt_phone_len=ph_offset, top_k=top_k, top_p=top_p, temperature=temperature, @@ -1375,8 +1374,6 @@ class TTS: streaming_mode=True, chunk_length=chunk_length, mute_emb_sim_matrix=self.configs.mute_emb_sim_matrix, - only_for_the_first_chunk=is_first_package, - limited_chunk_len=True ) t4 = time.perf_counter() t_34 += t4 - t3 @@ -1429,6 +1426,13 @@ class TTS: if not self.configs.use_vocoder: + token_padding_length = 0 + # token_padding_length = int(phones.shape[-1]*2)-_semantic_tokens.shape[-1] + # if token_padding_length>0: + # _semantic_tokens = F.pad(_semantic_tokens, (0, token_padding_length), "constant", 486) + # else: + # token_padding_length = 0 + audio_chunk, latent, latent_mask = self.vits_model.decode( _semantic_tokens.unsqueeze(0), phones, refer_audio_spec, @@ -1436,7 +1440,7 @@ class TTS: result_length=semantic_tokens.shape[-1]+overlap_len if not is_first_chunk else None, overlap_frames=last_latent[:,:,-overlap_len*(2 if self.vits_model.semantic_frame_rate == "25hz" else 1):] \ if last_latent is not None else None, - # result_length=chunk_length if not is_first_chunk else None + padding_length=token_padding_length ) audio_chunk=audio_chunk.detach()[0, 0, :] else: diff --git a/GPT_SoVITS/module/models.py b/GPT_SoVITS/module/models.py index 2076960e..6cb317f6 100644 --- a/GPT_SoVITS/module/models.py +++ b/GPT_SoVITS/module/models.py @@ -151,7 +151,7 @@ class DurationPredictor(nn.Module): return x * x_mask -HANN_WINDOW = {} +WINDOW = {} class TextEncoder(nn.Module): def __init__( @@ -211,7 +211,7 @@ class TextEncoder(nn.Module): self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1) - def forward(self, y, y_lengths, text, text_lengths, ge, speed=1, test=None, result_length:int=None, overlap_frames:torch.Tensor=None): + def forward(self, y, y_lengths, text, text_lengths, ge, speed=1, test=None, result_length:int=None, overlap_frames:torch.Tensor=None, padding_length:int=None): y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(y.dtype) y = self.ssl_proj(y * y_mask) * y_mask @@ -224,23 +224,33 @@ class TextEncoder(nn.Module): text = self.text_embedding(text).transpose(1, 2) text = self.encoder_text(text * text_mask, text_mask) y = self.mrte(y, y_mask, text, text_mask, ge) + + if padding_length is not None and padding_length!=0: + y = y[:, :, :-padding_length] + y_mask = y_mask[:, :, :-padding_length] + + y = self.encoder2(y * y_mask, y_mask) if result_length is not None: y = y[:, :, -result_length:] y_mask = y_mask[:, :, -result_length:] - + if overlap_frames is not None: overlap_len = overlap_frames.shape[-1] - window = HANN_WINDOW.get(overlap_len, None) + window = WINDOW.get(overlap_len, None) if window is None: - HANN_WINDOW[overlap_len] = torch.hann_window(overlap_len*2, device=y.device, dtype=y.dtype) - window = HANN_WINDOW[overlap_len] + # WINDOW[overlap_len] = torch.hann_window(overlap_len*2, device=y.device, dtype=y.dtype) + WINDOW[overlap_len] = torch.sin(torch.arange(overlap_len*2, device=y.device) * torch.pi / (overlap_len*2)) + window = WINDOW[overlap_len] + + window = window.to(y.device) y[:,:,:overlap_len] = ( window[:overlap_len].view(1, 1, -1) * y[:,:,:overlap_len] + window[overlap_len:].view(1, 1, -1) * overlap_frames ) + y_ = y y_mask_ = y_mask @@ -981,7 +991,7 @@ class SynthesizerTrn(nn.Module): return o, y_mask, (z, z_p, m_p, logs_p) @torch.no_grad() - def decode(self, codes, text, refer, noise_scale=0.5, speed=1, sv_emb=None, result_length:int=None, overlap_frames:torch.Tensor=None): + def decode(self, codes, text, refer, noise_scale=0.5, speed=1, sv_emb=None, result_length:int=None, overlap_frames:torch.Tensor=None, padding_length:int=None): def get_ge(refer, sv_emb): ge = None if refer is not None: @@ -1013,6 +1023,7 @@ class SynthesizerTrn(nn.Module): if self.semantic_frame_rate == "25hz": quantized = F.interpolate(quantized, size=int(quantized.shape[-1] * 2), mode="nearest") result_length = (2*result_length) if result_length is not None else None + padding_length = (2*padding_length) if padding_length is not None else None x, m_p, logs_p, y_mask, y_, y_mask_ = self.enc_p( quantized, y_lengths, @@ -1020,7 +1031,7 @@ class SynthesizerTrn(nn.Module): text_lengths, self.ge_to512(ge.transpose(2, 1)).transpose(2, 1) if self.is_v2pro else ge, speed, - , result_length=result_length, overlap_frames=overlap_frames) + , result_length=result_length, overlap_frames=overlap_frames, padding_length=padding_length) z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale z = self.flow(z_p, y_mask, g=ge, reverse=True) From 08d6ed0d8c5063822f23e3ed191ca8befeed8035 Mon Sep 17 00:00:00 2001 From: ChasonJiang <1440499136@qq.com> Date: Mon, 24 Nov 2025 20:47:32 +0800 Subject: [PATCH 06/10] modified: GPT_SoVITS/AR/models/t2s_model.py modified: GPT_SoVITS/TTS_infer_pack/TTS.py modified: GPT_SoVITS/module/models.py modified: api_v2.py --- GPT_SoVITS/AR/models/t2s_model.py | 6 +- GPT_SoVITS/TTS_infer_pack/TTS.py | 112 +++++++++++++----------------- GPT_SoVITS/module/models.py | 70 ++++++++++++++++--- api_v2.py | 12 ++-- 4 files changed, 118 insertions(+), 82 deletions(-) diff --git a/GPT_SoVITS/AR/models/t2s_model.py b/GPT_SoVITS/AR/models/t2s_model.py index ae1fcc3c..d72aa393 100644 --- a/GPT_SoVITS/AR/models/t2s_model.py +++ b/GPT_SoVITS/AR/models/t2s_model.py @@ -827,7 +827,7 @@ class Text2SemanticDecoder(nn.Module): **kwargs, ): mute_emb_sim_matrix = kwargs.get("mute_emb_sim_matrix", None) - sim_thershold = kwargs.get("sim_thershold", 0.3) + chunk_split_thershold = kwargs.get("chunk_split_thershold", 0.3) check_token_num = 2 @@ -927,9 +927,9 @@ class Text2SemanticDecoder(nn.Module): if streaming_mode and (mute_emb_sim_matrix is not None) and (token_counter >= chunk_length+check_token_num): - score = mute_emb_sim_matrix[y[0, curr_ptr:]] - sim_thershold + score = mute_emb_sim_matrix[y[0, curr_ptr:]] - chunk_split_thershold score[score<0]=-1 - score[:-1]=score[:-1]+score[1:] + score[:-1]=score[:-1]+score[1:] ##考虑连续两个token argmax_idx = score.argmax() if score[argmax_idx]>=0 and argmax_idx+1>=chunk_length: diff --git a/GPT_SoVITS/TTS_infer_pack/TTS.py b/GPT_SoVITS/TTS_infer_pack/TTS.py index 2f097f6c..a0c0d6ba 100644 --- a/GPT_SoVITS/TTS_infer_pack/TTS.py +++ b/GPT_SoVITS/TTS_infer_pack/TTS.py @@ -278,6 +278,8 @@ class TTS_Config: mute_tokens: dict = { "v1" : 486, "v2" : 486, + "v2Pro": 486, + "v2ProPlus": 486, "v3" : 486, "v4" : 486, } @@ -1009,7 +1011,7 @@ class TTS: "top_k": 5, # int. top k sampling "top_p": 1, # float. top p sampling "temperature": 1, # float. temperature for sampling - "text_split_method": "cut0", # str. text split method, see text_segmentation_method.py for details. + "text_split_method": "cut1", # str. text split method, see text_segmentation_method.py for details. "batch_size": 1, # int. batch size for inference "batch_threshold": 0.75, # float. threshold for batch splitting. "split_bucket: True, # bool. whether to split the batch into multiple buckets. @@ -1023,7 +1025,7 @@ class TTS: "super_sampling": False, # bool. whether to use super-sampling for audio when using VITS model V3. "streaming_mode": False, # bool. return audio chunk by chunk. "overlap_length": 2, # int. overlap length of semantic tokens for streaming mode. - "chunk_length: 24, # int. chunk length of semantic tokens for streaming mode. (affects audio chunk size) + "min_chunk_length: 16, # int. The minimum chunk length of semantic tokens for streaming mode. (affects audio chunk size) } returns: Tuple[int, np.ndarray]: sampling rate and audio data. @@ -1039,7 +1041,7 @@ class TTS: top_k: int = inputs.get("top_k", 5) top_p: float = inputs.get("top_p", 1) temperature: float = inputs.get("temperature", 1) - text_split_method: str = inputs.get("text_split_method", "cut0") + text_split_method: str = inputs.get("text_split_method", "cut1") batch_size = inputs.get("batch_size", 1) batch_threshold = inputs.get("batch_threshold", 0.75) speed_factor = inputs.get("speed_factor", 1.0) @@ -1055,7 +1057,8 @@ class TTS: super_sampling = inputs.get("super_sampling", False) streaming_mode = inputs.get("streaming_mode", False) overlap_length = inputs.get("overlap_length", 2) - chunk_length = inputs.get("chunk_length", 24) + min_chunk_length = inputs.get("min_chunk_length", 16) + chunk_split_thershold = 0.0 # 该值代表语义token与mute token的余弦相似度阈值,若大于该阈值,则视为可切分点。 if parallel_infer and not streaming_mode: print(i18n("并行推理模式已开启")) @@ -1249,6 +1252,15 @@ class TTS: self.prompt_cache["prompt_semantic"].expand(len(all_phoneme_ids), -1).to(self.configs.device) ) + refer_audio_spec = [] + + sv_emb = [] if self.is_v2pro else None + for spec, audio_tensor in self.prompt_cache["refer_spec"]: + spec = spec.to(dtype=self.precision, device=self.configs.device) + refer_audio_spec.append(spec) + if self.is_v2pro: + sv_emb.append(self.sv_model.compute_embedding3(audio_tensor)) + if not streaming_mode: print(f"############ {i18n('预测语义Token')} ############") pred_semantic_list, idx_list = self.t2s_model.model.infer_panel( @@ -1267,14 +1279,6 @@ class TTS: t4 = time.perf_counter() t_34 += t4 - t3 - refer_audio_spec = [] - if self.is_v2pro: - sv_emb = [] - for spec, audio_tensor in self.prompt_cache["refer_spec"]: - spec = spec.to(dtype=self.precision, device=self.configs.device) - refer_audio_spec.append(spec) - if self.is_v2pro: - sv_emb.append(self.sv_model.compute_embedding3(audio_tensor)) batch_audio_fragment = [] @@ -1293,7 +1297,7 @@ class TTS: print(f"############ {i18n('合成音频')} ############") if not self.configs.use_vocoder: if speed_factor == 1.0: - print(f"{i18n('并行合成中')}...") + print(f"{i18n('合成中')}...") # ## vits并行推理 method 2 pred_semantic_list = [item[-idx:] for item, idx in zip(pred_semantic_list, idx_list)] upsample_rate = math.prod(self.vits_model.upsample_rates) @@ -1306,15 +1310,11 @@ class TTS: torch.cat(pred_semantic_list).unsqueeze(0).unsqueeze(0).to(self.configs.device) ) _batch_phones = torch.cat(batch_phones).unsqueeze(0).to(self.configs.device) - if self.is_v2pro != True: - _batch_audio_fragment, _, _ = self.vits_model.decode( - all_pred_semantic, _batch_phones, refer_audio_spec, speed=speed_factor - ).detach()[0, 0, :] - else: - _batch_audio_fragment = self.vits_model.decode( + + _batch_audio_fragment = self.vits_model.decode( all_pred_semantic, _batch_phones, refer_audio_spec, speed=speed_factor, sv_emb=sv_emb - ) - _batch_audio_fragment = _batch_audio_fragment.detach()[0, 0, :] + ).detach()[0, 0, :] + audio_frag_end_idx.insert(0, 0) batch_audio_fragment = [ _batch_audio_fragment[audio_frag_end_idx[i - 1] : audio_frag_end_idx[i]] @@ -1327,20 +1327,14 @@ class TTS: _pred_semantic = ( pred_semantic_list[i][-idx:].unsqueeze(0).unsqueeze(0) ) # .unsqueeze(0)#mq要多unsqueeze一次 - if self.is_v2pro != True: - audio_fragment, _, _ = self.vits_model.decode( - _pred_semantic, phones, refer_audio_spec, speed=speed_factor - ).detach()[0, 0, :] - else: - audio_fragment = self.vits_model.decode( + audio_fragment = self.vits_model.decode( _pred_semantic, phones, refer_audio_spec, speed=speed_factor, sv_emb=sv_emb - ) - audio_fragment=audio_fragment.detach()[0, 0, :] + ).detach()[0, 0, :] batch_audio_fragment.append(audio_fragment) ###试试重建不带上prompt部分 else: if parallel_infer: print(f"{i18n('并行合成中')}...") - audio_fragments, y, y_mask = self.using_vocoder_synthesis_batched_infer( + audio_fragments = self.using_vocoder_synthesis_batched_infer( idx_list, pred_semantic_list, batch_phones, speed=speed_factor, sample_steps=sample_steps ) batch_audio_fragment.extend(audio_fragments) @@ -1350,16 +1344,16 @@ class TTS: _pred_semantic = ( pred_semantic_list[i][-idx:].unsqueeze(0).unsqueeze(0) ) # .unsqueeze(0)#mq要多unsqueeze一次 - audio_fragment, y, y_mask = self.using_vocoder_synthesis( + audio_fragment = self.using_vocoder_synthesis( _pred_semantic, phones, speed=speed_factor, sample_steps=sample_steps ) batch_audio_fragment.append(audio_fragment) else: - refer_audio_spec: torch.Tensor = [ - item.to(dtype=self.precision, device=self.configs.device) - for item in self.prompt_cache["refer_spec"] - ] + # refer_audio_spec: torch.Tensor = [ + # item.to(dtype=self.precision, device=self.configs.device) + # for item in self.prompt_cache["refer_spec"] + # ] semantic_token_generator =self.t2s_model.model.infer_panel( all_phoneme_ids[0].unsqueeze(0), all_phoneme_lens, @@ -1372,8 +1366,9 @@ class TTS: max_len=max_len, repetition_penalty=repetition_penalty, streaming_mode=True, - chunk_length=chunk_length, + chunk_length=min_chunk_length, mute_emb_sim_matrix=self.configs.mute_emb_sim_matrix, + chunk_split_thershold=chunk_split_thershold, ) t4 = time.perf_counter() t_34 += t4 - t3 @@ -1381,15 +1376,15 @@ class TTS: is_first_chunk = True if not self.configs.use_vocoder: - if speed_factor == 1.0: - upsample_rate = math.prod(self.vits_model.upsample_rates)*(2 if self.vits_model.semantic_frame_rate == "25hz" else 1) - else: - upsample_rate = math.prod(self.vits_model.upsample_rates)*((2 if self.vits_model.semantic_frame_rate == "25hz" else 1)/speed_factor) + # if speed_factor == 1.0: + # upsample_rate = math.prod(self.vits_model.upsample_rates)*(2 if self.vits_model.semantic_frame_rate == "25hz" else 1) + # else: + upsample_rate = math.prod(self.vits_model.upsample_rates)*((2 if self.vits_model.semantic_frame_rate == "25hz" else 1)/speed_factor) else: - if speed_factor == 1.0: - upsample_rate = self.vocoder_configs["upsample_rate"]*(3.875 if self.configs.version == "v3" else 4) - else: - upsample_rate = self.vocoder_configs["upsample_rate"]*((3.875 if self.configs.version == "v3" else 4)/speed_factor) + # if speed_factor == 1.0: + # upsample_rate = self.vocoder_configs["upsample_rate"]*(3.875 if self.configs.version == "v3" else 4) + # else: + upsample_rate = self.vocoder_configs["upsample_rate"]*((3.875 if self.configs.version == "v3" else 4)/speed_factor) last_audio_chunk = None # last_tokens = None @@ -1419,10 +1414,8 @@ class TTS: if not is_first_chunk and semantic_tokens.shape[-1] < 10: overlap_len = overlap_length+(10-semantic_tokens.shape[-1]) - # overlap_size = math.ceil(overlap_len*upsample_rate) else: overlap_len = overlap_length - # overlap_size = math.ceil(overlap_length*upsample_rate) if not self.configs.use_vocoder: @@ -1433,10 +1426,11 @@ class TTS: # else: # token_padding_length = 0 - audio_chunk, latent, latent_mask = self.vits_model.decode( + audio_chunk, latent, latent_mask = self.vits_model.decode_steaming( _semantic_tokens.unsqueeze(0), phones, refer_audio_spec, speed=speed_factor, + sv_emb=sv_emb, result_length=semantic_tokens.shape[-1]+overlap_len if not is_first_chunk else None, overlap_frames=last_latent[:,:,-overlap_len*(2 if self.vits_model.semantic_frame_rate == "25hz" else 1):] \ if last_latent is not None else None, @@ -1444,13 +1438,7 @@ class TTS: ) audio_chunk=audio_chunk.detach()[0, 0, :] else: - audio_chunk, latent, latent_mask = self.using_vocoder_synthesis( - _semantic_tokens.unsqueeze(0), phones, - speed=speed_factor, sample_steps=sample_steps, - result_length = semantic_tokens.shape[-1]+overlap_len if not is_first_chunk else None, - overlap_frames=last_latent[:,:,-overlap_len*(2 if self.vits_model.semantic_frame_rate == "25hz" else 1):] \ - if last_latent is not None else None, - ) + raise RuntimeError(i18n("SoVits V3/4模型不支持流式推理模式")) if overlap_len>overlap_length: audio_chunk=audio_chunk[-int((overlap_length+semantic_tokens.shape[-1])*upsample_rate):] @@ -1614,7 +1602,7 @@ class TTS: return sr, audio def using_vocoder_synthesis( - self, semantic_tokens: torch.Tensor, phones: torch.Tensor, speed: float = 1.0, sample_steps: int = 32, result_length:int=None, overlap_frames:torch.Tensor=None + self, semantic_tokens: torch.Tensor, phones: torch.Tensor, speed: float = 1.0, sample_steps: int = 32 ): prompt_semantic_tokens = self.prompt_cache["prompt_semantic"].unsqueeze(0).unsqueeze(0).to(self.configs.device) prompt_phones = torch.LongTensor(self.prompt_cache["phones"]).unsqueeze(0).to(self.configs.device) @@ -1623,7 +1611,7 @@ class TTS: raw_entry = raw_entry[0] refer_audio_spec = raw_entry.to(dtype=self.precision, device=self.configs.device) - fea_ref, ge, y, y_mask = self.vits_model.decode_encp(prompt_semantic_tokens, prompt_phones, refer_audio_spec) + fea_ref, ge = self.vits_model.decode_encp(prompt_semantic_tokens, prompt_phones, refer_audio_spec) ref_audio: torch.Tensor = self.prompt_cache["raw_audio"] ref_sr = self.prompt_cache["raw_sr"] ref_audio = ref_audio.to(self.configs.device).float() @@ -1649,7 +1637,7 @@ class TTS: chunk_len = T_chunk - T_min mel2 = mel2.to(self.precision) - fea_todo, ge, y, y_mask = self.vits_model.decode_encp(semantic_tokens, phones, refer_audio_spec, ge, speed, result_length=result_length, overlap_frames=overlap_frames) + fea_todo, ge = self.vits_model.decode_encp(semantic_tokens, phones, refer_audio_spec, ge, speed) cfm_resss = [] idx = 0 @@ -1672,11 +1660,11 @@ class TTS: cfm_res = torch.cat(cfm_resss, 2) cfm_res = denorm_spec(cfm_res) - with torch.no_grad(): + with torch.inference_mode(): wav_gen = self.vocoder(cfm_res) audio = wav_gen[0][0] # .cpu().detach().numpy() - return audio, y, y_mask + return audio def using_vocoder_synthesis_batched_infer( self, @@ -1693,7 +1681,7 @@ class TTS: raw_entry = raw_entry[0] refer_audio_spec = raw_entry.to(dtype=self.precision, device=self.configs.device) - fea_ref, ge, y, y_mask = self.vits_model.decode_encp(prompt_semantic_tokens, prompt_phones, refer_audio_spec) + fea_ref, ge = self.vits_model.decode_encp(prompt_semantic_tokens, prompt_phones, refer_audio_spec) ref_audio: torch.Tensor = self.prompt_cache["raw_audio"] ref_sr = self.prompt_cache["raw_sr"] ref_audio = ref_audio.to(self.configs.device).float() @@ -1731,7 +1719,7 @@ class TTS: semantic_tokens = ( semantic_tokens_list[i][-idx:].unsqueeze(0).unsqueeze(0) ) # .unsqueeze(0)#mq要多unsqueeze一次 - feat, _, y, y_mask = self.vits_model.decode_encp(semantic_tokens, phones, refer_audio_spec, ge, speed) + feat, _ = self.vits_model.decode_encp(semantic_tokens, phones, refer_audio_spec, ge, speed) feat_list.append(feat) feat_lens.append(feat.shape[2]) @@ -1791,7 +1779,7 @@ class TTS: audio_fragments.append(audio_fragment) audio = audio[feat_len * upsample_rate :] - return audio_fragments, y, y_mask + return audio_fragments def sola_algorithm( self, diff --git a/GPT_SoVITS/module/models.py b/GPT_SoVITS/module/models.py index 6cb317f6..5049017f 100644 --- a/GPT_SoVITS/module/models.py +++ b/GPT_SoVITS/module/models.py @@ -990,8 +990,57 @@ class SynthesizerTrn(nn.Module): o = self.dec((z * y_mask)[:, :, :], g=ge) return o, y_mask, (z, z_p, m_p, logs_p) + @torch.no_grad() - def decode(self, codes, text, refer, noise_scale=0.5, speed=1, sv_emb=None, result_length:int=None, overlap_frames:torch.Tensor=None, padding_length:int=None): + def decode(self, codes, text, refer, noise_scale=0.5, speed=1, sv_emb=None): + def get_ge(refer, sv_emb): + ge = None + if refer is not None: + refer_lengths = torch.LongTensor([refer.size(2)]).to(refer.device) + refer_mask = torch.unsqueeze(commons.sequence_mask(refer_lengths, refer.size(2)), 1).to(refer.dtype) + if self.version == "v1": + ge = self.ref_enc(refer * refer_mask, refer_mask) + else: + ge = self.ref_enc(refer[:, :704] * refer_mask, refer_mask) + if self.is_v2pro: + sv_emb = self.sv_emb(sv_emb) # B*20480->B*512 + ge += sv_emb.unsqueeze(-1) + ge = self.prelu(ge) + return ge + + if type(refer) == list: + ges = [] + for idx, _refer in enumerate(refer): + ge = get_ge(_refer, sv_emb[idx] if self.is_v2pro else None) + ges.append(ge) + ge = torch.stack(ges, 0).mean(0) + else: + ge = get_ge(refer, sv_emb) + + y_lengths = torch.LongTensor([codes.size(2) * 2]).to(codes.device) + text_lengths = torch.LongTensor([text.size(-1)]).to(text.device) + + quantized = self.quantizer.decode(codes) + if self.semantic_frame_rate == "25hz": + quantized = F.interpolate(quantized, size=int(quantized.shape[-1] * 2), mode="nearest") + x, m_p, logs_p, y_mask, _, _ = self.enc_p( + quantized, + y_lengths, + text, + text_lengths, + self.ge_to512(ge.transpose(2, 1)).transpose(2, 1) if self.is_v2pro else ge, + speed, + ) + z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale + + z = self.flow(z_p, y_mask, g=ge, reverse=True) + + o = self.dec((z * y_mask)[:, :, :], g=ge) + return o + + + @torch.no_grad() + def decode_steaming(self, codes, text, refer, noise_scale=0.5, speed=1, sv_emb=None, result_length:int=None, overlap_frames:torch.Tensor=None, padding_length:int=None): def get_ge(refer, sv_emb): ge = None if refer is not None: @@ -1031,7 +1080,10 @@ class SynthesizerTrn(nn.Module): text_lengths, self.ge_to512(ge.transpose(2, 1)).transpose(2, 1) if self.is_v2pro else ge, speed, - , result_length=result_length, overlap_frames=overlap_frames, padding_length=padding_length) + result_length=result_length, + overlap_frames=overlap_frames, + padding_length=padding_length + ) z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale z = self.flow(z_p, y_mask, g=ge, reverse=True) @@ -1277,7 +1329,7 @@ class SynthesizerTrnV3(nn.Module): return cfm_loss @torch.no_grad() - def decode_encp(self, codes, text, refer, ge=None, speed=1, result_length:int=None, overlap_frames:torch.Tensor=None): + def decode_encp(self, codes, text, refer, ge=None, speed=1): # print(2333333,refer.shape) # ge=None if ge == None: @@ -1285,26 +1337,22 @@ class SynthesizerTrnV3(nn.Module): refer_mask = torch.unsqueeze(commons.sequence_mask(refer_lengths, refer.size(2)), 1).to(refer.dtype) ge = self.ref_enc(refer[:, :704] * refer_mask, refer_mask) y_lengths = torch.LongTensor([int(codes.size(2) * 2)]).to(codes.device) - if speed == 1: - sizee = int((codes.size(2) if result_length is None else result_length) * (3.875 if self.version=="v3"else 4)) + sizee = int(codes.size(2) * (3.875 if self.version == "v3" else 4)) else: - sizee = int((codes.size(2) if result_length is None else result_length) * (3.875 if self.version=="v3"else 4) / speed) + 1 + sizee = int(codes.size(2) * (3.875 if self.version == "v3" else 4) / speed) + 1 y_lengths1 = torch.LongTensor([sizee]).to(codes.device) text_lengths = torch.LongTensor([text.size(-1)]).to(text.device) quantized = self.quantizer.decode(codes) if self.semantic_frame_rate == "25hz": quantized = F.interpolate(quantized, scale_factor=2, mode="nearest") ##BCT - result_length = result_length * 2 if result_length is not None else None - - - x, m_p, logs_p, y_mask, y_, y_mask_ = self.enc_p(quantized, y_lengths, text, text_lengths, ge, speed, result_length=result_length, overlap_frames=overlap_frames) + x, m_p, logs_p, y_mask, _, _ = self.enc_p(quantized, y_lengths, text, text_lengths, ge, speed) fea = self.bridge(x) fea = F.interpolate(fea, scale_factor=(1.875 if self.version == "v3" else 2), mode="nearest") ##BCT ####more wn paramter to learn mel fea, y_mask_ = self.wns1(fea, y_lengths1, ge) - return fea, ge, y_, y_mask_ + return fea, ge def extract_latent(self, x): ssl = self.ssl_proj(x) diff --git a/api_v2.py b/api_v2.py index 6dcec851..7aeb5c16 100644 --- a/api_v2.py +++ b/api_v2.py @@ -30,7 +30,7 @@ POST: "top_k": 5, # int. top k sampling "top_p": 1, # float. top p sampling "temperature": 1, # float. temperature for sampling - "text_split_method": "cut0", # str. text split method, see text_segmentation_method.py for details. + "text_split_method": "cut5", # str. text split method, see text_segmentation_method.py for details. "batch_size": 1, # int. batch size for inference "batch_threshold": 0.75, # float. threshold for batch splitting. "split_bucket": True, # bool. whether to split the batch into multiple buckets. @@ -42,7 +42,7 @@ POST: "sample_steps": 32, # int. number of sampling steps for VITS model V3. "super_sampling": False, # bool. whether to use super-sampling for audio when using VITS model V3. "overlap_length": 2, # int. overlap length of semantic tokens for streaming mode. - "chunk_length: 24, # int. chunk length of semantic tokens for streaming mode. (affects audio chunk size) + "min_chunk_length: 16, # int. The minimum chunk length of semantic tokens for streaming mode. (affects audio chunk size) "return_fragment": False, # bool. step by step return the audio fragment. (old version of streaming mode) } ``` @@ -174,7 +174,7 @@ class TTS_Request(BaseModel): sample_steps: int = 32 super_sampling: bool = False overlap_length: int = 2 - chunk_length: int = 24 + min_chunk_length: int = 16 return_fragment: bool = False @@ -333,7 +333,7 @@ async def tts_handle(req: dict): "sample_steps": 32, # int. number of sampling steps for VITS model V3. "super_sampling": False, # bool. whether to use super-sampling for audio when using VITS model V3. "overlap_length": 2, # int. overlap length of semantic tokens for streaming mode. - "chunk_length: 24, # int. chunk length of semantic tokens for streaming mode. (affects audio chunk size) + "min_chunk_length: 16, # int. The minimum chunk length of semantic tokens for streaming mode. (affects audio chunk size) "return_fragment": False, # bool. step by step return the audio fragment. (old version of streaming mode) } returns: @@ -416,7 +416,7 @@ async def tts_get_endpoint( sample_steps: int = 32, super_sampling: bool = False, overlap_length: int = 2, - chunk_length: int = 24, + min_chunk_length: int = 16, return_fragment: bool = False, ): req = { @@ -443,7 +443,7 @@ async def tts_get_endpoint( "sample_steps": int(sample_steps), "super_sampling": super_sampling, "overlap_length": int(overlap_length), - "chunk_length": int(chunk_length), + "min_chunk_length": int(min_chunk_length), "return_fragment": return_fragment, } return await tts_handle(req) From fcdd15460d12a9e3059d98837b9fce5c57814601 Mon Sep 17 00:00:00 2001 From: ChasonJiang <1440499136@qq.com> Date: Mon, 24 Nov 2025 20:52:56 +0800 Subject: [PATCH 07/10] modified: GPT_SoVITS/TTS_infer_pack/TTS.py --- GPT_SoVITS/TTS_infer_pack/TTS.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/GPT_SoVITS/TTS_infer_pack/TTS.py b/GPT_SoVITS/TTS_infer_pack/TTS.py index a0c0d6ba..1c5897a2 100644 --- a/GPT_SoVITS/TTS_infer_pack/TTS.py +++ b/GPT_SoVITS/TTS_infer_pack/TTS.py @@ -1297,7 +1297,7 @@ class TTS: print(f"############ {i18n('合成音频')} ############") if not self.configs.use_vocoder: if speed_factor == 1.0: - print(f"{i18n('合成中')}...") + print(f"{i18n('并行合成中')}...") # ## vits并行推理 method 2 pred_semantic_list = [item[-idx:] for item, idx in zip(pred_semantic_list, idx_list)] upsample_rate = math.prod(self.vits_model.upsample_rates) From 9b147cd24aa729c8257b94514c58897c312c01e9 Mon Sep 17 00:00:00 2001 From: ChasonJiang <1440499136@qq.com> Date: Mon, 24 Nov 2025 23:07:38 +0800 Subject: [PATCH 08/10] =?UTF-8?q?=E6=9B=B4=E6=AD=A3=E6=8B=BC=E5=86=99?= =?UTF-8?q?=E9=94=99=E8=AF=AF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- GPT_SoVITS/TTS_infer_pack/TTS.py | 2 +- GPT_SoVITS/module/models.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/GPT_SoVITS/TTS_infer_pack/TTS.py b/GPT_SoVITS/TTS_infer_pack/TTS.py index 1c5897a2..1d25e30a 100644 --- a/GPT_SoVITS/TTS_infer_pack/TTS.py +++ b/GPT_SoVITS/TTS_infer_pack/TTS.py @@ -1426,7 +1426,7 @@ class TTS: # else: # token_padding_length = 0 - audio_chunk, latent, latent_mask = self.vits_model.decode_steaming( + audio_chunk, latent, latent_mask = self.vits_model.decode_streaming( _semantic_tokens.unsqueeze(0), phones, refer_audio_spec, speed=speed_factor, diff --git a/GPT_SoVITS/module/models.py b/GPT_SoVITS/module/models.py index 5049017f..348ddb3f 100644 --- a/GPT_SoVITS/module/models.py +++ b/GPT_SoVITS/module/models.py @@ -1040,7 +1040,7 @@ class SynthesizerTrn(nn.Module): @torch.no_grad() - def decode_steaming(self, codes, text, refer, noise_scale=0.5, speed=1, sv_emb=None, result_length:int=None, overlap_frames:torch.Tensor=None, padding_length:int=None): + def decode_streaming(self, codes, text, refer, noise_scale=0.5, speed=1, sv_emb=None, result_length:int=None, overlap_frames:torch.Tensor=None, padding_length:int=None): def get_ge(refer, sv_emb): ge = None if refer is not None: From 6bce575d69ef6d415de4d5e17b2c9e1c9df05545 Mon Sep 17 00:00:00 2001 From: ChasonJiang <1440499136@qq.com> Date: Wed, 26 Nov 2025 14:41:42 +0800 Subject: [PATCH 09/10] =?UTF-8?q?=E6=94=AF=E6=8C=81=E5=9B=BA=E5=AE=9Achunk?= =?UTF-8?q?=E9=95=BF=E5=BA=A6=E7=9A=84=E6=B5=81=E5=BC=8F=E6=8E=A8=E7=90=86?= =?UTF-8?q?=EF=BC=8C=E4=BC=98=E5=8C=96sola=E7=AE=97=E6=B3=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- GPT_SoVITS/AR/models/t2s_model.py | 3 +- GPT_SoVITS/TTS_infer_pack/TTS.py | 56 ++++++++++++++++--------------- api_v2.py | 33 ++++++++++-------- 3 files changed, 50 insertions(+), 42 deletions(-) diff --git a/GPT_SoVITS/AR/models/t2s_model.py b/GPT_SoVITS/AR/models/t2s_model.py index d72aa393..0caadd04 100644 --- a/GPT_SoVITS/AR/models/t2s_model.py +++ b/GPT_SoVITS/AR/models/t2s_model.py @@ -940,9 +940,10 @@ class Text2SemanticDecoder(nn.Module): elif streaming_mode and (mute_emb_sim_matrix is None) and (token_counter >= chunk_length): - token_counter == chunk_length yield y[:, -token_counter:], False + curr_ptr+=token_counter token_counter = 0 + ####################### update next step ################################### diff --git a/GPT_SoVITS/TTS_infer_pack/TTS.py b/GPT_SoVITS/TTS_infer_pack/TTS.py index 1d25e30a..be3d3a19 100644 --- a/GPT_SoVITS/TTS_infer_pack/TTS.py +++ b/GPT_SoVITS/TTS_infer_pack/TTS.py @@ -1014,18 +1014,19 @@ class TTS: "text_split_method": "cut1", # str. text split method, see text_segmentation_method.py for details. "batch_size": 1, # int. batch size for inference "batch_threshold": 0.75, # float. threshold for batch splitting. - "split_bucket: True, # bool. whether to split the batch into multiple buckets. - "return_fragment": False, # bool. step by step return the audio fragment. + "split_bucket": True, # bool. whether to split the batch into multiple buckets. "speed_factor":1.0, # float. control the speed of the synthesized audio. "fragment_interval":0.3, # float. to control the interval of the audio fragment. "seed": -1, # int. random seed for reproducibility. "parallel_infer": True, # bool. whether to use parallel inference. - "repetition_penalty": 1.35 # float. repetition penalty for T2S model. + "repetition_penalty": 1.35, # float. repetition penalty for T2S model. "sample_steps": 32, # int. number of sampling steps for VITS model V3. "super_sampling": False, # bool. whether to use super-sampling for audio when using VITS model V3. - "streaming_mode": False, # bool. return audio chunk by chunk. + "return_fragment": False, # bool. step by step return the audio fragment. (Best Quality, Slowest response speed. old version of streaming mode) + "streaming_mode": False, # bool. return audio chunk by chunk. (Medium quality, Slow response speed) "overlap_length": 2, # int. overlap length of semantic tokens for streaming mode. - "min_chunk_length: 16, # int. The minimum chunk length of semantic tokens for streaming mode. (affects audio chunk size) + "min_chunk_length": 16, # int. The minimum chunk length of semantic tokens for streaming mode. (affects audio chunk size) + "fixed_length_chunk": False, # bool. When turned on, it can achieve faster streaming response, but with lower quality. (lower quality, faster response speed) } returns: Tuple[int, np.ndarray]: sampling rate and audio data. @@ -1058,6 +1059,7 @@ class TTS: streaming_mode = inputs.get("streaming_mode", False) overlap_length = inputs.get("overlap_length", 2) min_chunk_length = inputs.get("min_chunk_length", 16) + fixed_length_chunk = inputs.get("fixed_length_chunk", False) chunk_split_thershold = 0.0 # 该值代表语义token与mute token的余弦相似度阈值,若大于该阈值,则视为可切分点。 if parallel_infer and not streaming_mode: @@ -1367,7 +1369,7 @@ class TTS: repetition_penalty=repetition_penalty, streaming_mode=True, chunk_length=min_chunk_length, - mute_emb_sim_matrix=self.configs.mute_emb_sim_matrix, + mute_emb_sim_matrix=self.configs.mute_emb_sim_matrix if not fixed_length_chunk else None, chunk_split_thershold=chunk_split_thershold, ) t4 = time.perf_counter() @@ -1456,11 +1458,6 @@ class TTS: else audio_chunk_[last_audio_chunk.shape[0]-overlap_size:] ) - # audio_chunk_ = ( - # audio_chunk_[overlap_size:-overlap_size] if not is_final \ - # else audio_chunk_[overlap_size:] - # ) - last_latent = latent last_audio_chunk = audio_chunk yield self.audio_postprocess( @@ -1785,30 +1782,35 @@ class TTS: self, audio_fragments: List[torch.Tensor], overlap_len: int, + search_len:int= 320 ): - for i in range(len(audio_fragments) - 1): - f1 = audio_fragments[i] - f2 = audio_fragments[i + 1] - w1 = f1[-overlap_len:] - w2 = f2[:overlap_len] - w2 = w2[-w2.shape[-1]//2:] - # assert w1.shape == w2.shape - corr = F.conv1d(w1.view(1, 1, -1), w2.view(1, 1, -1)).view(-1) + # overlap_len-=search_len - squared_sum = F.conv1d(w1.view(1, 1, -1)**2, torch.ones_like(w2).view(1, 1, -1)).view(-1)+ 1e-8 - idx = (corr/squared_sum.sqrt()).argmax() + dtype = audio_fragments[0].dtype + + for i in range(len(audio_fragments) - 1): + f1 = audio_fragments[i].float() + f2 = audio_fragments[i + 1].float() + w1 = f1[-overlap_len:] + w2 = f2[:overlap_len+search_len] + # w2 = w2[-w2.shape[-1]//2:] + # assert w1.shape == w2.shape + corr_norm = F.conv1d(w2.view(1, 1, -1), w1.view(1, 1, -1)).view(-1) + + corr_den = F.conv1d(w2.view(1, 1, -1)**2, torch.ones_like(w1).view(1, 1, -1)).view(-1)+ 1e-8 + idx = (corr_norm/corr_den.sqrt()).argmax() print(f"seg_idx: {idx}") # idx = corr.argmax() - f1_ = f1[: -(overlap_len - idx)] + f1_ = f1[: -overlap_len] audio_fragments[i] = f1_ f2_ = f2[idx:] - window = torch.hann_window((overlap_len - idx) * 2, device=f1.device, dtype=f1.dtype) - f2_[: (overlap_len - idx)] = ( - window[: (overlap_len - idx)] * f2_[: (overlap_len - idx)] - + window[(overlap_len - idx) :] * f1[-(overlap_len - idx) :] + window = torch.hann_window((overlap_len) * 2, device=f1.device, dtype=f1.dtype) + f2_[: overlap_len] = ( + window[: overlap_len] * f2_[: overlap_len] + + window[overlap_len :] * f1[-overlap_len :] ) # window = torch.sin(torch.arange((overlap_len - idx), device=f1.device) * np.pi / (overlap_len - idx)) @@ -1819,4 +1821,4 @@ class TTS: audio_fragments[i + 1] = f2_ - return torch.cat(audio_fragments, 0) + return torch.cat(audio_fragments, 0).to(dtype) diff --git a/api_v2.py b/api_v2.py index 7aeb5c16..21e9c0c5 100644 --- a/api_v2.py +++ b/api_v2.py @@ -35,15 +35,17 @@ POST: "batch_threshold": 0.75, # float. threshold for batch splitting. "split_bucket": True, # bool. whether to split the batch into multiple buckets. "speed_factor":1.0, # float. control the speed of the synthesized audio. - "streaming_mode": False, # bool. whether to return a streaming response. + "fragment_interval":0.3, # float. to control the interval of the audio fragment. "seed": -1, # int. random seed for reproducibility. "parallel_infer": True, # bool. whether to use parallel inference. "repetition_penalty": 1.35, # float. repetition penalty for T2S model. "sample_steps": 32, # int. number of sampling steps for VITS model V3. "super_sampling": False, # bool. whether to use super-sampling for audio when using VITS model V3. + "return_fragment": False, # bool. step by step return the audio fragment. (Best Quality, Slowest response speed. old version of streaming mode) + "streaming_mode": False, # bool. return audio chunk by chunk. (Medium quality, Slow response speed) "overlap_length": 2, # int. overlap length of semantic tokens for streaming mode. - "min_chunk_length: 16, # int. The minimum chunk length of semantic tokens for streaming mode. (affects audio chunk size) - "return_fragment": False, # bool. step by step return the audio fragment. (old version of streaming mode) + "min_chunk_length": 16, # int. The minimum chunk length of semantic tokens for streaming mode. (affects audio chunk size) + "fixed_length_chunk": False, # bool. When turned on, it can achieve faster streaming response, but with lower quality. (lower quality, faster response speed) } ``` @@ -176,6 +178,7 @@ class TTS_Request(BaseModel): overlap_length: int = 2 min_chunk_length: int = 16 return_fragment: bool = False + fixed_length_chunk: bool = False ### modify from https://github.com/RVC-Boss/GPT-SoVITS/pull/894/files @@ -313,7 +316,7 @@ async def tts_handle(req: dict): "text": "", # str.(required) text to be synthesized "text_lang: "", # str.(required) language of the text to be synthesized "ref_audio_path": "", # str.(required) reference audio path - "aux_ref_audio_paths": [], # list.(optional) auxiliary reference audio paths for multi-speaker synthesis + "aux_ref_audio_paths": [], # list.(optional) auxiliary reference audio paths for multi-speaker tone fusion "prompt_text": "", # str.(optional) prompt text for the reference audio "prompt_lang": "", # str.(required) language of the prompt text for the reference audio "top_k": 5, # int. top k sampling @@ -322,19 +325,19 @@ async def tts_handle(req: dict): "text_split_method": "cut5", # str. text split method, see text_segmentation_method.py for details. "batch_size": 1, # int. batch size for inference "batch_threshold": 0.75, # float. threshold for batch splitting. - "split_bucket: True, # bool. whether to split the batch into multiple buckets. + "split_bucket": True, # bool. whether to split the batch into multiple buckets. "speed_factor":1.0, # float. control the speed of the synthesized audio. "fragment_interval":0.3, # float. to control the interval of the audio fragment. "seed": -1, # int. random seed for reproducibility. - "media_type": "wav", # str. media type of the output audio, support "wav", "raw", "ogg", "aac". - "streaming_mode": False, # bool. whether to return a streaming response. - "parallel_infer": True, # bool.(optional) whether to use parallel inference. - "repetition_penalty": 1.35 # float.(optional) repetition penalty for T2S model. + "parallel_infer": True, # bool. whether to use parallel inference. + "repetition_penalty": 1.35, # float. repetition penalty for T2S model. "sample_steps": 32, # int. number of sampling steps for VITS model V3. "super_sampling": False, # bool. whether to use super-sampling for audio when using VITS model V3. + "return_fragment": False, # bool. step by step return the audio fragment. (Best Quality, Slowest response speed. old version of streaming mode) + "streaming_mode": False, # bool. return audio chunk by chunk. (Medium quality, Slow response speed) "overlap_length": 2, # int. overlap length of semantic tokens for streaming mode. - "min_chunk_length: 16, # int. The minimum chunk length of semantic tokens for streaming mode. (affects audio chunk size) - "return_fragment": False, # bool. step by step return the audio fragment. (old version of streaming mode) + "min_chunk_length": 16, # int. The minimum chunk length of semantic tokens for streaming mode. (affects audio chunk size) + "fixed_length_chunk": False, # bool. When turned on, it can achieve faster streaming response, but with lower quality. (lower quality, faster response speed) } returns: StreamingResponse: audio stream response. @@ -402,7 +405,7 @@ async def tts_get_endpoint( top_k: int = 5, top_p: float = 1, temperature: float = 1, - text_split_method: str = "cut0", + text_split_method: str = "cut5", batch_size: int = 1, batch_threshold: float = 0.75, split_bucket: bool = True, @@ -410,14 +413,15 @@ async def tts_get_endpoint( fragment_interval: float = 0.3, seed: int = -1, media_type: str = "wav", - streaming_mode: bool = False, parallel_infer: bool = True, repetition_penalty: float = 1.35, sample_steps: int = 32, super_sampling: bool = False, + return_fragment: bool = False, + streaming_mode: bool = False, overlap_length: int = 2, min_chunk_length: int = 16, - return_fragment: bool = False, + fixed_length_chunk: bool = False, ): req = { "text": text, @@ -445,6 +449,7 @@ async def tts_get_endpoint( "overlap_length": int(overlap_length), "min_chunk_length": int(min_chunk_length), "return_fragment": return_fragment, + "fixed_length_chunk": fixed_length_chunk } return await tts_handle(req) From 365760f7560cf14f528c1e2cc85c8027839f54c9 Mon Sep 17 00:00:00 2001 From: ChasonJiang <1440499136@qq.com> Date: Wed, 26 Nov 2025 15:40:34 +0800 Subject: [PATCH 10/10] =?UTF-8?q?=E4=BF=AE=E5=A4=8Dapi=5Fv2=E7=9A=84ogg?= =?UTF-8?q?=E6=A0=BC=E5=BC=8F=E4=BC=A0=E8=BE=93=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- api_v2.py | 50 +++++++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 45 insertions(+), 5 deletions(-) diff --git a/api_v2.py b/api_v2.py index 21e9c0c5..5df2da66 100644 --- a/api_v2.py +++ b/api_v2.py @@ -126,6 +126,7 @@ from tools.i18n.i18n import I18nAuto from GPT_SoVITS.TTS_infer_pack.TTS import TTS, TTS_Config from GPT_SoVITS.TTS_infer_pack.text_segmentation_method import get_method_names as get_cut_method_names from pydantic import BaseModel +import threading # print(sys.path) i18n = I18nAuto() @@ -181,10 +182,49 @@ class TTS_Request(BaseModel): fixed_length_chunk: bool = False -### modify from https://github.com/RVC-Boss/GPT-SoVITS/pull/894/files def pack_ogg(io_buffer: BytesIO, data: np.ndarray, rate: int): - with sf.SoundFile(io_buffer, mode="w", samplerate=rate, channels=1, format="ogg") as audio_file: - audio_file.write(data) + # Author: AkagawaTsurunaki + # Issue: + # Stack overflow probabilistically occurs + # when the function `sf_writef_short` of `libsndfile_64bit.dll` is called + # using the Python library `soundfile` + # Note: + # This is an issue related to `libsndfile`, not this project itself. + # It happens when you generate a large audio tensor (about 499804 frames in my PC) + # and try to convert it to an ogg file. + # Related: + # https://github.com/RVC-Boss/GPT-SoVITS/issues/1199 + # https://github.com/libsndfile/libsndfile/issues/1023 + # https://github.com/bastibe/python-soundfile/issues/396 + # Suggestion: + # Or split the whole audio data into smaller audio segment to avoid stack overflow? + + def handle_pack_ogg(): + with sf.SoundFile(io_buffer, mode="w", samplerate=rate, channels=1, format="ogg") as audio_file: + audio_file.write(data) + + + + # See: https://docs.python.org/3/library/threading.html + # The stack size of this thread is at least 32768 + # If stack overflow error still occurs, just modify the `stack_size`. + # stack_size = n * 4096, where n should be a positive integer. + # Here we chose n = 4096. + stack_size = 4096 * 4096 + try: + threading.stack_size(stack_size) + pack_ogg_thread = threading.Thread(target=handle_pack_ogg) + pack_ogg_thread.start() + pack_ogg_thread.join() + except RuntimeError as e: + # If changing the thread stack size is unsupported, a RuntimeError is raised. + print("RuntimeError: {}".format(e)) + print("Changing the thread stack size is unsupported.") + except ValueError as e: + # If the specified stack size is invalid, a ValueError is raised and the stack size is unmodified. + print("ValueError: {}".format(e)) + print("The specified stack size is invalid.") + return io_buffer @@ -295,8 +335,8 @@ def check_params(req: dict): ) if media_type not in ["wav", "raw", "ogg", "aac"]: return JSONResponse(status_code=400, content={"message": f"media_type: {media_type} is not supported"}) - elif media_type == "ogg" and not streaming_mode: - return JSONResponse(status_code=400, content={"message": "ogg format is not supported in non-streaming mode"}) + # elif media_type == "ogg" and not streaming_mode: + # return JSONResponse(status_code=400, content={"message": "ogg format is not supported in non-streaming mode"}) if text_split_method not in cut_method_names: return JSONResponse(