diff --git a/GPT_SoVITS/AR/models/t2s_model.py b/GPT_SoVITS/AR/models/t2s_model.py index 7196d6ab..0caadd04 100644 --- a/GPT_SoVITS/AR/models/t2s_model.py +++ b/GPT_SoVITS/AR/models/t2s_model.py @@ -794,7 +794,7 @@ class Text2SemanticDecoder(nn.Module): y_list = [] idx_list = [] for i in range(len(x)): - y, idx = self.infer_panel_naive( + y, idx = next(self.infer_panel_naive( x[i].unsqueeze(0), x_lens[i], prompts[i].unsqueeze(0) if prompts is not None else None, @@ -805,7 +805,7 @@ class Text2SemanticDecoder(nn.Module): temperature, repetition_penalty, **kwargs, - ) + )) y_list.append(y[0]) idx_list.append(idx) @@ -822,8 +822,15 @@ class Text2SemanticDecoder(nn.Module): early_stop_num: int = -1, temperature: float = 1.0, repetition_penalty: float = 1.35, + streaming_mode: bool = False, + chunk_length: int = 24, **kwargs, ): + mute_emb_sim_matrix = kwargs.get("mute_emb_sim_matrix", None) + chunk_split_thershold = kwargs.get("chunk_split_thershold", 0.3) + check_token_num = 2 + + x = self.ar_text_embedding(x) x = x + self.bert_proj(bert_feature.transpose(1, 2)) x = self.ar_text_position(x) @@ -875,7 +882,10 @@ class Text2SemanticDecoder(nn.Module): .to(device=x.device, dtype=torch.bool) ) + token_counter = 0 + curr_ptr = prefix_len for idx in tqdm(range(1500)): + token_counter+=1 if xy_attn_mask is not None: xy_dec, k_cache, v_cache = self.t2s_transformer.process_prompt(xy_pos, xy_attn_mask, None) else: @@ -900,22 +910,56 @@ class Text2SemanticDecoder(nn.Module): if torch.argmax(logits, dim=-1)[0] == self.EOS or samples[0, 0] == self.EOS: stop = True + y=y[:, :-1] + token_counter -= 1 + + if idx == 1499: + stop = True + if stop: if y.shape[1] == 0: y = torch.concat([y, torch.zeros_like(samples)], dim=1) print("bad zero prediction") - print(f"T2S Decoding EOS [{prefix_len} -> {y.shape[1]}]") + # print(f"T2S Decoding EOS [{prefix_len} -> {y.shape[1]}]") + if streaming_mode: + yield y[:, curr_ptr:] if curr_ptr= chunk_length+check_token_num): + score = mute_emb_sim_matrix[y[0, curr_ptr:]] - chunk_split_thershold + score[score<0]=-1 + score[:-1]=score[:-1]+score[1:] ##考虑连续两个token + argmax_idx = score.argmax() + + if score[argmax_idx]>=0 and argmax_idx+1>=chunk_length: + print(f"\n\ncurr_ptr:{curr_ptr}") + yield y[:, curr_ptr:], False + token_counter -= argmax_idx+1 + curr_ptr += argmax_idx+1 + + + elif streaming_mode and (mute_emb_sim_matrix is None) and (token_counter >= chunk_length): + yield y[:, -token_counter:], False + curr_ptr+=token_counter + token_counter = 0 + + + ####################### update next step ################################### y_emb = self.ar_audio_embedding(y[:, -1:]) xy_pos = y_emb * self.ar_audio_position.x_scale + self.ar_audio_position.alpha * self.ar_audio_position.pe[ :, y_len + idx ].to(dtype=y_emb.dtype, device=y_emb.device) - if ref_free: - return y[:, :-1], 0 - return y[:, :-1], idx + + + if not streaming_mode: + if ref_free: + yield y, 0 + yield y, idx + + def infer_panel( self, @@ -930,6 +974,6 @@ class Text2SemanticDecoder(nn.Module): repetition_penalty: float = 1.35, **kwargs, ): - return self.infer_panel_naive( + return next(self.infer_panel_naive( x, x_lens, prompts, bert_feature, top_k, top_p, early_stop_num, temperature, repetition_penalty, **kwargs - ) + )) diff --git a/GPT_SoVITS/TTS_infer_pack/TTS.py b/GPT_SoVITS/TTS_infer_pack/TTS.py index 0c1d2484..be3d3a19 100644 --- a/GPT_SoVITS/TTS_infer_pack/TTS.py +++ b/GPT_SoVITS/TTS_infer_pack/TTS.py @@ -275,6 +275,15 @@ class TTS_Config: v1_languages: list = ["auto", "en", "zh", "ja", "all_zh", "all_ja"] v2_languages: list = ["auto", "auto_yue", "en", "zh", "ja", "yue", "ko", "all_zh", "all_ja", "all_yue", "all_ko"] languages: list = v2_languages + mute_tokens: dict = { + "v1" : 486, + "v2" : 486, + "v2Pro": 486, + "v2ProPlus": 486, + "v3" : 486, + "v4" : 486, + } + mute_emb_sim_matrix: torch.Tensor = None # "all_zh",#全部按中文识别 # "en",#全部按英文识别#######不变 # "all_ja",#全部按日文识别 @@ -598,6 +607,11 @@ class TTS: if self.configs.is_half and str(self.configs.device) != "cpu": self.t2s_model = self.t2s_model.half() + codebook = t2s_model.model.ar_audio_embedding.weight.clone() + mute_emb = codebook[self.configs.mute_tokens[self.configs.version]].unsqueeze(0) + sim_matrix = F.cosine_similarity(mute_emb.float(), codebook.float(), dim=-1) + self.configs.mute_emb_sim_matrix = sim_matrix + def init_vocoder(self, version: str): if version == "v3": if self.vocoder is not None and self.vocoder.__class__.__name__ == "BigVGAN": @@ -997,18 +1011,22 @@ class TTS: "top_k": 5, # int. top k sampling "top_p": 1, # float. top p sampling "temperature": 1, # float. temperature for sampling - "text_split_method": "cut0", # str. text split method, see text_segmentation_method.py for details. + "text_split_method": "cut1", # str. text split method, see text_segmentation_method.py for details. "batch_size": 1, # int. batch size for inference "batch_threshold": 0.75, # float. threshold for batch splitting. - "split_bucket: True, # bool. whether to split the batch into multiple buckets. - "return_fragment": False, # bool. step by step return the audio fragment. + "split_bucket": True, # bool. whether to split the batch into multiple buckets. "speed_factor":1.0, # float. control the speed of the synthesized audio. "fragment_interval":0.3, # float. to control the interval of the audio fragment. "seed": -1, # int. random seed for reproducibility. "parallel_infer": True, # bool. whether to use parallel inference. - "repetition_penalty": 1.35 # float. repetition penalty for T2S model. + "repetition_penalty": 1.35, # float. repetition penalty for T2S model. "sample_steps": 32, # int. number of sampling steps for VITS model V3. - "super_sampling": False, # bool. whether to use super-sampling for audio when using VITS model V3. + "super_sampling": False, # bool. whether to use super-sampling for audio when using VITS model V3. + "return_fragment": False, # bool. step by step return the audio fragment. (Best Quality, Slowest response speed. old version of streaming mode) + "streaming_mode": False, # bool. return audio chunk by chunk. (Medium quality, Slow response speed) + "overlap_length": 2, # int. overlap length of semantic tokens for streaming mode. + "min_chunk_length": 16, # int. The minimum chunk length of semantic tokens for streaming mode. (affects audio chunk size) + "fixed_length_chunk": False, # bool. When turned on, it can achieve faster streaming response, but with lower quality. (lower quality, faster response speed) } returns: Tuple[int, np.ndarray]: sampling rate and audio data. @@ -1024,7 +1042,7 @@ class TTS: top_k: int = inputs.get("top_k", 5) top_p: float = inputs.get("top_p", 1) temperature: float = inputs.get("temperature", 1) - text_split_method: str = inputs.get("text_split_method", "cut0") + text_split_method: str = inputs.get("text_split_method", "cut1") batch_size = inputs.get("batch_size", 1) batch_threshold = inputs.get("batch_threshold", 0.75) speed_factor = inputs.get("speed_factor", 1.0) @@ -1038,19 +1056,43 @@ class TTS: repetition_penalty = inputs.get("repetition_penalty", 1.35) sample_steps = inputs.get("sample_steps", 32) super_sampling = inputs.get("super_sampling", False) + streaming_mode = inputs.get("streaming_mode", False) + overlap_length = inputs.get("overlap_length", 2) + min_chunk_length = inputs.get("min_chunk_length", 16) + fixed_length_chunk = inputs.get("fixed_length_chunk", False) + chunk_split_thershold = 0.0 # 该值代表语义token与mute token的余弦相似度阈值,若大于该阈值,则视为可切分点。 - if parallel_infer: + if parallel_infer and not streaming_mode: print(i18n("并行推理模式已开启")) self.t2s_model.model.infer_panel = self.t2s_model.model.infer_panel_batch_infer + elif not parallel_infer and streaming_mode and not self.configs.use_vocoder: + print(i18n("流式推理模式已开启")) + self.t2s_model.model.infer_panel = self.t2s_model.model.infer_panel_naive + elif streaming_mode and self.configs.use_vocoder: + print(i18n("SoVits V3/4模型不支持流式推理模式,已自动回退到分段返回模式")) + streaming_mode = False + return_fragment = True + if parallel_infer: + self.t2s_model.model.infer_panel = self.t2s_model.model.infer_panel_batch_infer + else: + self.t2s_model.model.infer_panel = self.t2s_model.model.infer_panel_naive_batched + # self.t2s_model.model.infer_panel = self.t2s_model.model.infer_panel_naive + elif parallel_infer and streaming_mode: + print(i18n("不支持同时开启并行推理和流式推理模式,已自动关闭并行推理模式")) + parallel_infer = False + self.t2s_model.model.infer_panel = self.t2s_model.model.infer_panel_naive else: - print(i18n("并行推理模式已关闭")) + print(i18n("朴素推理模式已开启")) self.t2s_model.model.infer_panel = self.t2s_model.model.infer_panel_naive_batched - if return_fragment: - print(i18n("分段返回模式已开启")) - if split_bucket: - split_bucket = False - print(i18n("分段返回模式不支持分桶处理,已自动关闭分桶处理")) + if return_fragment and streaming_mode: + print(i18n("流式推理模式不支持分段返回,已自动关闭分段返回")) + return_fragment = False + + if (return_fragment or streaming_mode) and split_bucket: + print(i18n("分段返回模式/流式推理模式不支持分桶处理,已自动关闭分桶处理")) + split_bucket = False + if split_bucket and speed_factor == 1.0 and not (self.configs.use_vocoder and parallel_infer): print(i18n("分桶处理模式已开启")) @@ -1063,9 +1105,9 @@ class TTS: else: print(i18n("分桶处理模式已关闭")) - if fragment_interval < 0.01: - fragment_interval = 0.01 - print(i18n("分段间隔过小,已自动设置为0.01")) + # if fragment_interval < 0.01: + # fragment_interval = 0.01 + # print(i18n("分段间隔过小,已自动设置为0.01")) no_prompt_text = False if prompt_text in [None, ""]: @@ -1126,7 +1168,7 @@ class TTS: ###### text preprocessing ######## t1 = time.perf_counter() data: list = None - if not return_fragment: + if not (return_fragment or streaming_mode): data = self.text_preprocessor.preprocess(text, text_lang, text_split_method, self.configs.version) if len(data) == 0: yield 16000, np.zeros(int(16000), dtype=np.int16) @@ -1186,10 +1228,11 @@ class TTS: t_34 = 0.0 t_45 = 0.0 audio = [] + is_first_package = True output_sr = self.configs.sampling_rate if not self.configs.use_vocoder else self.vocoder_configs["sr"] for item in data: t3 = time.perf_counter() - if return_fragment: + if return_fragment or streaming_mode: item = make_batch(item) if item is None: continue @@ -1211,108 +1254,228 @@ class TTS: self.prompt_cache["prompt_semantic"].expand(len(all_phoneme_ids), -1).to(self.configs.device) ) - print(f"############ {i18n('预测语义Token')} ############") - pred_semantic_list, idx_list = self.t2s_model.model.infer_panel( - all_phoneme_ids, - all_phoneme_lens, - prompt, - all_bert_features, - # prompt_phone_len=ph_offset, - top_k=top_k, - top_p=top_p, - temperature=temperature, - early_stop_num=self.configs.hz * self.configs.max_sec, - max_len=max_len, - repetition_penalty=repetition_penalty, - ) - t4 = time.perf_counter() - t_34 += t4 - t3 - refer_audio_spec = [] - if self.is_v2pro: - sv_emb = [] + + sv_emb = [] if self.is_v2pro else None for spec, audio_tensor in self.prompt_cache["refer_spec"]: spec = spec.to(dtype=self.precision, device=self.configs.device) refer_audio_spec.append(spec) if self.is_v2pro: sv_emb.append(self.sv_model.compute_embedding3(audio_tensor)) - batch_audio_fragment = [] + if not streaming_mode: + print(f"############ {i18n('预测语义Token')} ############") + pred_semantic_list, idx_list = self.t2s_model.model.infer_panel( + all_phoneme_ids, + all_phoneme_lens, + prompt, + all_bert_features, + # prompt_phone_len=ph_offset, + top_k=top_k, + top_p=top_p, + temperature=temperature, + early_stop_num=self.configs.hz * self.configs.max_sec, + max_len=max_len, + repetition_penalty=repetition_penalty, + ) + t4 = time.perf_counter() + t_34 += t4 - t3 - # ## vits并行推理 method 1 - # pred_semantic_list = [item[-idx:] for item, idx in zip(pred_semantic_list, idx_list)] - # pred_semantic_len = torch.LongTensor([item.shape[0] for item in pred_semantic_list]).to(self.configs.device) - # pred_semantic = self.batch_sequences(pred_semantic_list, axis=0, pad_value=0).unsqueeze(0) - # max_len = 0 - # for i in range(0, len(batch_phones)): - # max_len = max(max_len, batch_phones[i].shape[-1]) - # batch_phones = self.batch_sequences(batch_phones, axis=0, pad_value=0, max_length=max_len) - # batch_phones = batch_phones.to(self.configs.device) - # batch_audio_fragment = (self.vits_model.batched_decode( - # pred_semantic, pred_semantic_len, batch_phones, batch_phones_len,refer_audio_spec - # )) - print(f"############ {i18n('合成音频')} ############") - if not self.configs.use_vocoder: - if speed_factor == 1.0: - print(f"{i18n('并行合成中')}...") - # ## vits并行推理 method 2 - pred_semantic_list = [item[-idx:] for item, idx in zip(pred_semantic_list, idx_list)] - upsample_rate = math.prod(self.vits_model.upsample_rates) - audio_frag_idx = [ - pred_semantic_list[i].shape[0] * 2 * upsample_rate - for i in range(0, len(pred_semantic_list)) - ] - audio_frag_end_idx = [sum(audio_frag_idx[: i + 1]) for i in range(0, len(audio_frag_idx))] - all_pred_semantic = ( - torch.cat(pred_semantic_list).unsqueeze(0).unsqueeze(0).to(self.configs.device) - ) - _batch_phones = torch.cat(batch_phones).unsqueeze(0).to(self.configs.device) - if self.is_v2pro != True: - _batch_audio_fragment = self.vits_model.decode( - all_pred_semantic, _batch_phones, refer_audio_spec, speed=speed_factor - ).detach()[0, 0, :] - else: - _batch_audio_fragment = self.vits_model.decode( - all_pred_semantic, _batch_phones, refer_audio_spec, speed=speed_factor, sv_emb=sv_emb - ).detach()[0, 0, :] - audio_frag_end_idx.insert(0, 0) - batch_audio_fragment = [ - _batch_audio_fragment[audio_frag_end_idx[i - 1] : audio_frag_end_idx[i]] - for i in range(1, len(audio_frag_end_idx)) - ] - else: - # ## vits串行推理 - for i, idx in enumerate(tqdm(idx_list)): - phones = batch_phones[i].unsqueeze(0).to(self.configs.device) - _pred_semantic = ( - pred_semantic_list[i][-idx:].unsqueeze(0).unsqueeze(0) - ) # .unsqueeze(0)#mq要多unsqueeze一次 - if self.is_v2pro != True: - audio_fragment = self.vits_model.decode( - _pred_semantic, phones, refer_audio_spec, speed=speed_factor - ).detach()[0, 0, :] - else: - audio_fragment = self.vits_model.decode( - _pred_semantic, phones, refer_audio_spec, speed=speed_factor, sv_emb=sv_emb - ).detach()[0, 0, :] - batch_audio_fragment.append(audio_fragment) ###试试重建不带上prompt部分 - else: - if parallel_infer: - print(f"{i18n('并行合成中')}...") - audio_fragments = self.using_vocoder_synthesis_batched_infer( - idx_list, pred_semantic_list, batch_phones, speed=speed_factor, sample_steps=sample_steps - ) - batch_audio_fragment.extend(audio_fragments) - else: - for i, idx in enumerate(tqdm(idx_list)): - phones = batch_phones[i].unsqueeze(0).to(self.configs.device) - _pred_semantic = ( - pred_semantic_list[i][-idx:].unsqueeze(0).unsqueeze(0) - ) # .unsqueeze(0)#mq要多unsqueeze一次 - audio_fragment = self.using_vocoder_synthesis( - _pred_semantic, phones, speed=speed_factor, sample_steps=sample_steps + + batch_audio_fragment = [] + + # ## vits并行推理 method 1 + # pred_semantic_list = [item[-idx:] for item, idx in zip(pred_semantic_list, idx_list)] + # pred_semantic_len = torch.LongTensor([item.shape[0] for item in pred_semantic_list]).to(self.configs.device) + # pred_semantic = self.batch_sequences(pred_semantic_list, axis=0, pad_value=0).unsqueeze(0) + # max_len = 0 + # for i in range(0, len(batch_phones)): + # max_len = max(max_len, batch_phones[i].shape[-1]) + # batch_phones = self.batch_sequences(batch_phones, axis=0, pad_value=0, max_length=max_len) + # batch_phones = batch_phones.to(self.configs.device) + # batch_audio_fragment = (self.vits_model.batched_decode( + # pred_semantic, pred_semantic_len, batch_phones, batch_phones_len,refer_audio_spec + # )) + print(f"############ {i18n('合成音频')} ############") + if not self.configs.use_vocoder: + if speed_factor == 1.0: + print(f"{i18n('并行合成中')}...") + # ## vits并行推理 method 2 + pred_semantic_list = [item[-idx:] for item, idx in zip(pred_semantic_list, idx_list)] + upsample_rate = math.prod(self.vits_model.upsample_rates) + audio_frag_idx = [ + pred_semantic_list[i].shape[0] * 2 * upsample_rate + for i in range(0, len(pred_semantic_list)) + ] + audio_frag_end_idx = [sum(audio_frag_idx[: i + 1]) for i in range(0, len(audio_frag_idx))] + all_pred_semantic = ( + torch.cat(pred_semantic_list).unsqueeze(0).unsqueeze(0).to(self.configs.device) ) - batch_audio_fragment.append(audio_fragment) + _batch_phones = torch.cat(batch_phones).unsqueeze(0).to(self.configs.device) + + _batch_audio_fragment = self.vits_model.decode( + all_pred_semantic, _batch_phones, refer_audio_spec, speed=speed_factor, sv_emb=sv_emb + ).detach()[0, 0, :] + + audio_frag_end_idx.insert(0, 0) + batch_audio_fragment = [ + _batch_audio_fragment[audio_frag_end_idx[i - 1] : audio_frag_end_idx[i]] + for i in range(1, len(audio_frag_end_idx)) + ] + else: + # ## vits串行推理 + for i, idx in enumerate(tqdm(idx_list)): + phones = batch_phones[i].unsqueeze(0).to(self.configs.device) + _pred_semantic = ( + pred_semantic_list[i][-idx:].unsqueeze(0).unsqueeze(0) + ) # .unsqueeze(0)#mq要多unsqueeze一次 + audio_fragment = self.vits_model.decode( + _pred_semantic, phones, refer_audio_spec, speed=speed_factor, sv_emb=sv_emb + ).detach()[0, 0, :] + batch_audio_fragment.append(audio_fragment) ###试试重建不带上prompt部分 + else: + if parallel_infer: + print(f"{i18n('并行合成中')}...") + audio_fragments = self.using_vocoder_synthesis_batched_infer( + idx_list, pred_semantic_list, batch_phones, speed=speed_factor, sample_steps=sample_steps + ) + batch_audio_fragment.extend(audio_fragments) + else: + for i, idx in enumerate(tqdm(idx_list)): + phones = batch_phones[i].unsqueeze(0).to(self.configs.device) + _pred_semantic = ( + pred_semantic_list[i][-idx:].unsqueeze(0).unsqueeze(0) + ) # .unsqueeze(0)#mq要多unsqueeze一次 + audio_fragment = self.using_vocoder_synthesis( + _pred_semantic, phones, speed=speed_factor, sample_steps=sample_steps + ) + batch_audio_fragment.append(audio_fragment) + + else: + # refer_audio_spec: torch.Tensor = [ + # item.to(dtype=self.precision, device=self.configs.device) + # for item in self.prompt_cache["refer_spec"] + # ] + semantic_token_generator =self.t2s_model.model.infer_panel( + all_phoneme_ids[0].unsqueeze(0), + all_phoneme_lens, + prompt, + all_bert_features[0].unsqueeze(0), + top_k=top_k, + top_p=top_p, + temperature=temperature, + early_stop_num=self.configs.hz * self.configs.max_sec, + max_len=max_len, + repetition_penalty=repetition_penalty, + streaming_mode=True, + chunk_length=min_chunk_length, + mute_emb_sim_matrix=self.configs.mute_emb_sim_matrix if not fixed_length_chunk else None, + chunk_split_thershold=chunk_split_thershold, + ) + t4 = time.perf_counter() + t_34 += t4 - t3 + phones = batch_phones[0].unsqueeze(0).to(self.configs.device) + is_first_chunk = True + + if not self.configs.use_vocoder: + # if speed_factor == 1.0: + # upsample_rate = math.prod(self.vits_model.upsample_rates)*(2 if self.vits_model.semantic_frame_rate == "25hz" else 1) + # else: + upsample_rate = math.prod(self.vits_model.upsample_rates)*((2 if self.vits_model.semantic_frame_rate == "25hz" else 1)/speed_factor) + else: + # if speed_factor == 1.0: + # upsample_rate = self.vocoder_configs["upsample_rate"]*(3.875 if self.configs.version == "v3" else 4) + # else: + upsample_rate = self.vocoder_configs["upsample_rate"]*((3.875 if self.configs.version == "v3" else 4)/speed_factor) + + last_audio_chunk = None + # last_tokens = None + last_latent = None + previous_tokens = [] + overlap_len = overlap_length + overlap_size = math.ceil(overlap_length*upsample_rate) + for semantic_tokens, is_final in semantic_token_generator: + if semantic_tokens is None and last_audio_chunk is not None: + yield self.audio_postprocess( + [[last_audio_chunk[-overlap_size:]]], + output_sr, + None, + speed_factor, + False, + 0.0, + super_sampling if self.configs.use_vocoder and self.configs.version == "v3" else False, + ) + break + + _semantic_tokens = semantic_tokens + print(f"semantic_tokens shape:{semantic_tokens.shape}") + + previous_tokens.append(semantic_tokens) + + _semantic_tokens = torch.cat(previous_tokens, dim=-1) + + if not is_first_chunk and semantic_tokens.shape[-1] < 10: + overlap_len = overlap_length+(10-semantic_tokens.shape[-1]) + else: + overlap_len = overlap_length + + + if not self.configs.use_vocoder: + token_padding_length = 0 + # token_padding_length = int(phones.shape[-1]*2)-_semantic_tokens.shape[-1] + # if token_padding_length>0: + # _semantic_tokens = F.pad(_semantic_tokens, (0, token_padding_length), "constant", 486) + # else: + # token_padding_length = 0 + + audio_chunk, latent, latent_mask = self.vits_model.decode_streaming( + _semantic_tokens.unsqueeze(0), + phones, refer_audio_spec, + speed=speed_factor, + sv_emb=sv_emb, + result_length=semantic_tokens.shape[-1]+overlap_len if not is_first_chunk else None, + overlap_frames=last_latent[:,:,-overlap_len*(2 if self.vits_model.semantic_frame_rate == "25hz" else 1):] \ + if last_latent is not None else None, + padding_length=token_padding_length + ) + audio_chunk=audio_chunk.detach()[0, 0, :] + else: + raise RuntimeError(i18n("SoVits V3/4模型不支持流式推理模式")) + + if overlap_len>overlap_length: + audio_chunk=audio_chunk[-int((overlap_length+semantic_tokens.shape[-1])*upsample_rate):] + + audio_chunk_ = audio_chunk + if is_first_chunk and not is_final: + is_first_chunk = False + audio_chunk_ = audio_chunk_[:-overlap_size] + elif is_first_chunk and is_final: + is_first_chunk = False + elif not is_first_chunk and not is_final: + audio_chunk_ = self.sola_algorithm([last_audio_chunk, audio_chunk_], overlap_size) + audio_chunk_ = ( + audio_chunk_[last_audio_chunk.shape[0]-overlap_size:-overlap_size] if not is_final \ + else audio_chunk_[last_audio_chunk.shape[0]-overlap_size:] + ) + + last_latent = latent + last_audio_chunk = audio_chunk + yield self.audio_postprocess( + [[audio_chunk_]], + output_sr, + None, + speed_factor, + False, + 0.0, + super_sampling if self.configs.use_vocoder and self.configs.version == "v3" else False, + ) + + if is_first_package: + print(f"first_package_delay: {time.perf_counter()-t0:.3f}") + is_first_package = False + + + yield output_sr, np.zeros(int(output_sr*fragment_interval), dtype=np.int16) t5 = time.perf_counter() t_45 += t5 - t4 @@ -1327,17 +1490,18 @@ class TTS: fragment_interval, super_sampling if self.configs.use_vocoder and self.configs.version == "v3" else False, ) + elif streaming_mode:... else: audio.append(batch_audio_fragment) if self.stop_flag: - yield 16000, np.zeros(int(16000), dtype=np.int16) + yield output_sr, np.zeros(int(output_sr), dtype=np.int16) return - if not return_fragment: + if not (return_fragment or streaming_mode): print("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t_34, t_45)) if len(audio) == 0: - yield 16000, np.zeros(int(16000), dtype=np.int16) + yield output_sr, np.zeros(int(output_sr), dtype=np.int16) return yield self.audio_postprocess( audio, @@ -1384,16 +1548,17 @@ class TTS: fragment_interval: float = 0.3, super_sampling: bool = False, ) -> Tuple[int, np.ndarray]: - zero_wav = torch.zeros( - int(self.configs.sampling_rate * fragment_interval), dtype=self.precision, device=self.configs.device - ) + if fragment_interval>0: + zero_wav = torch.zeros( + int(self.configs.sampling_rate * fragment_interval), dtype=self.precision, device=self.configs.device + ) for i, batch in enumerate(audio): for j, audio_fragment in enumerate(batch): max_audio = torch.abs(audio_fragment).max() # 简单防止16bit爆音 if max_audio > 1: audio_fragment /= max_audio - audio_fragment: torch.Tensor = torch.cat([audio_fragment, zero_wav], dim=0) + audio_fragment: torch.Tensor = torch.cat([audio_fragment, zero_wav], dim=0) if fragment_interval>0 else audio_fragment audio[i][j] = audio_fragment if split_bucket: @@ -1413,13 +1578,18 @@ class TTS: max_audio = np.abs(audio).max() if max_audio > 1: audio /= max_audio + audio = (audio * 32768).astype(np.int16) t2 = time.perf_counter() print(f"超采样用时:{t2 - t1:.3f}s") else: + # audio = audio.float() * 32768 + # audio = audio.to(dtype=torch.int16).clamp(-32768, 32767).cpu().numpy() + audio = audio.cpu().numpy() audio = (audio * 32768).astype(np.int16) + # try: # if speed_factor != 1.0: # audio = speed_change(audio, speed=speed_factor, sr=int(sr)) @@ -1612,24 +1782,43 @@ class TTS: self, audio_fragments: List[torch.Tensor], overlap_len: int, + search_len:int= 320 ): + # overlap_len-=search_len + + dtype = audio_fragments[0].dtype + for i in range(len(audio_fragments) - 1): - f1 = audio_fragments[i] - f2 = audio_fragments[i + 1] + f1 = audio_fragments[i].float() + f2 = audio_fragments[i + 1].float() w1 = f1[-overlap_len:] - w2 = f2[:overlap_len] - assert w1.shape == w2.shape - corr = F.conv1d(w1.view(1, 1, -1), w2.view(1, 1, -1), padding=w2.shape[-1] // 2).view(-1)[:-1] - idx = corr.argmax() - f1_ = f1[: -(overlap_len - idx)] + w2 = f2[:overlap_len+search_len] + # w2 = w2[-w2.shape[-1]//2:] + # assert w1.shape == w2.shape + corr_norm = F.conv1d(w2.view(1, 1, -1), w1.view(1, 1, -1)).view(-1) + + corr_den = F.conv1d(w2.view(1, 1, -1)**2, torch.ones_like(w1).view(1, 1, -1)).view(-1)+ 1e-8 + idx = (corr_norm/corr_den.sqrt()).argmax() + + print(f"seg_idx: {idx}") + + # idx = corr.argmax() + f1_ = f1[: -overlap_len] audio_fragments[i] = f1_ f2_ = f2[idx:] - window = torch.hann_window((overlap_len - idx) * 2, device=f1.device, dtype=f1.dtype) - f2_[: (overlap_len - idx)] = ( - window[: (overlap_len - idx)] * f2_[: (overlap_len - idx)] - + window[(overlap_len - idx) :] * f1[-(overlap_len - idx) :] + window = torch.hann_window((overlap_len) * 2, device=f1.device, dtype=f1.dtype) + f2_[: overlap_len] = ( + window[: overlap_len] * f2_[: overlap_len] + + window[overlap_len :] * f1[-overlap_len :] ) + + # window = torch.sin(torch.arange((overlap_len - idx), device=f1.device) * np.pi / (overlap_len - idx)) + # f2_[: (overlap_len - idx)] = ( + # window * f2_[: (overlap_len - idx)] + # + (1-window) * f1[-(overlap_len - idx) :] + # ) + audio_fragments[i + 1] = f2_ - return torch.cat(audio_fragments, 0) + return torch.cat(audio_fragments, 0).to(dtype) diff --git a/GPT_SoVITS/module/models.py b/GPT_SoVITS/module/models.py index 1c8e662f..348ddb3f 100644 --- a/GPT_SoVITS/module/models.py +++ b/GPT_SoVITS/module/models.py @@ -151,6 +151,8 @@ class DurationPredictor(nn.Module): return x * x_mask +WINDOW = {} + class TextEncoder(nn.Module): def __init__( self, @@ -209,7 +211,7 @@ class TextEncoder(nn.Module): self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1) - def forward(self, y, y_lengths, text, text_lengths, ge, speed=1, test=None): + def forward(self, y, y_lengths, text, text_lengths, ge, speed=1, test=None, result_length:int=None, overlap_frames:torch.Tensor=None, padding_length:int=None): y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(y.dtype) y = self.ssl_proj(y * y_mask) * y_mask @@ -222,13 +224,44 @@ class TextEncoder(nn.Module): text = self.text_embedding(text).transpose(1, 2) text = self.encoder_text(text * text_mask, text_mask) y = self.mrte(y, y_mask, text, text_mask, ge) + + if padding_length is not None and padding_length!=0: + y = y[:, :, :-padding_length] + y_mask = y_mask[:, :, :-padding_length] + + y = self.encoder2(y * y_mask, y_mask) + + if result_length is not None: + y = y[:, :, -result_length:] + y_mask = y_mask[:, :, -result_length:] + + if overlap_frames is not None: + overlap_len = overlap_frames.shape[-1] + window = WINDOW.get(overlap_len, None) + if window is None: + # WINDOW[overlap_len] = torch.hann_window(overlap_len*2, device=y.device, dtype=y.dtype) + WINDOW[overlap_len] = torch.sin(torch.arange(overlap_len*2, device=y.device) * torch.pi / (overlap_len*2)) + window = WINDOW[overlap_len] + + + window = window.to(y.device) + y[:,:,:overlap_len] = ( + window[:overlap_len].view(1, 1, -1) * y[:,:,:overlap_len] + + window[overlap_len:].view(1, 1, -1) * overlap_frames + ) + + y_ = y + y_mask_ = y_mask + + + if speed != 1: y = F.interpolate(y, size=int(y.shape[-1] / speed) + 1, mode="linear") y_mask = F.interpolate(y_mask, size=y.shape[-1], mode="nearest") stats = self.proj(y) * y_mask m, logs = torch.split(stats, self.out_channels, dim=1) - return y, m, logs, y_mask + return y, m, logs, y_mask, y_, y_mask_ def extract_latent(self, x): x = self.ssl_proj(x) @@ -921,7 +954,7 @@ class SynthesizerTrn(nn.Module): if self.semantic_frame_rate == "25hz": quantized = F.interpolate(quantized, size=int(quantized.shape[-1] * 2), mode="nearest") - x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge512 if self.is_v2pro else ge) + x, m_p, logs_p, y_mask, _, _ = self.enc_p(quantized, y_lengths, text, text_lengths, ge512 if self.is_v2pro else ge) z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=ge) z_p = self.flow(z, y_mask, g=ge) @@ -949,7 +982,7 @@ class SynthesizerTrn(nn.Module): if self.semantic_frame_rate == "25hz": quantized = F.interpolate(quantized, size=int(quantized.shape[-1] * 2), mode="nearest") - x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge, test=test) + x, m_p, logs_p, y_mask, _, _ = self.enc_p(quantized, y_lengths, text, text_lengths, ge, test=test) z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale z = self.flow(z_p, y_mask, g=ge, reverse=True) @@ -957,6 +990,7 @@ class SynthesizerTrn(nn.Module): o = self.dec((z * y_mask)[:, :, :], g=ge) return o, y_mask, (z, z_p, m_p, logs_p) + @torch.no_grad() def decode(self, codes, text, refer, noise_scale=0.5, speed=1, sv_emb=None): def get_ge(refer, sv_emb): @@ -989,7 +1023,7 @@ class SynthesizerTrn(nn.Module): quantized = self.quantizer.decode(codes) if self.semantic_frame_rate == "25hz": quantized = F.interpolate(quantized, size=int(quantized.shape[-1] * 2), mode="nearest") - x, m_p, logs_p, y_mask = self.enc_p( + x, m_p, logs_p, y_mask, _, _ = self.enc_p( quantized, y_lengths, text, @@ -1004,6 +1038,59 @@ class SynthesizerTrn(nn.Module): o = self.dec((z * y_mask)[:, :, :], g=ge) return o + + @torch.no_grad() + def decode_streaming(self, codes, text, refer, noise_scale=0.5, speed=1, sv_emb=None, result_length:int=None, overlap_frames:torch.Tensor=None, padding_length:int=None): + def get_ge(refer, sv_emb): + ge = None + if refer is not None: + refer_lengths = torch.LongTensor([refer.size(2)]).to(refer.device) + refer_mask = torch.unsqueeze(commons.sequence_mask(refer_lengths, refer.size(2)), 1).to(refer.dtype) + if self.version == "v1": + ge = self.ref_enc(refer * refer_mask, refer_mask) + else: + ge = self.ref_enc(refer[:, :704] * refer_mask, refer_mask) + if self.is_v2pro: + sv_emb = self.sv_emb(sv_emb) # B*20480->B*512 + ge += sv_emb.unsqueeze(-1) + ge = self.prelu(ge) + return ge + + if type(refer) == list: + ges = [] + for idx, _refer in enumerate(refer): + ge = get_ge(_refer, sv_emb[idx] if self.is_v2pro else None) + ges.append(ge) + ge = torch.stack(ges, 0).mean(0) + else: + ge = get_ge(refer, sv_emb) + + y_lengths = torch.LongTensor([codes.size(2) * 2]).to(codes.device) + text_lengths = torch.LongTensor([text.size(-1)]).to(text.device) + + quantized = self.quantizer.decode(codes) + if self.semantic_frame_rate == "25hz": + quantized = F.interpolate(quantized, size=int(quantized.shape[-1] * 2), mode="nearest") + result_length = (2*result_length) if result_length is not None else None + padding_length = (2*padding_length) if padding_length is not None else None + x, m_p, logs_p, y_mask, y_, y_mask_ = self.enc_p( + quantized, + y_lengths, + text, + text_lengths, + self.ge_to512(ge.transpose(2, 1)).transpose(2, 1) if self.is_v2pro else ge, + speed, + result_length=result_length, + overlap_frames=overlap_frames, + padding_length=padding_length + ) + z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale + + z = self.flow(z_p, y_mask, g=ge, reverse=True) + + o = self.dec((z * y_mask)[:, :, :], g=ge) + return o, y_, y_mask_ + def extract_latent(self, x): ssl = self.ssl_proj(x) quantized, codes, commit_loss, quantized_list = self.quantizer(ssl) @@ -1226,7 +1313,7 @@ class SynthesizerTrnV3(nn.Module): ssl = self.ssl_proj(ssl) quantized, codes, commit_loss, quantized_list = self.quantizer(ssl, layers=[0]) quantized = F.interpolate(quantized, scale_factor=2, mode="nearest") ##BCT - x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge) + x, m_p, logs_p, y_mask, y_, y_mask_ = self.enc_p(quantized, y_lengths, text, text_lengths, ge) fea = self.bridge(x) fea = F.interpolate(fea, scale_factor=(1.875 if self.version == "v3" else 2), mode="nearest") ##BCT fea, y_mask_ = self.wns1( @@ -1260,7 +1347,7 @@ class SynthesizerTrnV3(nn.Module): quantized = self.quantizer.decode(codes) if self.semantic_frame_rate == "25hz": quantized = F.interpolate(quantized, scale_factor=2, mode="nearest") ##BCT - x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge, speed) + x, m_p, logs_p, y_mask, _, _ = self.enc_p(quantized, y_lengths, text, text_lengths, ge, speed) fea = self.bridge(x) fea = F.interpolate(fea, scale_factor=(1.875 if self.version == "v3" else 2), mode="nearest") ##BCT ####more wn paramter to learn mel @@ -1377,7 +1464,7 @@ class SynthesizerTrnV3b(nn.Module): ssl = self.ssl_proj(ssl) quantized, codes, commit_loss, quantized_list = self.quantizer(ssl, layers=[0]) quantized = F.interpolate(quantized, scale_factor=2, mode="nearest") ##BCT - x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge) + x, m_p, logs_p, y_mask, y_, y_mask_ = self.enc_p(quantized, y_lengths, text, text_lengths, ge) z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=ge) z_p = self.flow(z, y_mask, g=ge) z_slice, ids_slice = commons.rand_slice_segments(z, y_lengths, self.segment_size) @@ -1420,7 +1507,7 @@ class SynthesizerTrnV3b(nn.Module): quantized = self.quantizer.decode(codes) if self.semantic_frame_rate == "25hz": quantized = F.interpolate(quantized, scale_factor=2, mode="nearest") ##BCT - x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge) + x, m_p, logs_p, y_mask, y_, y_mask_ = self.enc_p(quantized, y_lengths, text, text_lengths, ge) fea = self.bridge(x) fea = F.interpolate(fea, scale_factor=1.875, mode="nearest") ##BCT ####more wn paramter to learn mel diff --git a/api_v2.py b/api_v2.py index 5947df53..5df2da66 100644 --- a/api_v2.py +++ b/api_v2.py @@ -30,17 +30,22 @@ POST: "top_k": 5, # int. top k sampling "top_p": 1, # float. top p sampling "temperature": 1, # float. temperature for sampling - "text_split_method": "cut0", # str. text split method, see text_segmentation_method.py for details. + "text_split_method": "cut5", # str. text split method, see text_segmentation_method.py for details. "batch_size": 1, # int. batch size for inference "batch_threshold": 0.75, # float. threshold for batch splitting. "split_bucket": True, # bool. whether to split the batch into multiple buckets. "speed_factor":1.0, # float. control the speed of the synthesized audio. - "streaming_mode": False, # bool. whether to return a streaming response. + "fragment_interval":0.3, # float. to control the interval of the audio fragment. "seed": -1, # int. random seed for reproducibility. "parallel_infer": True, # bool. whether to use parallel inference. "repetition_penalty": 1.35, # float. repetition penalty for T2S model. "sample_steps": 32, # int. number of sampling steps for VITS model V3. - "super_sampling": False # bool. whether to use super-sampling for audio when using VITS model V3. + "super_sampling": False, # bool. whether to use super-sampling for audio when using VITS model V3. + "return_fragment": False, # bool. step by step return the audio fragment. (Best Quality, Slowest response speed. old version of streaming mode) + "streaming_mode": False, # bool. return audio chunk by chunk. (Medium quality, Slow response speed) + "overlap_length": 2, # int. overlap length of semantic tokens for streaming mode. + "min_chunk_length": 16, # int. The minimum chunk length of semantic tokens for streaming mode. (affects audio chunk size) + "fixed_length_chunk": False, # bool. When turned on, it can achieve faster streaming response, but with lower quality. (lower quality, faster response speed) } ``` @@ -121,6 +126,7 @@ from tools.i18n.i18n import I18nAuto from GPT_SoVITS.TTS_infer_pack.TTS import TTS, TTS_Config from GPT_SoVITS.TTS_infer_pack.text_segmentation_method import get_method_names as get_cut_method_names from pydantic import BaseModel +import threading # print(sys.path) i18n = I18nAuto() @@ -170,12 +176,55 @@ class TTS_Request(BaseModel): repetition_penalty: float = 1.35 sample_steps: int = 32 super_sampling: bool = False + overlap_length: int = 2 + min_chunk_length: int = 16 + return_fragment: bool = False + fixed_length_chunk: bool = False -### modify from https://github.com/RVC-Boss/GPT-SoVITS/pull/894/files def pack_ogg(io_buffer: BytesIO, data: np.ndarray, rate: int): - with sf.SoundFile(io_buffer, mode="w", samplerate=rate, channels=1, format="ogg") as audio_file: - audio_file.write(data) + # Author: AkagawaTsurunaki + # Issue: + # Stack overflow probabilistically occurs + # when the function `sf_writef_short` of `libsndfile_64bit.dll` is called + # using the Python library `soundfile` + # Note: + # This is an issue related to `libsndfile`, not this project itself. + # It happens when you generate a large audio tensor (about 499804 frames in my PC) + # and try to convert it to an ogg file. + # Related: + # https://github.com/RVC-Boss/GPT-SoVITS/issues/1199 + # https://github.com/libsndfile/libsndfile/issues/1023 + # https://github.com/bastibe/python-soundfile/issues/396 + # Suggestion: + # Or split the whole audio data into smaller audio segment to avoid stack overflow? + + def handle_pack_ogg(): + with sf.SoundFile(io_buffer, mode="w", samplerate=rate, channels=1, format="ogg") as audio_file: + audio_file.write(data) + + + + # See: https://docs.python.org/3/library/threading.html + # The stack size of this thread is at least 32768 + # If stack overflow error still occurs, just modify the `stack_size`. + # stack_size = n * 4096, where n should be a positive integer. + # Here we chose n = 4096. + stack_size = 4096 * 4096 + try: + threading.stack_size(stack_size) + pack_ogg_thread = threading.Thread(target=handle_pack_ogg) + pack_ogg_thread.start() + pack_ogg_thread.join() + except RuntimeError as e: + # If changing the thread stack size is unsupported, a RuntimeError is raised. + print("RuntimeError: {}".format(e)) + print("Changing the thread stack size is unsupported.") + except ValueError as e: + # If the specified stack size is invalid, a ValueError is raised and the stack size is unmodified. + print("ValueError: {}".format(e)) + print("The specified stack size is invalid.") + return io_buffer @@ -286,8 +335,8 @@ def check_params(req: dict): ) if media_type not in ["wav", "raw", "ogg", "aac"]: return JSONResponse(status_code=400, content={"message": f"media_type: {media_type} is not supported"}) - elif media_type == "ogg" and not streaming_mode: - return JSONResponse(status_code=400, content={"message": "ogg format is not supported in non-streaming mode"}) + # elif media_type == "ogg" and not streaming_mode: + # return JSONResponse(status_code=400, content={"message": "ogg format is not supported in non-streaming mode"}) if text_split_method not in cut_method_names: return JSONResponse( @@ -307,7 +356,7 @@ async def tts_handle(req: dict): "text": "", # str.(required) text to be synthesized "text_lang: "", # str.(required) language of the text to be synthesized "ref_audio_path": "", # str.(required) reference audio path - "aux_ref_audio_paths": [], # list.(optional) auxiliary reference audio paths for multi-speaker synthesis + "aux_ref_audio_paths": [], # list.(optional) auxiliary reference audio paths for multi-speaker tone fusion "prompt_text": "", # str.(optional) prompt text for the reference audio "prompt_lang": "", # str.(required) language of the prompt text for the reference audio "top_k": 5, # int. top k sampling @@ -316,16 +365,19 @@ async def tts_handle(req: dict): "text_split_method": "cut5", # str. text split method, see text_segmentation_method.py for details. "batch_size": 1, # int. batch size for inference "batch_threshold": 0.75, # float. threshold for batch splitting. - "split_bucket: True, # bool. whether to split the batch into multiple buckets. + "split_bucket": True, # bool. whether to split the batch into multiple buckets. "speed_factor":1.0, # float. control the speed of the synthesized audio. "fragment_interval":0.3, # float. to control the interval of the audio fragment. "seed": -1, # int. random seed for reproducibility. - "media_type": "wav", # str. media type of the output audio, support "wav", "raw", "ogg", "aac". - "streaming_mode": False, # bool. whether to return a streaming response. - "parallel_infer": True, # bool.(optional) whether to use parallel inference. - "repetition_penalty": 1.35 # float.(optional) repetition penalty for T2S model. + "parallel_infer": True, # bool. whether to use parallel inference. + "repetition_penalty": 1.35, # float. repetition penalty for T2S model. "sample_steps": 32, # int. number of sampling steps for VITS model V3. - "super_sampling": False, # bool. whether to use super-sampling for audio when using VITS model V3. + "super_sampling": False, # bool. whether to use super-sampling for audio when using VITS model V3. + "return_fragment": False, # bool. step by step return the audio fragment. (Best Quality, Slowest response speed. old version of streaming mode) + "streaming_mode": False, # bool. return audio chunk by chunk. (Medium quality, Slow response speed) + "overlap_length": 2, # int. overlap length of semantic tokens for streaming mode. + "min_chunk_length": 16, # int. The minimum chunk length of semantic tokens for streaming mode. (affects audio chunk size) + "fixed_length_chunk": False, # bool. When turned on, it can achieve faster streaming response, but with lower quality. (lower quality, faster response speed) } returns: StreamingResponse: audio stream response. @@ -339,8 +391,10 @@ async def tts_handle(req: dict): if check_res is not None: return check_res - if streaming_mode or return_fragment: - req["return_fragment"] = True + req["streaming_mode"] = streaming_mode + req["return_fragment"] = return_fragment + streaming_mode = streaming_mode or return_fragment + try: tts_generator = tts_pipeline.run(req) @@ -391,7 +445,7 @@ async def tts_get_endpoint( top_k: int = 5, top_p: float = 1, temperature: float = 1, - text_split_method: str = "cut0", + text_split_method: str = "cut5", batch_size: int = 1, batch_threshold: float = 0.75, split_bucket: bool = True, @@ -399,11 +453,15 @@ async def tts_get_endpoint( fragment_interval: float = 0.3, seed: int = -1, media_type: str = "wav", - streaming_mode: bool = False, parallel_infer: bool = True, repetition_penalty: float = 1.35, sample_steps: int = 32, super_sampling: bool = False, + return_fragment: bool = False, + streaming_mode: bool = False, + overlap_length: int = 2, + min_chunk_length: int = 16, + fixed_length_chunk: bool = False, ): req = { "text": text, @@ -428,6 +486,10 @@ async def tts_get_endpoint( "repetition_penalty": float(repetition_penalty), "sample_steps": int(sample_steps), "super_sampling": super_sampling, + "overlap_length": int(overlap_length), + "min_chunk_length": int(min_chunk_length), + "return_fragment": return_fragment, + "fixed_length_chunk": fixed_length_chunk } return await tts_handle(req)