From fdf794e31d1fd6f91c5cb4fbb0396094491a31ac Mon Sep 17 00:00:00 2001 From: XXXXRT666 <157766680+XXXXRT666@users.noreply.github.com> Date: Sat, 2 Aug 2025 17:47:15 +0800 Subject: [PATCH 01/20] Update WSL Rocm (#2561) --- install.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/install.sh b/install.sh index a2fa751e..7d80ec28 100644 --- a/install.sh +++ b/install.sh @@ -373,7 +373,7 @@ if [ "$USE_ROCM" = true ] && [ "$IS_WSL" = true ]; then location=$(pip show torch | grep Location | awk -F ": " '{print $2}') cd "${location}"/torch/lib/ || exit rm libhsa-runtime64.so* - cp /opt/rocm/lib/libhsa-runtime64.so.1.2 libhsa-runtime64.so + cp "$(readlink -f /opt/rocm/lib/libhsa-runtime64.so)" libhsa-runtime64.so echo -e "${SUCCESS}ROCm Runtime Lib Updated..." fi From 11aa78bd9bda8b53047cfcae03abf7ca94d27391 Mon Sep 17 00:00:00 2001 From: RVC-Boss <129054828+RVC-Boss@users.noreply.github.com> Date: Wed, 10 Sep 2025 15:01:04 +0800 Subject: [PATCH 02/20] =?UTF-8?q?=E4=BF=AE=E5=A4=8D=E7=8E=AF=E5=A2=83?= =?UTF-8?q?=E5=8F=98=E9=87=8F=E5=8F=AF=E8=83=BD=E4=B8=8D=E4=B8=BAstr?= =?UTF-8?q?=E7=9A=84=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 修复环境变量可能不为str的问题 --- webui.py | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/webui.py b/webui.py index 9a6aae5f..cf5d8a3a 100644 --- a/webui.py +++ b/webui.py @@ -343,7 +343,7 @@ def change_tts_inference(bert_path, cnhubert_base_path, gpu_number, gpt_path, so os.environ["sovits_path"] = sovits_path os.environ["cnhubert_base_path"] = cnhubert_base_path os.environ["bert_path"] = bert_path - os.environ["_CUDA_VISIBLE_DEVICES"] = fix_gpu_number(gpu_number) + os.environ["_CUDA_VISIBLE_DEVICES"] = str(fix_gpu_number(gpu_number)) os.environ["is_half"] = str(is_half) os.environ["infer_ttswebui"] = str(webui_port_infer_tts) os.environ["is_share"] = str(is_share) @@ -628,7 +628,7 @@ def open1Bb( data["output_dir"] = "%s/logs_s1_%s" % (s1_dir, version) # data["version"]=version - os.environ["_CUDA_VISIBLE_DEVICES"] = fix_gpu_numbers(gpu_numbers.replace("-", ",")) + os.environ["_CUDA_VISIBLE_DEVICES"] = str(fix_gpu_numbers(gpu_numbers.replace("-", ","))) os.environ["hz"] = "25hz" tmp_config_path = "%s/tmp_s1.yaml" % tmp with open(tmp_config_path, "w") as f: @@ -801,7 +801,7 @@ def open1a(inp_text, inp_wav_dir, exp_name, gpu_numbers, bert_pretrained_dir): { "i_part": str(i_part), "all_parts": str(all_parts), - "_CUDA_VISIBLE_DEVICES": fix_gpu_number(gpu_names[i_part]), + "_CUDA_VISIBLE_DEVICES": str(fix_gpu_number(gpu_names[i_part])), "is_half": str(is_half), } ) @@ -892,7 +892,7 @@ def open1b(version, inp_text, inp_wav_dir, exp_name, gpu_numbers, ssl_pretrained { "i_part": str(i_part), "all_parts": str(all_parts), - "_CUDA_VISIBLE_DEVICES": fix_gpu_number(gpu_names[i_part]), + "_CUDA_VISIBLE_DEVICES": str(fix_gpu_number(gpu_names[i_part])), } ) os.environ.update(config) @@ -914,7 +914,7 @@ def open1b(version, inp_text, inp_wav_dir, exp_name, gpu_numbers, ssl_pretrained { "i_part": str(i_part), "all_parts": str(all_parts), - "_CUDA_VISIBLE_DEVICES": fix_gpu_number(gpu_names[i_part]), + "_CUDA_VISIBLE_DEVICES": str(fix_gpu_number(gpu_names[i_part])), } ) os.environ.update(config) @@ -986,7 +986,7 @@ def open1c(version, inp_text, inp_wav_dir, exp_name, gpu_numbers, pretrained_s2G { "i_part": str(i_part), "all_parts": str(all_parts), - "_CUDA_VISIBLE_DEVICES": fix_gpu_number(gpu_names[i_part]), + "_CUDA_VISIBLE_DEVICES": str(fix_gpu_number(gpu_names[i_part])), } ) os.environ.update(config) @@ -1086,7 +1086,7 @@ def open1abc( { "i_part": str(i_part), "all_parts": str(all_parts), - "_CUDA_VISIBLE_DEVICES": fix_gpu_number(gpu_names[i_part]), + "_CUDA_VISIBLE_DEVICES": str(fix_gpu_number(gpu_names[i_part])), } ) os.environ.update(config) @@ -1133,7 +1133,7 @@ def open1abc( { "i_part": str(i_part), "all_parts": str(all_parts), - "_CUDA_VISIBLE_DEVICES": fix_gpu_number(gpu_names[i_part]), + "_CUDA_VISIBLE_DEVICES": str(fix_gpu_number(gpu_names[i_part])), } ) os.environ.update(config) @@ -1155,7 +1155,7 @@ def open1abc( { "i_part": str(i_part), "all_parts": str(all_parts), - "_CUDA_VISIBLE_DEVICES": fix_gpu_number(gpu_names[i_part]), + "_CUDA_VISIBLE_DEVICES": str(fix_gpu_number(gpu_names[i_part])), } ) os.environ.update(config) @@ -1195,7 +1195,7 @@ def open1abc( { "i_part": str(i_part), "all_parts": str(all_parts), - "_CUDA_VISIBLE_DEVICES": fix_gpu_number(gpu_names[i_part]), + "_CUDA_VISIBLE_DEVICES": str(fix_gpu_number(gpu_names[i_part])), } ) os.environ.update(config) @@ -1980,3 +1980,4 @@ with gr.Blocks(title="GPT-SoVITS WebUI", analytics_enabled=False, js=js, css=css server_port=webui_port_main, # quiet=True, ) + From 92ab59c5533a5dea368ddb8dad89e14474307145 Mon Sep 17 00:00:00 2001 From: ChasonJiang <46401978+ChasonJiang@users.noreply.github.com> Date: Fri, 28 Nov 2025 21:12:41 +0800 Subject: [PATCH 03/20] =?UTF-8?q?=E6=9B=B4=E7=BB=86=E7=B2=92=E5=BA=A6?= =?UTF-8?q?=E7=9A=84=E6=B5=81=E5=BC=8F=E6=8E=A8=E7=90=86=E6=A8=A1=E5=BC=8F?= =?UTF-8?q?=20(#2671)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * 更好的流式推理模式 * 清理无用代码 * modified: GPT_SoVITS/AR/models/t2s_model.py modified: GPT_SoVITS/TTS_infer_pack/TTS.py modified: GPT_SoVITS/module/models.py * modified: GPT_SoVITS/TTS_infer_pack/TTS.py * modified: .gitignore modified: GPT_SoVITS/AR/models/t2s_model.py modified: GPT_SoVITS/TTS_infer_pack/TTS.py modified: GPT_SoVITS/module/models.py * modified: GPT_SoVITS/AR/models/t2s_model.py modified: GPT_SoVITS/TTS_infer_pack/TTS.py modified: GPT_SoVITS/module/models.py modified: api_v2.py * modified: GPT_SoVITS/TTS_infer_pack/TTS.py * 更正拼写错误 * 支持固定chunk长度的流式推理,优化sola算法 * 修复api_v2的ogg格式传输问题 --- GPT_SoVITS/AR/models/t2s_model.py | 60 +++- GPT_SoVITS/TTS_infer_pack/TTS.py | 449 +++++++++++++++++++++--------- GPT_SoVITS/module/models.py | 105 ++++++- api_v2.py | 100 +++++-- 4 files changed, 548 insertions(+), 166 deletions(-) diff --git a/GPT_SoVITS/AR/models/t2s_model.py b/GPT_SoVITS/AR/models/t2s_model.py index 7196d6ab..0caadd04 100644 --- a/GPT_SoVITS/AR/models/t2s_model.py +++ b/GPT_SoVITS/AR/models/t2s_model.py @@ -794,7 +794,7 @@ class Text2SemanticDecoder(nn.Module): y_list = [] idx_list = [] for i in range(len(x)): - y, idx = self.infer_panel_naive( + y, idx = next(self.infer_panel_naive( x[i].unsqueeze(0), x_lens[i], prompts[i].unsqueeze(0) if prompts is not None else None, @@ -805,7 +805,7 @@ class Text2SemanticDecoder(nn.Module): temperature, repetition_penalty, **kwargs, - ) + )) y_list.append(y[0]) idx_list.append(idx) @@ -822,8 +822,15 @@ class Text2SemanticDecoder(nn.Module): early_stop_num: int = -1, temperature: float = 1.0, repetition_penalty: float = 1.35, + streaming_mode: bool = False, + chunk_length: int = 24, **kwargs, ): + mute_emb_sim_matrix = kwargs.get("mute_emb_sim_matrix", None) + chunk_split_thershold = kwargs.get("chunk_split_thershold", 0.3) + check_token_num = 2 + + x = self.ar_text_embedding(x) x = x + self.bert_proj(bert_feature.transpose(1, 2)) x = self.ar_text_position(x) @@ -875,7 +882,10 @@ class Text2SemanticDecoder(nn.Module): .to(device=x.device, dtype=torch.bool) ) + token_counter = 0 + curr_ptr = prefix_len for idx in tqdm(range(1500)): + token_counter+=1 if xy_attn_mask is not None: xy_dec, k_cache, v_cache = self.t2s_transformer.process_prompt(xy_pos, xy_attn_mask, None) else: @@ -900,22 +910,56 @@ class Text2SemanticDecoder(nn.Module): if torch.argmax(logits, dim=-1)[0] == self.EOS or samples[0, 0] == self.EOS: stop = True + y=y[:, :-1] + token_counter -= 1 + + if idx == 1499: + stop = True + if stop: if y.shape[1] == 0: y = torch.concat([y, torch.zeros_like(samples)], dim=1) print("bad zero prediction") - print(f"T2S Decoding EOS [{prefix_len} -> {y.shape[1]}]") + # print(f"T2S Decoding EOS [{prefix_len} -> {y.shape[1]}]") + if streaming_mode: + yield y[:, curr_ptr:] if curr_ptr= chunk_length+check_token_num): + score = mute_emb_sim_matrix[y[0, curr_ptr:]] - chunk_split_thershold + score[score<0]=-1 + score[:-1]=score[:-1]+score[1:] ##考虑连续两个token + argmax_idx = score.argmax() + + if score[argmax_idx]>=0 and argmax_idx+1>=chunk_length: + print(f"\n\ncurr_ptr:{curr_ptr}") + yield y[:, curr_ptr:], False + token_counter -= argmax_idx+1 + curr_ptr += argmax_idx+1 + + + elif streaming_mode and (mute_emb_sim_matrix is None) and (token_counter >= chunk_length): + yield y[:, -token_counter:], False + curr_ptr+=token_counter + token_counter = 0 + + + ####################### update next step ################################### y_emb = self.ar_audio_embedding(y[:, -1:]) xy_pos = y_emb * self.ar_audio_position.x_scale + self.ar_audio_position.alpha * self.ar_audio_position.pe[ :, y_len + idx ].to(dtype=y_emb.dtype, device=y_emb.device) - if ref_free: - return y[:, :-1], 0 - return y[:, :-1], idx + + + if not streaming_mode: + if ref_free: + yield y, 0 + yield y, idx + + def infer_panel( self, @@ -930,6 +974,6 @@ class Text2SemanticDecoder(nn.Module): repetition_penalty: float = 1.35, **kwargs, ): - return self.infer_panel_naive( + return next(self.infer_panel_naive( x, x_lens, prompts, bert_feature, top_k, top_p, early_stop_num, temperature, repetition_penalty, **kwargs - ) + )) diff --git a/GPT_SoVITS/TTS_infer_pack/TTS.py b/GPT_SoVITS/TTS_infer_pack/TTS.py index 0c1d2484..be3d3a19 100644 --- a/GPT_SoVITS/TTS_infer_pack/TTS.py +++ b/GPT_SoVITS/TTS_infer_pack/TTS.py @@ -275,6 +275,15 @@ class TTS_Config: v1_languages: list = ["auto", "en", "zh", "ja", "all_zh", "all_ja"] v2_languages: list = ["auto", "auto_yue", "en", "zh", "ja", "yue", "ko", "all_zh", "all_ja", "all_yue", "all_ko"] languages: list = v2_languages + mute_tokens: dict = { + "v1" : 486, + "v2" : 486, + "v2Pro": 486, + "v2ProPlus": 486, + "v3" : 486, + "v4" : 486, + } + mute_emb_sim_matrix: torch.Tensor = None # "all_zh",#全部按中文识别 # "en",#全部按英文识别#######不变 # "all_ja",#全部按日文识别 @@ -598,6 +607,11 @@ class TTS: if self.configs.is_half and str(self.configs.device) != "cpu": self.t2s_model = self.t2s_model.half() + codebook = t2s_model.model.ar_audio_embedding.weight.clone() + mute_emb = codebook[self.configs.mute_tokens[self.configs.version]].unsqueeze(0) + sim_matrix = F.cosine_similarity(mute_emb.float(), codebook.float(), dim=-1) + self.configs.mute_emb_sim_matrix = sim_matrix + def init_vocoder(self, version: str): if version == "v3": if self.vocoder is not None and self.vocoder.__class__.__name__ == "BigVGAN": @@ -997,18 +1011,22 @@ class TTS: "top_k": 5, # int. top k sampling "top_p": 1, # float. top p sampling "temperature": 1, # float. temperature for sampling - "text_split_method": "cut0", # str. text split method, see text_segmentation_method.py for details. + "text_split_method": "cut1", # str. text split method, see text_segmentation_method.py for details. "batch_size": 1, # int. batch size for inference "batch_threshold": 0.75, # float. threshold for batch splitting. - "split_bucket: True, # bool. whether to split the batch into multiple buckets. - "return_fragment": False, # bool. step by step return the audio fragment. + "split_bucket": True, # bool. whether to split the batch into multiple buckets. "speed_factor":1.0, # float. control the speed of the synthesized audio. "fragment_interval":0.3, # float. to control the interval of the audio fragment. "seed": -1, # int. random seed for reproducibility. "parallel_infer": True, # bool. whether to use parallel inference. - "repetition_penalty": 1.35 # float. repetition penalty for T2S model. + "repetition_penalty": 1.35, # float. repetition penalty for T2S model. "sample_steps": 32, # int. number of sampling steps for VITS model V3. - "super_sampling": False, # bool. whether to use super-sampling for audio when using VITS model V3. + "super_sampling": False, # bool. whether to use super-sampling for audio when using VITS model V3. + "return_fragment": False, # bool. step by step return the audio fragment. (Best Quality, Slowest response speed. old version of streaming mode) + "streaming_mode": False, # bool. return audio chunk by chunk. (Medium quality, Slow response speed) + "overlap_length": 2, # int. overlap length of semantic tokens for streaming mode. + "min_chunk_length": 16, # int. The minimum chunk length of semantic tokens for streaming mode. (affects audio chunk size) + "fixed_length_chunk": False, # bool. When turned on, it can achieve faster streaming response, but with lower quality. (lower quality, faster response speed) } returns: Tuple[int, np.ndarray]: sampling rate and audio data. @@ -1024,7 +1042,7 @@ class TTS: top_k: int = inputs.get("top_k", 5) top_p: float = inputs.get("top_p", 1) temperature: float = inputs.get("temperature", 1) - text_split_method: str = inputs.get("text_split_method", "cut0") + text_split_method: str = inputs.get("text_split_method", "cut1") batch_size = inputs.get("batch_size", 1) batch_threshold = inputs.get("batch_threshold", 0.75) speed_factor = inputs.get("speed_factor", 1.0) @@ -1038,19 +1056,43 @@ class TTS: repetition_penalty = inputs.get("repetition_penalty", 1.35) sample_steps = inputs.get("sample_steps", 32) super_sampling = inputs.get("super_sampling", False) + streaming_mode = inputs.get("streaming_mode", False) + overlap_length = inputs.get("overlap_length", 2) + min_chunk_length = inputs.get("min_chunk_length", 16) + fixed_length_chunk = inputs.get("fixed_length_chunk", False) + chunk_split_thershold = 0.0 # 该值代表语义token与mute token的余弦相似度阈值,若大于该阈值,则视为可切分点。 - if parallel_infer: + if parallel_infer and not streaming_mode: print(i18n("并行推理模式已开启")) self.t2s_model.model.infer_panel = self.t2s_model.model.infer_panel_batch_infer + elif not parallel_infer and streaming_mode and not self.configs.use_vocoder: + print(i18n("流式推理模式已开启")) + self.t2s_model.model.infer_panel = self.t2s_model.model.infer_panel_naive + elif streaming_mode and self.configs.use_vocoder: + print(i18n("SoVits V3/4模型不支持流式推理模式,已自动回退到分段返回模式")) + streaming_mode = False + return_fragment = True + if parallel_infer: + self.t2s_model.model.infer_panel = self.t2s_model.model.infer_panel_batch_infer + else: + self.t2s_model.model.infer_panel = self.t2s_model.model.infer_panel_naive_batched + # self.t2s_model.model.infer_panel = self.t2s_model.model.infer_panel_naive + elif parallel_infer and streaming_mode: + print(i18n("不支持同时开启并行推理和流式推理模式,已自动关闭并行推理模式")) + parallel_infer = False + self.t2s_model.model.infer_panel = self.t2s_model.model.infer_panel_naive else: - print(i18n("并行推理模式已关闭")) + print(i18n("朴素推理模式已开启")) self.t2s_model.model.infer_panel = self.t2s_model.model.infer_panel_naive_batched - if return_fragment: - print(i18n("分段返回模式已开启")) - if split_bucket: - split_bucket = False - print(i18n("分段返回模式不支持分桶处理,已自动关闭分桶处理")) + if return_fragment and streaming_mode: + print(i18n("流式推理模式不支持分段返回,已自动关闭分段返回")) + return_fragment = False + + if (return_fragment or streaming_mode) and split_bucket: + print(i18n("分段返回模式/流式推理模式不支持分桶处理,已自动关闭分桶处理")) + split_bucket = False + if split_bucket and speed_factor == 1.0 and not (self.configs.use_vocoder and parallel_infer): print(i18n("分桶处理模式已开启")) @@ -1063,9 +1105,9 @@ class TTS: else: print(i18n("分桶处理模式已关闭")) - if fragment_interval < 0.01: - fragment_interval = 0.01 - print(i18n("分段间隔过小,已自动设置为0.01")) + # if fragment_interval < 0.01: + # fragment_interval = 0.01 + # print(i18n("分段间隔过小,已自动设置为0.01")) no_prompt_text = False if prompt_text in [None, ""]: @@ -1126,7 +1168,7 @@ class TTS: ###### text preprocessing ######## t1 = time.perf_counter() data: list = None - if not return_fragment: + if not (return_fragment or streaming_mode): data = self.text_preprocessor.preprocess(text, text_lang, text_split_method, self.configs.version) if len(data) == 0: yield 16000, np.zeros(int(16000), dtype=np.int16) @@ -1186,10 +1228,11 @@ class TTS: t_34 = 0.0 t_45 = 0.0 audio = [] + is_first_package = True output_sr = self.configs.sampling_rate if not self.configs.use_vocoder else self.vocoder_configs["sr"] for item in data: t3 = time.perf_counter() - if return_fragment: + if return_fragment or streaming_mode: item = make_batch(item) if item is None: continue @@ -1211,108 +1254,228 @@ class TTS: self.prompt_cache["prompt_semantic"].expand(len(all_phoneme_ids), -1).to(self.configs.device) ) - print(f"############ {i18n('预测语义Token')} ############") - pred_semantic_list, idx_list = self.t2s_model.model.infer_panel( - all_phoneme_ids, - all_phoneme_lens, - prompt, - all_bert_features, - # prompt_phone_len=ph_offset, - top_k=top_k, - top_p=top_p, - temperature=temperature, - early_stop_num=self.configs.hz * self.configs.max_sec, - max_len=max_len, - repetition_penalty=repetition_penalty, - ) - t4 = time.perf_counter() - t_34 += t4 - t3 - refer_audio_spec = [] - if self.is_v2pro: - sv_emb = [] + + sv_emb = [] if self.is_v2pro else None for spec, audio_tensor in self.prompt_cache["refer_spec"]: spec = spec.to(dtype=self.precision, device=self.configs.device) refer_audio_spec.append(spec) if self.is_v2pro: sv_emb.append(self.sv_model.compute_embedding3(audio_tensor)) - batch_audio_fragment = [] + if not streaming_mode: + print(f"############ {i18n('预测语义Token')} ############") + pred_semantic_list, idx_list = self.t2s_model.model.infer_panel( + all_phoneme_ids, + all_phoneme_lens, + prompt, + all_bert_features, + # prompt_phone_len=ph_offset, + top_k=top_k, + top_p=top_p, + temperature=temperature, + early_stop_num=self.configs.hz * self.configs.max_sec, + max_len=max_len, + repetition_penalty=repetition_penalty, + ) + t4 = time.perf_counter() + t_34 += t4 - t3 - # ## vits并行推理 method 1 - # pred_semantic_list = [item[-idx:] for item, idx in zip(pred_semantic_list, idx_list)] - # pred_semantic_len = torch.LongTensor([item.shape[0] for item in pred_semantic_list]).to(self.configs.device) - # pred_semantic = self.batch_sequences(pred_semantic_list, axis=0, pad_value=0).unsqueeze(0) - # max_len = 0 - # for i in range(0, len(batch_phones)): - # max_len = max(max_len, batch_phones[i].shape[-1]) - # batch_phones = self.batch_sequences(batch_phones, axis=0, pad_value=0, max_length=max_len) - # batch_phones = batch_phones.to(self.configs.device) - # batch_audio_fragment = (self.vits_model.batched_decode( - # pred_semantic, pred_semantic_len, batch_phones, batch_phones_len,refer_audio_spec - # )) - print(f"############ {i18n('合成音频')} ############") - if not self.configs.use_vocoder: - if speed_factor == 1.0: - print(f"{i18n('并行合成中')}...") - # ## vits并行推理 method 2 - pred_semantic_list = [item[-idx:] for item, idx in zip(pred_semantic_list, idx_list)] - upsample_rate = math.prod(self.vits_model.upsample_rates) - audio_frag_idx = [ - pred_semantic_list[i].shape[0] * 2 * upsample_rate - for i in range(0, len(pred_semantic_list)) - ] - audio_frag_end_idx = [sum(audio_frag_idx[: i + 1]) for i in range(0, len(audio_frag_idx))] - all_pred_semantic = ( - torch.cat(pred_semantic_list).unsqueeze(0).unsqueeze(0).to(self.configs.device) - ) - _batch_phones = torch.cat(batch_phones).unsqueeze(0).to(self.configs.device) - if self.is_v2pro != True: - _batch_audio_fragment = self.vits_model.decode( - all_pred_semantic, _batch_phones, refer_audio_spec, speed=speed_factor - ).detach()[0, 0, :] - else: - _batch_audio_fragment = self.vits_model.decode( - all_pred_semantic, _batch_phones, refer_audio_spec, speed=speed_factor, sv_emb=sv_emb - ).detach()[0, 0, :] - audio_frag_end_idx.insert(0, 0) - batch_audio_fragment = [ - _batch_audio_fragment[audio_frag_end_idx[i - 1] : audio_frag_end_idx[i]] - for i in range(1, len(audio_frag_end_idx)) - ] - else: - # ## vits串行推理 - for i, idx in enumerate(tqdm(idx_list)): - phones = batch_phones[i].unsqueeze(0).to(self.configs.device) - _pred_semantic = ( - pred_semantic_list[i][-idx:].unsqueeze(0).unsqueeze(0) - ) # .unsqueeze(0)#mq要多unsqueeze一次 - if self.is_v2pro != True: - audio_fragment = self.vits_model.decode( - _pred_semantic, phones, refer_audio_spec, speed=speed_factor - ).detach()[0, 0, :] - else: - audio_fragment = self.vits_model.decode( - _pred_semantic, phones, refer_audio_spec, speed=speed_factor, sv_emb=sv_emb - ).detach()[0, 0, :] - batch_audio_fragment.append(audio_fragment) ###试试重建不带上prompt部分 - else: - if parallel_infer: - print(f"{i18n('并行合成中')}...") - audio_fragments = self.using_vocoder_synthesis_batched_infer( - idx_list, pred_semantic_list, batch_phones, speed=speed_factor, sample_steps=sample_steps - ) - batch_audio_fragment.extend(audio_fragments) - else: - for i, idx in enumerate(tqdm(idx_list)): - phones = batch_phones[i].unsqueeze(0).to(self.configs.device) - _pred_semantic = ( - pred_semantic_list[i][-idx:].unsqueeze(0).unsqueeze(0) - ) # .unsqueeze(0)#mq要多unsqueeze一次 - audio_fragment = self.using_vocoder_synthesis( - _pred_semantic, phones, speed=speed_factor, sample_steps=sample_steps + + batch_audio_fragment = [] + + # ## vits并行推理 method 1 + # pred_semantic_list = [item[-idx:] for item, idx in zip(pred_semantic_list, idx_list)] + # pred_semantic_len = torch.LongTensor([item.shape[0] for item in pred_semantic_list]).to(self.configs.device) + # pred_semantic = self.batch_sequences(pred_semantic_list, axis=0, pad_value=0).unsqueeze(0) + # max_len = 0 + # for i in range(0, len(batch_phones)): + # max_len = max(max_len, batch_phones[i].shape[-1]) + # batch_phones = self.batch_sequences(batch_phones, axis=0, pad_value=0, max_length=max_len) + # batch_phones = batch_phones.to(self.configs.device) + # batch_audio_fragment = (self.vits_model.batched_decode( + # pred_semantic, pred_semantic_len, batch_phones, batch_phones_len,refer_audio_spec + # )) + print(f"############ {i18n('合成音频')} ############") + if not self.configs.use_vocoder: + if speed_factor == 1.0: + print(f"{i18n('并行合成中')}...") + # ## vits并行推理 method 2 + pred_semantic_list = [item[-idx:] for item, idx in zip(pred_semantic_list, idx_list)] + upsample_rate = math.prod(self.vits_model.upsample_rates) + audio_frag_idx = [ + pred_semantic_list[i].shape[0] * 2 * upsample_rate + for i in range(0, len(pred_semantic_list)) + ] + audio_frag_end_idx = [sum(audio_frag_idx[: i + 1]) for i in range(0, len(audio_frag_idx))] + all_pred_semantic = ( + torch.cat(pred_semantic_list).unsqueeze(0).unsqueeze(0).to(self.configs.device) ) - batch_audio_fragment.append(audio_fragment) + _batch_phones = torch.cat(batch_phones).unsqueeze(0).to(self.configs.device) + + _batch_audio_fragment = self.vits_model.decode( + all_pred_semantic, _batch_phones, refer_audio_spec, speed=speed_factor, sv_emb=sv_emb + ).detach()[0, 0, :] + + audio_frag_end_idx.insert(0, 0) + batch_audio_fragment = [ + _batch_audio_fragment[audio_frag_end_idx[i - 1] : audio_frag_end_idx[i]] + for i in range(1, len(audio_frag_end_idx)) + ] + else: + # ## vits串行推理 + for i, idx in enumerate(tqdm(idx_list)): + phones = batch_phones[i].unsqueeze(0).to(self.configs.device) + _pred_semantic = ( + pred_semantic_list[i][-idx:].unsqueeze(0).unsqueeze(0) + ) # .unsqueeze(0)#mq要多unsqueeze一次 + audio_fragment = self.vits_model.decode( + _pred_semantic, phones, refer_audio_spec, speed=speed_factor, sv_emb=sv_emb + ).detach()[0, 0, :] + batch_audio_fragment.append(audio_fragment) ###试试重建不带上prompt部分 + else: + if parallel_infer: + print(f"{i18n('并行合成中')}...") + audio_fragments = self.using_vocoder_synthesis_batched_infer( + idx_list, pred_semantic_list, batch_phones, speed=speed_factor, sample_steps=sample_steps + ) + batch_audio_fragment.extend(audio_fragments) + else: + for i, idx in enumerate(tqdm(idx_list)): + phones = batch_phones[i].unsqueeze(0).to(self.configs.device) + _pred_semantic = ( + pred_semantic_list[i][-idx:].unsqueeze(0).unsqueeze(0) + ) # .unsqueeze(0)#mq要多unsqueeze一次 + audio_fragment = self.using_vocoder_synthesis( + _pred_semantic, phones, speed=speed_factor, sample_steps=sample_steps + ) + batch_audio_fragment.append(audio_fragment) + + else: + # refer_audio_spec: torch.Tensor = [ + # item.to(dtype=self.precision, device=self.configs.device) + # for item in self.prompt_cache["refer_spec"] + # ] + semantic_token_generator =self.t2s_model.model.infer_panel( + all_phoneme_ids[0].unsqueeze(0), + all_phoneme_lens, + prompt, + all_bert_features[0].unsqueeze(0), + top_k=top_k, + top_p=top_p, + temperature=temperature, + early_stop_num=self.configs.hz * self.configs.max_sec, + max_len=max_len, + repetition_penalty=repetition_penalty, + streaming_mode=True, + chunk_length=min_chunk_length, + mute_emb_sim_matrix=self.configs.mute_emb_sim_matrix if not fixed_length_chunk else None, + chunk_split_thershold=chunk_split_thershold, + ) + t4 = time.perf_counter() + t_34 += t4 - t3 + phones = batch_phones[0].unsqueeze(0).to(self.configs.device) + is_first_chunk = True + + if not self.configs.use_vocoder: + # if speed_factor == 1.0: + # upsample_rate = math.prod(self.vits_model.upsample_rates)*(2 if self.vits_model.semantic_frame_rate == "25hz" else 1) + # else: + upsample_rate = math.prod(self.vits_model.upsample_rates)*((2 if self.vits_model.semantic_frame_rate == "25hz" else 1)/speed_factor) + else: + # if speed_factor == 1.0: + # upsample_rate = self.vocoder_configs["upsample_rate"]*(3.875 if self.configs.version == "v3" else 4) + # else: + upsample_rate = self.vocoder_configs["upsample_rate"]*((3.875 if self.configs.version == "v3" else 4)/speed_factor) + + last_audio_chunk = None + # last_tokens = None + last_latent = None + previous_tokens = [] + overlap_len = overlap_length + overlap_size = math.ceil(overlap_length*upsample_rate) + for semantic_tokens, is_final in semantic_token_generator: + if semantic_tokens is None and last_audio_chunk is not None: + yield self.audio_postprocess( + [[last_audio_chunk[-overlap_size:]]], + output_sr, + None, + speed_factor, + False, + 0.0, + super_sampling if self.configs.use_vocoder and self.configs.version == "v3" else False, + ) + break + + _semantic_tokens = semantic_tokens + print(f"semantic_tokens shape:{semantic_tokens.shape}") + + previous_tokens.append(semantic_tokens) + + _semantic_tokens = torch.cat(previous_tokens, dim=-1) + + if not is_first_chunk and semantic_tokens.shape[-1] < 10: + overlap_len = overlap_length+(10-semantic_tokens.shape[-1]) + else: + overlap_len = overlap_length + + + if not self.configs.use_vocoder: + token_padding_length = 0 + # token_padding_length = int(phones.shape[-1]*2)-_semantic_tokens.shape[-1] + # if token_padding_length>0: + # _semantic_tokens = F.pad(_semantic_tokens, (0, token_padding_length), "constant", 486) + # else: + # token_padding_length = 0 + + audio_chunk, latent, latent_mask = self.vits_model.decode_streaming( + _semantic_tokens.unsqueeze(0), + phones, refer_audio_spec, + speed=speed_factor, + sv_emb=sv_emb, + result_length=semantic_tokens.shape[-1]+overlap_len if not is_first_chunk else None, + overlap_frames=last_latent[:,:,-overlap_len*(2 if self.vits_model.semantic_frame_rate == "25hz" else 1):] \ + if last_latent is not None else None, + padding_length=token_padding_length + ) + audio_chunk=audio_chunk.detach()[0, 0, :] + else: + raise RuntimeError(i18n("SoVits V3/4模型不支持流式推理模式")) + + if overlap_len>overlap_length: + audio_chunk=audio_chunk[-int((overlap_length+semantic_tokens.shape[-1])*upsample_rate):] + + audio_chunk_ = audio_chunk + if is_first_chunk and not is_final: + is_first_chunk = False + audio_chunk_ = audio_chunk_[:-overlap_size] + elif is_first_chunk and is_final: + is_first_chunk = False + elif not is_first_chunk and not is_final: + audio_chunk_ = self.sola_algorithm([last_audio_chunk, audio_chunk_], overlap_size) + audio_chunk_ = ( + audio_chunk_[last_audio_chunk.shape[0]-overlap_size:-overlap_size] if not is_final \ + else audio_chunk_[last_audio_chunk.shape[0]-overlap_size:] + ) + + last_latent = latent + last_audio_chunk = audio_chunk + yield self.audio_postprocess( + [[audio_chunk_]], + output_sr, + None, + speed_factor, + False, + 0.0, + super_sampling if self.configs.use_vocoder and self.configs.version == "v3" else False, + ) + + if is_first_package: + print(f"first_package_delay: {time.perf_counter()-t0:.3f}") + is_first_package = False + + + yield output_sr, np.zeros(int(output_sr*fragment_interval), dtype=np.int16) t5 = time.perf_counter() t_45 += t5 - t4 @@ -1327,17 +1490,18 @@ class TTS: fragment_interval, super_sampling if self.configs.use_vocoder and self.configs.version == "v3" else False, ) + elif streaming_mode:... else: audio.append(batch_audio_fragment) if self.stop_flag: - yield 16000, np.zeros(int(16000), dtype=np.int16) + yield output_sr, np.zeros(int(output_sr), dtype=np.int16) return - if not return_fragment: + if not (return_fragment or streaming_mode): print("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t_34, t_45)) if len(audio) == 0: - yield 16000, np.zeros(int(16000), dtype=np.int16) + yield output_sr, np.zeros(int(output_sr), dtype=np.int16) return yield self.audio_postprocess( audio, @@ -1384,16 +1548,17 @@ class TTS: fragment_interval: float = 0.3, super_sampling: bool = False, ) -> Tuple[int, np.ndarray]: - zero_wav = torch.zeros( - int(self.configs.sampling_rate * fragment_interval), dtype=self.precision, device=self.configs.device - ) + if fragment_interval>0: + zero_wav = torch.zeros( + int(self.configs.sampling_rate * fragment_interval), dtype=self.precision, device=self.configs.device + ) for i, batch in enumerate(audio): for j, audio_fragment in enumerate(batch): max_audio = torch.abs(audio_fragment).max() # 简单防止16bit爆音 if max_audio > 1: audio_fragment /= max_audio - audio_fragment: torch.Tensor = torch.cat([audio_fragment, zero_wav], dim=0) + audio_fragment: torch.Tensor = torch.cat([audio_fragment, zero_wav], dim=0) if fragment_interval>0 else audio_fragment audio[i][j] = audio_fragment if split_bucket: @@ -1413,13 +1578,18 @@ class TTS: max_audio = np.abs(audio).max() if max_audio > 1: audio /= max_audio + audio = (audio * 32768).astype(np.int16) t2 = time.perf_counter() print(f"超采样用时:{t2 - t1:.3f}s") else: + # audio = audio.float() * 32768 + # audio = audio.to(dtype=torch.int16).clamp(-32768, 32767).cpu().numpy() + audio = audio.cpu().numpy() audio = (audio * 32768).astype(np.int16) + # try: # if speed_factor != 1.0: # audio = speed_change(audio, speed=speed_factor, sr=int(sr)) @@ -1612,24 +1782,43 @@ class TTS: self, audio_fragments: List[torch.Tensor], overlap_len: int, + search_len:int= 320 ): + # overlap_len-=search_len + + dtype = audio_fragments[0].dtype + for i in range(len(audio_fragments) - 1): - f1 = audio_fragments[i] - f2 = audio_fragments[i + 1] + f1 = audio_fragments[i].float() + f2 = audio_fragments[i + 1].float() w1 = f1[-overlap_len:] - w2 = f2[:overlap_len] - assert w1.shape == w2.shape - corr = F.conv1d(w1.view(1, 1, -1), w2.view(1, 1, -1), padding=w2.shape[-1] // 2).view(-1)[:-1] - idx = corr.argmax() - f1_ = f1[: -(overlap_len - idx)] + w2 = f2[:overlap_len+search_len] + # w2 = w2[-w2.shape[-1]//2:] + # assert w1.shape == w2.shape + corr_norm = F.conv1d(w2.view(1, 1, -1), w1.view(1, 1, -1)).view(-1) + + corr_den = F.conv1d(w2.view(1, 1, -1)**2, torch.ones_like(w1).view(1, 1, -1)).view(-1)+ 1e-8 + idx = (corr_norm/corr_den.sqrt()).argmax() + + print(f"seg_idx: {idx}") + + # idx = corr.argmax() + f1_ = f1[: -overlap_len] audio_fragments[i] = f1_ f2_ = f2[idx:] - window = torch.hann_window((overlap_len - idx) * 2, device=f1.device, dtype=f1.dtype) - f2_[: (overlap_len - idx)] = ( - window[: (overlap_len - idx)] * f2_[: (overlap_len - idx)] - + window[(overlap_len - idx) :] * f1[-(overlap_len - idx) :] + window = torch.hann_window((overlap_len) * 2, device=f1.device, dtype=f1.dtype) + f2_[: overlap_len] = ( + window[: overlap_len] * f2_[: overlap_len] + + window[overlap_len :] * f1[-overlap_len :] ) + + # window = torch.sin(torch.arange((overlap_len - idx), device=f1.device) * np.pi / (overlap_len - idx)) + # f2_[: (overlap_len - idx)] = ( + # window * f2_[: (overlap_len - idx)] + # + (1-window) * f1[-(overlap_len - idx) :] + # ) + audio_fragments[i + 1] = f2_ - return torch.cat(audio_fragments, 0) + return torch.cat(audio_fragments, 0).to(dtype) diff --git a/GPT_SoVITS/module/models.py b/GPT_SoVITS/module/models.py index 1c8e662f..348ddb3f 100644 --- a/GPT_SoVITS/module/models.py +++ b/GPT_SoVITS/module/models.py @@ -151,6 +151,8 @@ class DurationPredictor(nn.Module): return x * x_mask +WINDOW = {} + class TextEncoder(nn.Module): def __init__( self, @@ -209,7 +211,7 @@ class TextEncoder(nn.Module): self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1) - def forward(self, y, y_lengths, text, text_lengths, ge, speed=1, test=None): + def forward(self, y, y_lengths, text, text_lengths, ge, speed=1, test=None, result_length:int=None, overlap_frames:torch.Tensor=None, padding_length:int=None): y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(y.dtype) y = self.ssl_proj(y * y_mask) * y_mask @@ -222,13 +224,44 @@ class TextEncoder(nn.Module): text = self.text_embedding(text).transpose(1, 2) text = self.encoder_text(text * text_mask, text_mask) y = self.mrte(y, y_mask, text, text_mask, ge) + + if padding_length is not None and padding_length!=0: + y = y[:, :, :-padding_length] + y_mask = y_mask[:, :, :-padding_length] + + y = self.encoder2(y * y_mask, y_mask) + + if result_length is not None: + y = y[:, :, -result_length:] + y_mask = y_mask[:, :, -result_length:] + + if overlap_frames is not None: + overlap_len = overlap_frames.shape[-1] + window = WINDOW.get(overlap_len, None) + if window is None: + # WINDOW[overlap_len] = torch.hann_window(overlap_len*2, device=y.device, dtype=y.dtype) + WINDOW[overlap_len] = torch.sin(torch.arange(overlap_len*2, device=y.device) * torch.pi / (overlap_len*2)) + window = WINDOW[overlap_len] + + + window = window.to(y.device) + y[:,:,:overlap_len] = ( + window[:overlap_len].view(1, 1, -1) * y[:,:,:overlap_len] + + window[overlap_len:].view(1, 1, -1) * overlap_frames + ) + + y_ = y + y_mask_ = y_mask + + + if speed != 1: y = F.interpolate(y, size=int(y.shape[-1] / speed) + 1, mode="linear") y_mask = F.interpolate(y_mask, size=y.shape[-1], mode="nearest") stats = self.proj(y) * y_mask m, logs = torch.split(stats, self.out_channels, dim=1) - return y, m, logs, y_mask + return y, m, logs, y_mask, y_, y_mask_ def extract_latent(self, x): x = self.ssl_proj(x) @@ -921,7 +954,7 @@ class SynthesizerTrn(nn.Module): if self.semantic_frame_rate == "25hz": quantized = F.interpolate(quantized, size=int(quantized.shape[-1] * 2), mode="nearest") - x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge512 if self.is_v2pro else ge) + x, m_p, logs_p, y_mask, _, _ = self.enc_p(quantized, y_lengths, text, text_lengths, ge512 if self.is_v2pro else ge) z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=ge) z_p = self.flow(z, y_mask, g=ge) @@ -949,7 +982,7 @@ class SynthesizerTrn(nn.Module): if self.semantic_frame_rate == "25hz": quantized = F.interpolate(quantized, size=int(quantized.shape[-1] * 2), mode="nearest") - x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge, test=test) + x, m_p, logs_p, y_mask, _, _ = self.enc_p(quantized, y_lengths, text, text_lengths, ge, test=test) z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale z = self.flow(z_p, y_mask, g=ge, reverse=True) @@ -957,6 +990,7 @@ class SynthesizerTrn(nn.Module): o = self.dec((z * y_mask)[:, :, :], g=ge) return o, y_mask, (z, z_p, m_p, logs_p) + @torch.no_grad() def decode(self, codes, text, refer, noise_scale=0.5, speed=1, sv_emb=None): def get_ge(refer, sv_emb): @@ -989,7 +1023,7 @@ class SynthesizerTrn(nn.Module): quantized = self.quantizer.decode(codes) if self.semantic_frame_rate == "25hz": quantized = F.interpolate(quantized, size=int(quantized.shape[-1] * 2), mode="nearest") - x, m_p, logs_p, y_mask = self.enc_p( + x, m_p, logs_p, y_mask, _, _ = self.enc_p( quantized, y_lengths, text, @@ -1004,6 +1038,59 @@ class SynthesizerTrn(nn.Module): o = self.dec((z * y_mask)[:, :, :], g=ge) return o + + @torch.no_grad() + def decode_streaming(self, codes, text, refer, noise_scale=0.5, speed=1, sv_emb=None, result_length:int=None, overlap_frames:torch.Tensor=None, padding_length:int=None): + def get_ge(refer, sv_emb): + ge = None + if refer is not None: + refer_lengths = torch.LongTensor([refer.size(2)]).to(refer.device) + refer_mask = torch.unsqueeze(commons.sequence_mask(refer_lengths, refer.size(2)), 1).to(refer.dtype) + if self.version == "v1": + ge = self.ref_enc(refer * refer_mask, refer_mask) + else: + ge = self.ref_enc(refer[:, :704] * refer_mask, refer_mask) + if self.is_v2pro: + sv_emb = self.sv_emb(sv_emb) # B*20480->B*512 + ge += sv_emb.unsqueeze(-1) + ge = self.prelu(ge) + return ge + + if type(refer) == list: + ges = [] + for idx, _refer in enumerate(refer): + ge = get_ge(_refer, sv_emb[idx] if self.is_v2pro else None) + ges.append(ge) + ge = torch.stack(ges, 0).mean(0) + else: + ge = get_ge(refer, sv_emb) + + y_lengths = torch.LongTensor([codes.size(2) * 2]).to(codes.device) + text_lengths = torch.LongTensor([text.size(-1)]).to(text.device) + + quantized = self.quantizer.decode(codes) + if self.semantic_frame_rate == "25hz": + quantized = F.interpolate(quantized, size=int(quantized.shape[-1] * 2), mode="nearest") + result_length = (2*result_length) if result_length is not None else None + padding_length = (2*padding_length) if padding_length is not None else None + x, m_p, logs_p, y_mask, y_, y_mask_ = self.enc_p( + quantized, + y_lengths, + text, + text_lengths, + self.ge_to512(ge.transpose(2, 1)).transpose(2, 1) if self.is_v2pro else ge, + speed, + result_length=result_length, + overlap_frames=overlap_frames, + padding_length=padding_length + ) + z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale + + z = self.flow(z_p, y_mask, g=ge, reverse=True) + + o = self.dec((z * y_mask)[:, :, :], g=ge) + return o, y_, y_mask_ + def extract_latent(self, x): ssl = self.ssl_proj(x) quantized, codes, commit_loss, quantized_list = self.quantizer(ssl) @@ -1226,7 +1313,7 @@ class SynthesizerTrnV3(nn.Module): ssl = self.ssl_proj(ssl) quantized, codes, commit_loss, quantized_list = self.quantizer(ssl, layers=[0]) quantized = F.interpolate(quantized, scale_factor=2, mode="nearest") ##BCT - x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge) + x, m_p, logs_p, y_mask, y_, y_mask_ = self.enc_p(quantized, y_lengths, text, text_lengths, ge) fea = self.bridge(x) fea = F.interpolate(fea, scale_factor=(1.875 if self.version == "v3" else 2), mode="nearest") ##BCT fea, y_mask_ = self.wns1( @@ -1260,7 +1347,7 @@ class SynthesizerTrnV3(nn.Module): quantized = self.quantizer.decode(codes) if self.semantic_frame_rate == "25hz": quantized = F.interpolate(quantized, scale_factor=2, mode="nearest") ##BCT - x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge, speed) + x, m_p, logs_p, y_mask, _, _ = self.enc_p(quantized, y_lengths, text, text_lengths, ge, speed) fea = self.bridge(x) fea = F.interpolate(fea, scale_factor=(1.875 if self.version == "v3" else 2), mode="nearest") ##BCT ####more wn paramter to learn mel @@ -1377,7 +1464,7 @@ class SynthesizerTrnV3b(nn.Module): ssl = self.ssl_proj(ssl) quantized, codes, commit_loss, quantized_list = self.quantizer(ssl, layers=[0]) quantized = F.interpolate(quantized, scale_factor=2, mode="nearest") ##BCT - x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge) + x, m_p, logs_p, y_mask, y_, y_mask_ = self.enc_p(quantized, y_lengths, text, text_lengths, ge) z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=ge) z_p = self.flow(z, y_mask, g=ge) z_slice, ids_slice = commons.rand_slice_segments(z, y_lengths, self.segment_size) @@ -1420,7 +1507,7 @@ class SynthesizerTrnV3b(nn.Module): quantized = self.quantizer.decode(codes) if self.semantic_frame_rate == "25hz": quantized = F.interpolate(quantized, scale_factor=2, mode="nearest") ##BCT - x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge) + x, m_p, logs_p, y_mask, y_, y_mask_ = self.enc_p(quantized, y_lengths, text, text_lengths, ge) fea = self.bridge(x) fea = F.interpolate(fea, scale_factor=1.875, mode="nearest") ##BCT ####more wn paramter to learn mel diff --git a/api_v2.py b/api_v2.py index 5947df53..5df2da66 100644 --- a/api_v2.py +++ b/api_v2.py @@ -30,17 +30,22 @@ POST: "top_k": 5, # int. top k sampling "top_p": 1, # float. top p sampling "temperature": 1, # float. temperature for sampling - "text_split_method": "cut0", # str. text split method, see text_segmentation_method.py for details. + "text_split_method": "cut5", # str. text split method, see text_segmentation_method.py for details. "batch_size": 1, # int. batch size for inference "batch_threshold": 0.75, # float. threshold for batch splitting. "split_bucket": True, # bool. whether to split the batch into multiple buckets. "speed_factor":1.0, # float. control the speed of the synthesized audio. - "streaming_mode": False, # bool. whether to return a streaming response. + "fragment_interval":0.3, # float. to control the interval of the audio fragment. "seed": -1, # int. random seed for reproducibility. "parallel_infer": True, # bool. whether to use parallel inference. "repetition_penalty": 1.35, # float. repetition penalty for T2S model. "sample_steps": 32, # int. number of sampling steps for VITS model V3. - "super_sampling": False # bool. whether to use super-sampling for audio when using VITS model V3. + "super_sampling": False, # bool. whether to use super-sampling for audio when using VITS model V3. + "return_fragment": False, # bool. step by step return the audio fragment. (Best Quality, Slowest response speed. old version of streaming mode) + "streaming_mode": False, # bool. return audio chunk by chunk. (Medium quality, Slow response speed) + "overlap_length": 2, # int. overlap length of semantic tokens for streaming mode. + "min_chunk_length": 16, # int. The minimum chunk length of semantic tokens for streaming mode. (affects audio chunk size) + "fixed_length_chunk": False, # bool. When turned on, it can achieve faster streaming response, but with lower quality. (lower quality, faster response speed) } ``` @@ -121,6 +126,7 @@ from tools.i18n.i18n import I18nAuto from GPT_SoVITS.TTS_infer_pack.TTS import TTS, TTS_Config from GPT_SoVITS.TTS_infer_pack.text_segmentation_method import get_method_names as get_cut_method_names from pydantic import BaseModel +import threading # print(sys.path) i18n = I18nAuto() @@ -170,12 +176,55 @@ class TTS_Request(BaseModel): repetition_penalty: float = 1.35 sample_steps: int = 32 super_sampling: bool = False + overlap_length: int = 2 + min_chunk_length: int = 16 + return_fragment: bool = False + fixed_length_chunk: bool = False -### modify from https://github.com/RVC-Boss/GPT-SoVITS/pull/894/files def pack_ogg(io_buffer: BytesIO, data: np.ndarray, rate: int): - with sf.SoundFile(io_buffer, mode="w", samplerate=rate, channels=1, format="ogg") as audio_file: - audio_file.write(data) + # Author: AkagawaTsurunaki + # Issue: + # Stack overflow probabilistically occurs + # when the function `sf_writef_short` of `libsndfile_64bit.dll` is called + # using the Python library `soundfile` + # Note: + # This is an issue related to `libsndfile`, not this project itself. + # It happens when you generate a large audio tensor (about 499804 frames in my PC) + # and try to convert it to an ogg file. + # Related: + # https://github.com/RVC-Boss/GPT-SoVITS/issues/1199 + # https://github.com/libsndfile/libsndfile/issues/1023 + # https://github.com/bastibe/python-soundfile/issues/396 + # Suggestion: + # Or split the whole audio data into smaller audio segment to avoid stack overflow? + + def handle_pack_ogg(): + with sf.SoundFile(io_buffer, mode="w", samplerate=rate, channels=1, format="ogg") as audio_file: + audio_file.write(data) + + + + # See: https://docs.python.org/3/library/threading.html + # The stack size of this thread is at least 32768 + # If stack overflow error still occurs, just modify the `stack_size`. + # stack_size = n * 4096, where n should be a positive integer. + # Here we chose n = 4096. + stack_size = 4096 * 4096 + try: + threading.stack_size(stack_size) + pack_ogg_thread = threading.Thread(target=handle_pack_ogg) + pack_ogg_thread.start() + pack_ogg_thread.join() + except RuntimeError as e: + # If changing the thread stack size is unsupported, a RuntimeError is raised. + print("RuntimeError: {}".format(e)) + print("Changing the thread stack size is unsupported.") + except ValueError as e: + # If the specified stack size is invalid, a ValueError is raised and the stack size is unmodified. + print("ValueError: {}".format(e)) + print("The specified stack size is invalid.") + return io_buffer @@ -286,8 +335,8 @@ def check_params(req: dict): ) if media_type not in ["wav", "raw", "ogg", "aac"]: return JSONResponse(status_code=400, content={"message": f"media_type: {media_type} is not supported"}) - elif media_type == "ogg" and not streaming_mode: - return JSONResponse(status_code=400, content={"message": "ogg format is not supported in non-streaming mode"}) + # elif media_type == "ogg" and not streaming_mode: + # return JSONResponse(status_code=400, content={"message": "ogg format is not supported in non-streaming mode"}) if text_split_method not in cut_method_names: return JSONResponse( @@ -307,7 +356,7 @@ async def tts_handle(req: dict): "text": "", # str.(required) text to be synthesized "text_lang: "", # str.(required) language of the text to be synthesized "ref_audio_path": "", # str.(required) reference audio path - "aux_ref_audio_paths": [], # list.(optional) auxiliary reference audio paths for multi-speaker synthesis + "aux_ref_audio_paths": [], # list.(optional) auxiliary reference audio paths for multi-speaker tone fusion "prompt_text": "", # str.(optional) prompt text for the reference audio "prompt_lang": "", # str.(required) language of the prompt text for the reference audio "top_k": 5, # int. top k sampling @@ -316,16 +365,19 @@ async def tts_handle(req: dict): "text_split_method": "cut5", # str. text split method, see text_segmentation_method.py for details. "batch_size": 1, # int. batch size for inference "batch_threshold": 0.75, # float. threshold for batch splitting. - "split_bucket: True, # bool. whether to split the batch into multiple buckets. + "split_bucket": True, # bool. whether to split the batch into multiple buckets. "speed_factor":1.0, # float. control the speed of the synthesized audio. "fragment_interval":0.3, # float. to control the interval of the audio fragment. "seed": -1, # int. random seed for reproducibility. - "media_type": "wav", # str. media type of the output audio, support "wav", "raw", "ogg", "aac". - "streaming_mode": False, # bool. whether to return a streaming response. - "parallel_infer": True, # bool.(optional) whether to use parallel inference. - "repetition_penalty": 1.35 # float.(optional) repetition penalty for T2S model. + "parallel_infer": True, # bool. whether to use parallel inference. + "repetition_penalty": 1.35, # float. repetition penalty for T2S model. "sample_steps": 32, # int. number of sampling steps for VITS model V3. - "super_sampling": False, # bool. whether to use super-sampling for audio when using VITS model V3. + "super_sampling": False, # bool. whether to use super-sampling for audio when using VITS model V3. + "return_fragment": False, # bool. step by step return the audio fragment. (Best Quality, Slowest response speed. old version of streaming mode) + "streaming_mode": False, # bool. return audio chunk by chunk. (Medium quality, Slow response speed) + "overlap_length": 2, # int. overlap length of semantic tokens for streaming mode. + "min_chunk_length": 16, # int. The minimum chunk length of semantic tokens for streaming mode. (affects audio chunk size) + "fixed_length_chunk": False, # bool. When turned on, it can achieve faster streaming response, but with lower quality. (lower quality, faster response speed) } returns: StreamingResponse: audio stream response. @@ -339,8 +391,10 @@ async def tts_handle(req: dict): if check_res is not None: return check_res - if streaming_mode or return_fragment: - req["return_fragment"] = True + req["streaming_mode"] = streaming_mode + req["return_fragment"] = return_fragment + streaming_mode = streaming_mode or return_fragment + try: tts_generator = tts_pipeline.run(req) @@ -391,7 +445,7 @@ async def tts_get_endpoint( top_k: int = 5, top_p: float = 1, temperature: float = 1, - text_split_method: str = "cut0", + text_split_method: str = "cut5", batch_size: int = 1, batch_threshold: float = 0.75, split_bucket: bool = True, @@ -399,11 +453,15 @@ async def tts_get_endpoint( fragment_interval: float = 0.3, seed: int = -1, media_type: str = "wav", - streaming_mode: bool = False, parallel_infer: bool = True, repetition_penalty: float = 1.35, sample_steps: int = 32, super_sampling: bool = False, + return_fragment: bool = False, + streaming_mode: bool = False, + overlap_length: int = 2, + min_chunk_length: int = 16, + fixed_length_chunk: bool = False, ): req = { "text": text, @@ -428,6 +486,10 @@ async def tts_get_endpoint( "repetition_penalty": float(repetition_penalty), "sample_steps": int(sample_steps), "super_sampling": super_sampling, + "overlap_length": int(overlap_length), + "min_chunk_length": int(min_chunk_length), + "return_fragment": return_fragment, + "fixed_length_chunk": fixed_length_chunk } return await tts_handle(req) From e00ca92140542e6d947b9f660e24ed757aabc793 Mon Sep 17 00:00:00 2001 From: KamioRinn <63162909+KamioRinn@users.noreply.github.com> Date: Fri, 28 Nov 2025 21:22:43 +0800 Subject: [PATCH 04/20] Fix ASMD (#2636) --- GPT_SoVITS/text/en_normalization/expend.py | 48 +++++++++++++++++++++- 1 file changed, 47 insertions(+), 1 deletion(-) diff --git a/GPT_SoVITS/text/en_normalization/expend.py b/GPT_SoVITS/text/en_normalization/expend.py index bbd607cd..9f7c1d21 100644 --- a/GPT_SoVITS/text/en_normalization/expend.py +++ b/GPT_SoVITS/text/en_normalization/expend.py @@ -238,6 +238,46 @@ def _expand_number(m): return _inflect.number_to_words(num, andword="") +# 加减乘除 +RE_ASMD = re.compile( + r"((-?)((\d+)(\.\d+)?[⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]*)|(\.\d+[⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]*)|([A-Za-z][⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]*))\s+([\+\-\×÷=])\s+((-?)((\d+)(\.\d+)?[⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]*)|(\.\d+[⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]*)|([A-Za-z][⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]*))" +) +# RE_ASMD = re.compile( +# r"\b((-?)((\d+)(\.\d+)?[⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]*)|(\.\d+[⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]*)|([A-Za-z][⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]*))([\+\-\×÷=])((-?)((\d+)(\.\d+)?[⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]*)|(\.\d+[⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]*)|([A-Za-z][⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]*))\b" +# ) + +asmd_map = {"+": " plus ", "-": " minus ", "×": " times ", "÷": " divided by ", "=": " Equals "} + + +def replace_asmd(match) -> str: + """ + Args: + match (re.Match) + Returns: + str + """ + result = match.group(1) + asmd_map[match.group(8)] + match.group(9) + return result + + +RE_INTEGER = re.compile(r"(?:^|\s+)(-)" r"(\d+)") + + +def replace_negative_num(match) -> str: + """ + Args: + match (re.Match) + Returns: + str + """ + sign = match.group(1) + number = match.group(2) + sign: str = "negative " if sign else "" + result = f"{sign}{number}" + return result + + + def normalize(text): """ !!! 所有的处理都需要正确的输入 !!! @@ -245,7 +285,13 @@ def normalize(text): """ text = re.sub(_ordinal_number_re, _convert_ordinal, text) - text = re.sub(r"(? Date: Fri, 28 Nov 2025 21:36:57 +0800 Subject: [PATCH 05/20] =?UTF-8?q?=20=E5=B0=9D=E8=AF=95=20stream=20infer=20?= =?UTF-8?q?(#2469)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * 尝试 stream infer * 在 stream_infer 脚本中绘制生成的音频 * stream_infer 增加导出部分。 * stream_infer: 更方便找规律的图 * stream_infer: 在拼接音频时进行相关性搜索,减少拼接带来基频断裂的情况 * stream_infer: 导出 `find_best_audio_offset_fast` * stream_infer: 优化波形显示,方便对比 * stream_v2pro.py 从命令行读取参数 * stream_v2pro.py 减少用于导出的文本长度 * stream_v2pro: 修复由于 spectrogram_torch 输入是 half 导致 spec 溢出最终没有声音的问题 * stream_v2pro: 新增 --lang 参数提示参考文字的语言类型 --- GPT_SoVITS/export_torch_script.py | 46 +-- GPT_SoVITS/stream_v2pro.py | 611 ++++++++++++++++++++++++++++++ 2 files changed, 624 insertions(+), 33 deletions(-) create mode 100644 GPT_SoVITS/stream_v2pro.py diff --git a/GPT_SoVITS/export_torch_script.py b/GPT_SoVITS/export_torch_script.py index e4406f28..786e22d1 100644 --- a/GPT_SoVITS/export_torch_script.py +++ b/GPT_SoVITS/export_torch_script.py @@ -261,41 +261,21 @@ class T2SBlock: attn = F.scaled_dot_product_attention(q, k, v, ~attn_mask) - attn = attn.permute(2, 0, 1, 3).reshape(batch_size * q_len, self.hidden_dim) - attn = attn.view(q_len, batch_size, self.hidden_dim).transpose(1, 0) + # attn = attn.permute(2, 0, 1, 3).reshape(batch_size * q_len, self.hidden_dim) + # attn = attn.view(q_len, batch_size, self.hidden_dim).transpose(1, 0) + attn = attn.transpose(1, 2).reshape(batch_size, q_len, -1) attn = F.linear(self.to_mask(attn, padding_mask), self.out_w, self.out_b) - if padding_mask is not None: - for i in range(batch_size): - # mask = padding_mask[i,:,0] - if self.false.device != padding_mask.device: - self.false = self.false.to(padding_mask.device) - idx = torch.where(padding_mask[i, :, 0] == self.false)[0] - x_item = x[i, idx, :].unsqueeze(0) - attn_item = attn[i, idx, :].unsqueeze(0) - x_item = x_item + attn_item - x_item = F.layer_norm(x_item, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1) - x_item = x_item + self.mlp.forward(x_item) - x_item = F.layer_norm( - x_item, - [self.hidden_dim], - self.norm_w2, - self.norm_b2, - self.norm_eps2, - ) - x[i, idx, :] = x_item.squeeze(0) - x = self.to_mask(x, padding_mask) - else: - x = x + attn - x = F.layer_norm(x, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1) - x = x + self.mlp.forward(x) - x = F.layer_norm( - x, - [self.hidden_dim], - self.norm_w2, - self.norm_b2, - self.norm_eps2, - ) + x = x + attn + x = F.layer_norm(x, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1) + x = x + self.mlp.forward(x) + x = F.layer_norm( + x, + [self.hidden_dim], + self.norm_w2, + self.norm_b2, + self.norm_eps2, + ) return x, k_cache, v_cache def decode_next_token(self, x: torch.Tensor, k_cache: torch.Tensor, v_cache: torch.Tensor): diff --git a/GPT_SoVITS/stream_v2pro.py b/GPT_SoVITS/stream_v2pro.py new file mode 100644 index 00000000..7cd4351e --- /dev/null +++ b/GPT_SoVITS/stream_v2pro.py @@ -0,0 +1,611 @@ +# 这是一个实验性质的实现,旨在探索 stream infer 的可能性。(xiao hai xie zhe wan de) +from typing import List +from export_torch_script import ExportERes2NetV2, SSLModel, T2SModel, VitsModel, get_raw_t2s_model, init_sv_cn, resamplex, sample, spectrogram_torch +import export_torch_script +from my_utils import load_audio +import torch +from torch import LongTensor, Tensor, nn +from torch.nn import functional as F + +import soundfile +from inference_webui import get_phones_and_bert +import matplotlib.pyplot as plt + + +class StreamT2SModel(nn.Module): + def __init__(self, t2s: T2SModel): + super(StreamT2SModel, self).__init__() + self.t2s = t2s + + @torch.jit.export + def pre_infer( + self, + prompts: LongTensor, + ref_seq: LongTensor, + text_seq: LongTensor, + ref_bert: torch.Tensor, + text_bert: torch.Tensor, + top_k: int, + ) -> tuple[int, Tensor, Tensor, List[Tensor], List[Tensor]]: + bert = torch.cat([ref_bert.T, text_bert.T], 1) + all_phoneme_ids = torch.cat([ref_seq, text_seq], 1) + bert = bert.unsqueeze(0) + + x = self.t2s.ar_text_embedding(all_phoneme_ids) + x = x + self.t2s.bert_proj(bert.transpose(1, 2)) + x: torch.Tensor = self.t2s.ar_text_position(x) + + # [1,N,512] [1,N] + # y, k, v, y_emb, x_example = self.first_stage_decoder(x, prompts) + y = prompts + # x_example = x[:,:,0] * 0.0 + + x_len = x.shape[1] + x_attn_mask = torch.zeros((x_len, x_len), dtype=torch.bool) + + y_emb = self.t2s.ar_audio_embedding(y) + y_len: int = y_emb.shape[1] + prefix_len = y.shape[1] + y_pos = self.t2s.ar_audio_position(y_emb) + xy_pos = torch.concat([x, y_pos], dim=1) + + bsz = x.shape[0] + src_len = x_len + y_len + x_attn_mask_pad = F.pad( + x_attn_mask, + (0, y_len), ###xx的纯0扩展到xx纯0+xy纯1,(x,x+y) + value=True, + ) + y_attn_mask = F.pad( ###yy的右上1扩展到左边xy的0,(y,x+y) + torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1), + (x_len, 0), + value=False, + ) + xy_attn_mask = ( + torch.concat([x_attn_mask_pad, y_attn_mask], dim=0) + .unsqueeze(0) + .expand(bsz * self.t2s.num_head, -1, -1) + .view(bsz, self.t2s.num_head, src_len, src_len) + .to(device=x.device, dtype=torch.bool) + ) + + xy_dec, k_cache, v_cache = self.t2s.t2s_transformer.process_prompt( + xy_pos, xy_attn_mask, None + ) + + logits = self.t2s.ar_predict_layer(xy_dec[:, -1]) + logits = logits[:, :-1] + samples = sample( + logits, y, top_k=top_k, top_p=1, repetition_penalty=1.35, temperature=1.0 + )[0] + y = torch.concat([y, samples], dim=1) + y_emb: Tensor = self.t2s.ar_audio_embedding(y[:, -1:]) + xy_pos: Tensor = ( + y_emb * self.t2s.ar_audio_position.x_scale + + self.t2s.ar_audio_position.alpha + * self.t2s.ar_audio_position.pe[:, y_len].to( + dtype=y_emb.dtype, device=y_emb.device + ) + ) + + return y_len, y, xy_pos, k_cache, v_cache + + @torch.jit.export + def decode_next_token( + self, + idx: int, # 记住从1开始 到1500 + top_k: int, + y_len: int, + y: Tensor, + xy_pos: Tensor, + k_cache: List[Tensor], + v_cache: List[Tensor], + ) -> tuple[Tensor, Tensor, int, List[Tensor], List[Tensor]]: + # [1, N] [N_layer, N, 1, 512] [N_layer, N, 1, 512] [1, N, 512] [1] [1, N, 512] [1, N] + # y, k, v, y_emb, logits, samples = self.stage_decoder(y, k, v, y_emb, x_example) + xy_dec, k_cache, v_cache = self.t2s.t2s_transformer.decode_next_token( + xy_pos, k_cache, v_cache + ) + logits = self.t2s.ar_predict_layer(xy_dec[:, -1]) + + if idx < 11: ###至少预测出10个token不然不给停止(0.4s) + logits = logits[:, :-1] + + samples = sample( + logits, y, top_k=top_k, top_p=1, repetition_penalty=1.35, temperature=1.0 + )[0] + + y = torch.concat([y, samples], dim=1) + last_token = int(samples[0, 0]) + + # if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num: + # stop = True + if torch.argmax(logits, dim=-1)[0] == self.t2s.EOS or samples[0, 0] == self.t2s.EOS: + return y[:,:-1], xy_pos, self.t2s.EOS, k_cache, v_cache + + # if stop: + # if y.shape[1] == 0: + # y = torch.concat([y, torch.zeros_like(samples)], dim=1) + # break + + y_emb = self.t2s.ar_audio_embedding(y[:, -1:]) + xy_pos = ( + y_emb * self.t2s.ar_audio_position.x_scale + + self.t2s.ar_audio_position.alpha + * self.t2s.ar_audio_position.pe[:, y_len + idx].to( + dtype=y_emb.dtype, device=y_emb.device + ) + ) + return y, xy_pos, last_token, k_cache, v_cache + + def forward( + self, + idx: int, # 记住从1开始 到1500 + top_k: int, + y_len: int, + y: Tensor, + xy_pos: Tensor, + k_cache: List[Tensor], + v_cache: List[Tensor], + ): + return self.decode_next_token(idx,top_k,y_len,y,xy_pos,k_cache,v_cache) + + +class StepVitsModel(nn.Module): + def __init__(self, vits: VitsModel,sv_model:ExportERes2NetV2): + super().__init__() + self.hps = vits.hps + self.vq_model = vits.vq_model + self.hann_window = vits.hann_window + self.sv = sv_model + + def ref_handle(self, ref_audio_32k): + refer = spectrogram_torch( + self.hann_window, + ref_audio_32k.float(), + self.hps.data.filter_length, + self.hps.data.sampling_rate, + self.hps.data.hop_length, + self.hps.data.win_length, + center=False, + ) + refer = refer.to(ref_audio_32k.dtype) + ref_audio_16k = resamplex(ref_audio_32k, 32000, 16000).to(ref_audio_32k.dtype).to(ref_audio_32k.device) + sv_emb = self.sv(ref_audio_16k) + return refer, sv_emb + + def extract_latent(self, ssl_content): + codes = self.vq_model.extract_latent(ssl_content) + return codes[0] + + def forward(self, pred_semantic, text_seq, refer, sv_emb=None): + return self.vq_model( + pred_semantic, text_seq, refer, speed=1.0, sv_emb=sv_emb + )[0, 0] + + +@torch.jit.script +def find_best_audio_offset_fast(reference_audio: Tensor, search_audio: Tensor): + ref_len = len(reference_audio) + search_len = len(search_audio) + + if search_len < ref_len: + raise ValueError( + f"搜索音频长度 ({search_len}) 必须大于等于参考音频长度 ({ref_len})" + ) + + # 使用F.conv1d计算原始互相关 + reference_flipped = reference_audio.unsqueeze(0).unsqueeze(0) + search_padded = search_audio.unsqueeze(0).unsqueeze(0) + + # 计算点积 + dot_products = F.conv1d(search_padded, reference_flipped).squeeze() + + if len(dot_products.shape) == 0: + dot_products = dot_products.unsqueeze(0) + + # 计算参考音频的平方和 + ref_squared_sum = torch.sum(reference_audio**2) + + # 计算搜索音频每个位置的平方和(滑动窗口) + search_squared = search_audio**2 + search_squared_padded = search_squared.unsqueeze(0).unsqueeze(0) + ones_kernel = torch.ones( + 1, 1, ref_len, dtype=search_audio.dtype, device=search_audio.device + ) + + segment_squared_sums = F.conv1d(search_squared_padded, ones_kernel).squeeze() + + if len(segment_squared_sums.shape) == 0: + segment_squared_sums = segment_squared_sums.unsqueeze(0) + + # 计算归一化因子 + ref_norm = torch.sqrt(ref_squared_sum) + segment_norms = torch.sqrt(segment_squared_sums) + + # 避免除零 + epsilon = 1e-8 + normalization_factor = ref_norm * segment_norms + epsilon + + # 归一化互相关 + correlation_scores = dot_products / normalization_factor + + best_offset = torch.argmax(correlation_scores).item() + + return best_offset, correlation_scores + + +import time + +def test_stream( + gpt_path, + vits_path, + version, + ref_audio_path, + ref_text, + output_path, + device="cpu", + is_half=True, +): + if export_torch_script.sv_cn_model == None: + init_sv_cn(device,is_half) + + ref_audio = torch.tensor([load_audio(ref_audio_path, 16000)]).float() + ssl = SSLModel() + + print(f"device: {device}") + + ref_seq_id, ref_bert_T, ref_norm_text = get_phones_and_bert( + ref_text, "all_zh", "v2" + ) + ref_seq = torch.LongTensor([ref_seq_id]).to(device) + ref_bert = ref_bert_T.T + if is_half: + ref_bert = ref_bert.half() + ref_bert = ref_bert.to(ref_seq.device) + + text_seq_id, text_bert_T, norm_text = get_phones_and_bert( + "这是一个简单的示例,真没想到这么简单就完成了,真的神奇,接下来我们说说狐狸,可能这就是狐狸吧.它有长长的尾巴,尖尖的耳朵,传说中还有九条尾巴。你觉得狐狸神奇吗?", "auto", "v2" + ) + text_seq = torch.LongTensor([text_seq_id]).to(device) + text_bert = text_bert_T.T + if is_half: + text_bert = text_bert.half() + text_bert = text_bert.to(text_seq.device) + + ssl_content = ssl(ref_audio) + if is_half: + ssl_content = ssl_content.half() + ssl_content = ssl_content.to(device) + + sv_model = ExportERes2NetV2(export_torch_script.sv_cn_model) + + # vits_path = "SoVITS_weights_v2/xw_e8_s216.pth" + vits = VitsModel(vits_path, version,is_half=is_half,device=device) + vits.eval() + + # gpt_path = "GPT_weights_v2/xw-e15.ckpt" + # dict_s1 = torch.load(gpt_path, map_location=device) + dict_s1 = torch.load(gpt_path, weights_only=False) + raw_t2s = get_raw_t2s_model(dict_s1).to(device) + print("#### get_raw_t2s_model ####") + print(raw_t2s.config) + if is_half: + raw_t2s = raw_t2s.half() + t2s_m = T2SModel(raw_t2s) + t2s_m.eval() + # t2s = torch.jit.script(t2s_m).to(device) + t2s = t2s_m + print("#### script t2s_m ####") + + print("vits.hps.data.sampling_rate:", vits.hps.data.sampling_rate) + + stream_t2s = StreamT2SModel(t2s).to(device) + stream_t2s = torch.jit.script(stream_t2s) + + ref_audio_sr = resamplex(ref_audio, 16000, 32000) + if is_half: + ref_audio_sr = ref_audio_sr.half() + ref_audio_sr = ref_audio_sr.to(device) + + top_k = 15 + + codes = vits.vq_model.extract_latent(ssl_content) + prompt_semantic = codes[0, 0] + prompts = prompt_semantic.unsqueeze(0) + + audio_16k = resamplex(ref_audio_sr, 32000, 16000).to(ref_audio_sr.dtype) + sv_emb = sv_model(audio_16k) + print("text_seq",text_seq.shape) + + refer = spectrogram_torch( + vits.hann_window, + ref_audio_sr, + vits.hps.data.filter_length, + vits.hps.data.sampling_rate, + vits.hps.data.hop_length, + vits.hps.data.win_length, + center=False, + ) + + st = time.time() + et = time.time() + + y_len, y, xy_pos, k_cache, v_cache = stream_t2s.pre_infer(prompts, ref_seq, text_seq, ref_bert, text_bert, top_k) + idx = 1 + last_idx = 0 + audios = [] + raw_audios = [] + last_audio_ret = None + offset_index = [] + full_audios = [] + print("y.shape:", y.shape) + cut_id = 0 + while True: + y, xy_pos, last_token, k_cache, v_cache = stream_t2s(idx, top_k, y_len, y, xy_pos, k_cache, v_cache) + # print("y.shape:", y.shape) + stop = last_token==t2s.EOS + print('idx:',idx , 'y.shape:', y.shape, y.shape[1]-idx) + + if last_token < 50 and idx-last_idx > (len(audios)+1) * 25 and idx > cut_id: + cut_id = idx + 7 + print('trigger:',idx, last_idx, y[:,-idx+last_idx:], y[:,-idx+last_idx:].shape) + # y = torch.cat([y, y[:,-1:]], dim=1) + # idx+=1 + + if stop : + idx -=1 + print('stop') + print(idx, y[:,-idx+last_idx:]) + print(idx,last_idx, y.shape) + print(y[:,-idx:-idx+20]) + + + # 玄学这档子事说不清楚 + if idx == cut_id or stop: + print(f"idx: {idx}, last_idx: {last_idx}, cut_id: {cut_id}, stop: {stop}") + audio = vits.vq_model(y[:,-idx:].unsqueeze(0), text_seq, refer, speed=1.0, sv_emb=sv_emb)[0, 0] + full_audios.append(audio) + if last_idx == 0: + last_audio_ret = audio[-1280*8:-1280*8+256] + audio = audio[:-1280*8] + raw_audios.append(audio) + et = time.time() + else: + if stop: + audio_ = audio[last_idx*1280 -1280*8:] + raw_audios.append(audio_) + i, x = find_best_audio_offset_fast(last_audio_ret, audio_[:1280]) + offset_index.append(i) + audio = audio_[i:] + else: + audio_ = audio[last_idx*1280 -1280*8:-1280*8] + raw_audios.append(audio_) + i, x = find_best_audio_offset_fast(last_audio_ret, audio_[:1280]) + offset_index.append(i) + last_audio_ret = audio[-1280*8:-1280*8+256] + audio = audio_[i:] + last_idx = idx + # print(f'write {output_path}/out_{audio_index}') + # soundfile.write(f"{output_path}/out_{audio_index}.wav", audio.float().detach().cpu().numpy(), 32000) + audios.append(audio) + # print(idx,'/',1500 , y.shape, y[0,-1].item(), stop) + if idx>1500: + break + + if stop: + break + + idx+=1 + + at = time.time() + + for (i,a) in enumerate(audios): + print(f'write {output_path}/out_{i}') + soundfile.write(f"{output_path}/out_{i}.wav", a.float().detach().cpu().numpy(), 32000) + + print(f"frist token: {et - st:.4f} seconds") + print(f"all token: {at - st:.4f} seconds") + audio = vits.vq_model(y[:,-idx:].unsqueeze(0), text_seq, refer, speed=1.0, sv_emb=sv_emb)[0, 0] + soundfile.write(f"{output_path}/out_final.wav", audio.float().detach().cpu().numpy(), 32000) + audio = torch.cat(audios, dim=0) + soundfile.write(f"{output_path}/out.wav", audio.float().detach().cpu().numpy(), 32000) + audio_raw = torch.cat(raw_audios, dim=0) + soundfile.write(f"{output_path}/out.raw.wav", audio_raw.float().detach().cpu().numpy(), 32000) + + + colors = ['red', 'green', 'blue', 'orange', 'purple', 'cyan', 'magenta', 'yellow'] + + max_duration = full_audios[-1].shape[0] + plt.xlim(0, max_duration) + + last_line = 0 + + for i,a in enumerate(full_audios): + plt.plot((a+2.0*i).float().detach().cpu().numpy(), color=colors[i], alpha=0.5, label=f"Audio {i}") + # plt.axvline(x=last_line, color=colors[i], linestyle='--') + last_line = a.shape[0]-8*1280 + plt.axvline(x=last_line, color=colors[i], linestyle='--') + + plt.plot((audio-2.0).float().detach().cpu().numpy(), color='black', label='Final Audio') + + plt.plot((audio_raw-4.0).float().detach().cpu().numpy(), color='cyan', label='Raw Audio') + + print("offset_index:", offset_index) + plt.show() + + +def export_prov2( + gpt_path, + vits_path, + version, + ref_audio_path, + ref_text, + output_path, + device="cpu", + is_half=True, + lang="auto", +): + if export_torch_script.sv_cn_model == None: + init_sv_cn(device,is_half) + + ref_audio = torch.tensor([load_audio(ref_audio_path, 16000)]).float() + ssl = SSLModel() + + print(f"device: {device}") + + ref_seq_id, ref_bert_T, ref_norm_text = get_phones_and_bert( + ref_text, lang, "v2" + ) + ref_seq = torch.LongTensor([ref_seq_id]).to(device) + ref_bert = ref_bert_T.T + if is_half: + ref_bert = ref_bert.half() + ref_bert = ref_bert.to(ref_seq.device) + + text_seq_id, text_bert_T, norm_text = get_phones_and_bert( + "这是一个简单的示例,真没想到这么简单就完成了.The King and His Stories.Once there was a king.He likes to write stories, but his stories were not good.", "auto", "v2" + ) + text_seq = torch.LongTensor([text_seq_id]).to(device) + text_bert = text_bert_T.T + if is_half: + text_bert = text_bert.half() + text_bert = text_bert.to(text_seq.device) + + ssl_content = ssl(ref_audio) + if is_half: + ssl_content = ssl_content.half() + ssl_content = ssl_content.to(device) + + sv_model = ExportERes2NetV2(export_torch_script.sv_cn_model) + + # vits_path = "SoVITS_weights_v2/xw_e8_s216.pth" + vits = VitsModel(vits_path, version,is_half=is_half,device=device) + vits.eval() + vits = StepVitsModel(vits, sv_model) + + # gpt_path = "GPT_weights_v2/xw-e15.ckpt" + # dict_s1 = torch.load(gpt_path, map_location=device) + dict_s1 = torch.load(gpt_path, weights_only=False) + raw_t2s = get_raw_t2s_model(dict_s1).to(device) + print("#### get_raw_t2s_model ####") + print(raw_t2s.config) + if is_half: + raw_t2s = raw_t2s.half() + t2s_m = T2SModel(raw_t2s) + t2s_m.eval() + # t2s = torch.jit.script(t2s_m).to(device) + t2s = t2s_m + print("#### script t2s_m ####") + + print("vits.hps.data.sampling_rate:", vits.hps.data.sampling_rate) + + stream_t2s = StreamT2SModel(t2s).to(device) + stream_t2s = torch.jit.script(stream_t2s) + + ref_audio_sr = resamplex(ref_audio, 16000, 32000) + ref_audio_sr = ref_audio_sr.to(device) + if is_half: + ref_audio_sr = ref_audio_sr.half() + + top_k = 15 + + prompts = vits.extract_latent(ssl_content) + + audio_16k = resamplex(ref_audio_sr, 32000, 16000).to(ref_audio_sr.dtype) + sv_emb = sv_model(audio_16k) + print("text_seq",text_seq.shape) + # torch.jit.trace() + + refer,sv_emb = vits.ref_handle(ref_audio_sr) + + st = time.time() + et = time.time() + + y_len, y, xy_pos, k_cache, v_cache = stream_t2s.pre_infer(prompts, ref_seq, text_seq, ref_bert, text_bert, top_k) + idx = 1 + print("y.shape:", y.shape) + while True: + y, xy_pos, last_token, k_cache, v_cache = stream_t2s(idx, top_k, y_len, y, xy_pos, k_cache, v_cache) + # print("y.shape:", y.shape) + + idx+=1 + # print(idx,'/',1500 , y.shape, y[0,-1].item(), stop) + if idx>1500: + break + + if last_token == t2s.EOS: + break + + at = time.time() + print("EOS:",t2s.EOS) + + print(f"frist token: {et - st:.4f} seconds") + print(f"all token: {at - st:.4f} seconds") + print("sv_emb", sv_emb.shape) + print("refer",refer.shape) + y = y[:,-idx:].unsqueeze(0) + print("y", y.shape) + audio = vits(y, text_seq, refer, sv_emb) + soundfile.write(f"{output_path}/out_final.wav", audio.float().detach().cpu().numpy(), 32000) + + torch._dynamo.mark_dynamic(ssl_content, 2) + torch._dynamo.mark_dynamic(ref_audio_sr, 1) + torch._dynamo.mark_dynamic(ref_seq, 1) + torch._dynamo.mark_dynamic(text_seq, 1) + torch._dynamo.mark_dynamic(ref_bert, 0) + torch._dynamo.mark_dynamic(text_bert, 0) + torch._dynamo.mark_dynamic(refer, 2) + torch._dynamo.mark_dynamic(y, 2) + + inputs = { + "forward": (y, text_seq, refer, sv_emb), + "extract_latent": ssl_content, + "ref_handle": ref_audio_sr, + } + + stream_t2s.save(f"{output_path}/t2s.pt") + torch.jit.trace_module(vits, inputs=inputs, optimize=True).save(f"{output_path}/vits.pt") + torch.jit.script(find_best_audio_offset_fast, optimize=True).save(f"{output_path}/find_best_audio_offset_fast.pt") + +import argparse +import os + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="GPT-SoVITS Command Line Tool") + parser.add_argument("--gpt_model", required=True, help="Path to the GPT model file") + parser.add_argument( + "--sovits_model", required=True, help="Path to the SoVITS model file" + ) + parser.add_argument( + "--ref_audio", required=True, help="Path to the reference audio file" + ) + parser.add_argument( + "--ref_text", required=True, help="Path to the reference text file" + ) + parser.add_argument( + "--output_path", required=True, help="Path to the output directory" + ) + parser.add_argument("--device", help="Device to use", default="cuda" if torch.cuda.is_available() else "cpu") + parser.add_argument("--version", help="version of the model", default="v2Pro") + parser.add_argument("--no-half", action="store_true", help = "Do not use half precision for model weights") + parser.add_argument("--lang", default="auto", help="Language for text processing (default: auto)") + + args = parser.parse_args() + + if not os.path.exists(args.output_path): + os.makedirs(args.output_path) + + is_half = not args.no_half + with torch.no_grad(): + export_prov2( + gpt_path=args.gpt_model, + vits_path=args.sovits_model, + version=args.version, + ref_audio_path=args.ref_audio, + ref_text=args.ref_text, + output_path=args.output_path, + device=args.device, + is_half=is_half, + lang=args.lang, + ) From 60a4a214aff18057bb4ce76643d3b85de4bb67a4 Mon Sep 17 00:00:00 2001 From: wzy3650 <48899243+wzy3650@users.noreply.github.com> Date: Fri, 28 Nov 2025 21:57:13 +0800 Subject: [PATCH 06/20] vq distributed training support (#2577) Co-authored-by: wangzeyuan --- GPT_SoVITS/module/core_vq.py | 93 ++++++++++++----- GPT_SoVITS/module/ddp_utils.py | 181 +++++++++++++++++++++++++++++++++ GPT_SoVITS/module/distrib.py | 123 ++++++++++++++++++++++ 3 files changed, 373 insertions(+), 24 deletions(-) create mode 100644 GPT_SoVITS/module/ddp_utils.py create mode 100644 GPT_SoVITS/module/distrib.py diff --git a/GPT_SoVITS/module/core_vq.py b/GPT_SoVITS/module/core_vq.py index b7dab317..40745386 100644 --- a/GPT_SoVITS/module/core_vq.py +++ b/GPT_SoVITS/module/core_vq.py @@ -37,6 +37,10 @@ from einops import rearrange, repeat import torch from torch import nn import torch.nn.functional as F +import torch.distributed as dist + +from module.distrib import broadcast_tensors, is_distributed +from module.ddp_utils import SyncFunction from tqdm import tqdm @@ -69,27 +73,40 @@ def sample_vectors(samples, num: int): return samples[indices] -def kmeans(samples, num_clusters: int, num_iters: int = 10): - dim, dtype = samples.shape[-1], samples.dtype - max_kmeans_samples = 500 - samples = samples[:max_kmeans_samples, :] +def kmeans(samples, num_clusters: int, num_iters: int = 10, frames_to_use: int = 10_000, batch_size: int = 64): + N, D = samples.shape + dtype, device = samples.dtype, samples.device + + if frames_to_use < N: + indices = torch.randperm(N, device=device)[:frames_to_use] + samples = samples[indices] + means = sample_vectors(samples, num_clusters) print("kmeans start ... ") for _ in tqdm(range(num_iters)): - diffs = rearrange(samples, "n d -> n () d") - rearrange(means, "c d -> () c d") - dists = -(diffs**2).sum(dim=-1) + # Store cluster assignments + all_assignments = [] - buckets = dists.max(dim=-1).indices + for i in range(0, samples.shape[0], batch_size): + batch = samples[i : i + batch_size] # [B, D] + dists = torch.cdist(batch, means, p=2) # [B, C] + assignments = dists.argmin(dim=1) # [B] + all_assignments.append(assignments) + + buckets = torch.cat(all_assignments, dim=0) # [N] bins = torch.bincount(buckets, minlength=num_clusters) zero_mask = bins == 0 bins_min_clamped = bins.masked_fill(zero_mask, 1) - new_means = buckets.new_zeros(num_clusters, dim, dtype=dtype) - new_means.scatter_add_(0, repeat(buckets, "n -> n d", d=dim), samples) - new_means = new_means / bins_min_clamped[..., None] + # Compute new means + new_means = torch.zeros_like(means) + for i in range(num_clusters): + mask = buckets == i + if mask.any(): + new_means[i] = samples[mask].mean(dim=0) - means = torch.where(zero_mask[..., None], means, new_means) + means = torch.where(zero_mask[:, None], means, new_means) return means, bins @@ -141,13 +158,24 @@ class EuclideanCodebook(nn.Module): if self.inited: return - embed, cluster_size = kmeans(data, self.codebook_size, self.kmeans_iters) + if dist.is_available() and dist.is_initialized(): + # [B * T * world_size, D] + data = SyncFunction.apply(data) + + if dist.get_rank() == 0: + embed, cluster_size = kmeans(data, self.codebook_size, self.kmeans_iters) + else: + embed = torch.empty_like(self.embed) + cluster_size = torch.empty_like(self.cluster_size) + dist.broadcast(embed, src=0) + dist.broadcast(cluster_size, src=0) + self.embed.data.copy_(embed) self.embed_avg.data.copy_(embed.clone()) self.cluster_size.data.copy_(cluster_size) self.inited.data.copy_(torch.Tensor([True])) # Make sure all buffers across workers are in sync after initialization - # broadcast_tensors(self.buffers()) + broadcast_tensors(self.buffers()) def replace_(self, samples, mask): modified_codebook = torch.where(mask[..., None], sample_vectors(samples, self.codebook_size), self.embed) @@ -161,9 +189,17 @@ class EuclideanCodebook(nn.Module): if not torch.any(expired_codes): return - batch_samples = rearrange(batch_samples, "... d -> (...) d") - self.replace_(batch_samples, mask=expired_codes) - # broadcast_tensors(self.buffers()) + if is_distributed(): + # [B * T * world_size, D] + batch_samples = SyncFunction.apply(batch_samples) + + if dist.get_rank() == 0: + new_embeds = sample_vectors(batch_samples, expired_codes.sum()) + else: + new_embeds = torch.zeros(expired_codes.sum(), self.embed.size(1), device=self.embed.device) + dist.broadcast(new_embeds, src=0) + self.embed.data[expired_codes] = new_embeds + broadcast_tensors(self.buffers()) def preprocess(self, x): x = rearrange(x, "... d -> (...) d") @@ -208,17 +244,26 @@ class EuclideanCodebook(nn.Module): quantize = self.dequantize(embed_ind) if self.training: + ### Update codebook by EMA + embed_onehot_sum = embed_onehot.sum(0) # [cb-size,] + embed_sum = x.t() @ embed_onehot # [D, cb-size] + if is_distributed(): + dist.all_reduce(embed_onehot_sum) + dist.all_reduce(embed_sum) + # Update ema cluster count N_i^t, eq. (6) in vqvae paper + self.cluster_size.data.mul_(self.decay).add_(embed_onehot_sum, alpha=1 - self.decay) + # Update ema embed: eq. (7) in vqvae paper + self.embed_avg.data.mul_(self.decay).add_(embed_sum.t(), alpha=1 - self.decay) + # apply laplace smoothing + n = self.cluster_size.sum() + cluster_size = (self.cluster_size + self.epsilon) / (n + self.codebook_size * self.epsilon) * n + # Update ema embed: eq. (8) in vqvae paper + embed_normalized = self.embed_avg / cluster_size.unsqueeze(1) + self.embed.data.copy_(embed_normalized) + # We do the expiry of code at that point as buffers are in sync # and all the workers will take the same decision. self.expire_codes_(x) - ema_inplace(self.cluster_size, embed_onehot.sum(0), self.decay) - embed_sum = x.t() @ embed_onehot - ema_inplace(self.embed_avg, embed_sum.t(), self.decay) - cluster_size = ( - laplace_smoothing(self.cluster_size, self.codebook_size, self.epsilon) * self.cluster_size.sum() - ) - embed_normalized = self.embed_avg / cluster_size.unsqueeze(1) - self.embed.data.copy_(embed_normalized) return quantize, embed_ind diff --git a/GPT_SoVITS/module/ddp_utils.py b/GPT_SoVITS/module/ddp_utils.py new file mode 100644 index 00000000..af30dd3f --- /dev/null +++ b/GPT_SoVITS/module/ddp_utils.py @@ -0,0 +1,181 @@ +import torch +from torch.nn.parallel import DistributedDataParallel +from torch.nn.parallel.distributed import _find_tensors +from packaging import version + + +# from https://github.com/Lightning-AI/lightning-bolts/blob/5d61197cd2f491f69e238137a5edabe80ae14ad9/pl_bolts/models/self_supervised/simclr/simclr_module.py#L20 +class SyncFunction(torch.autograd.Function): + @staticmethod + # @torch.no_grad() + def forward(ctx, tensor): + world_size = torch.distributed.get_world_size() + + # Collect batch sizes from all processes + local_bs = torch.tensor([tensor.shape[0]], device=tensor.device) + batch_sizes = [torch.zeros_like(local_bs) for _ in range(world_size)] + torch.distributed.all_gather(batch_sizes, local_bs) + + # Convert to integer list and find the minimum + batch_sizes_int = [bs.item() for bs in batch_sizes] + min_bs = min(batch_sizes_int) + + # Crop the tensor to the minimum batch size if needed + cropped_tensor = tensor[:min_bs] if tensor.shape[0] > min_bs else tensor + + # Prepare for gathering + out_shape = (min_bs * world_size,) + tensor.shape[1:] + gathered_tensor = torch.zeros(out_shape, dtype=tensor.dtype, device=tensor.device) + + # Build tensor list for all_gather + tensor_list = list(torch.chunk(gathered_tensor, world_size)) + + # Perform all_gather using the cropped tensors + torch.distributed.all_gather(tensor_list, cropped_tensor) + + # Save for backward pass + ctx.min_bs = min_bs + ctx.world_size = world_size + ctx.orig_shape = tensor.shape + + return gathered_tensor + + @staticmethod + def backward(ctx, grad_output): + assert False + grad_input = grad_output.clone() + torch.distributed.all_reduce(grad_input, op=torch.distributed.ReduceOp.SUM, async_op=False) + + idx_from = torch.distributed.get_rank() * ctx.batch_size + idx_to = (torch.distributed.get_rank() + 1) * ctx.batch_size + return grad_input[idx_from:idx_to] + +class DDP(DistributedDataParallel): + """ + Override the forward call in lightning so it goes to training and validation step respectively + """ + + def forward(self, *inputs, **kwargs): # pragma: no cover + if version.parse(torch.__version__[:6]) < version.parse("1.11"): + self._sync_params() + inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids) + assert len(self.device_ids) == 1 + if self.module.training: + output = self.module.training_step(*inputs[0], **kwargs[0]) + elif self.module.testing: + output = self.module.test_step(*inputs[0], **kwargs[0]) + else: + output = self.module.validation_step(*inputs[0], **kwargs[0]) + if torch.is_grad_enabled(): + # We'll return the output object verbatim since it is a freeform + # object. We need to find any tensors in this object, though, + # because we need to figure out which parameters were used during + # this forward pass, to ensure we short circuit reduction for any + # unused parameters. Only if `find_unused_parameters` is set. + if self.find_unused_parameters: + self.reducer.prepare_for_backward(list(_find_tensors(output))) + else: + self.reducer.prepare_for_backward([]) + else: + from torch.nn.parallel.distributed import ( + Join, + _DDPSink, + _tree_flatten_with_rref, + _tree_unflatten_with_rref, + ) + + with torch.autograd.profiler.record_function("DistributedDataParallel.forward"): + if torch.is_grad_enabled() and self.require_backward_grad_sync: + self.logger.set_runtime_stats_and_log() + self.num_iterations += 1 + self.reducer.prepare_for_forward() + + # Notify the join context that this process has not joined, if + # needed + work = Join.notify_join_context(self) + if work: + self.reducer._set_forward_pass_work_handle(work, self._divide_by_initial_world_size) + + # Calling _rebuild_buckets before forward compuation, + # It may allocate new buckets before deallocating old buckets + # inside _rebuild_buckets. To save peak memory usage, + # call _rebuild_buckets before the peak memory usage increases + # during forward computation. + # This should be called only once during whole training period. + if torch.is_grad_enabled() and self.reducer._rebuild_buckets(): + print("Reducer buckets have been rebuilt in this iteration.") + self._has_rebuilt_buckets = True + + # sync params according to location (before/after forward) user + # specified as part of hook, if hook was specified. + buffer_hook_registered = hasattr(self, "buffer_hook") + if self._check_sync_bufs_pre_fwd(): + self._sync_buffers() + + if self._join_config.enable: + # Notify joined ranks whether they should sync in backwards pass or not. + self._check_global_requires_backward_grad_sync(is_joined_rank=False) + + inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids) + if self.module.training: + output = self.module.training_step(*inputs[0], **kwargs[0]) + elif self.module.testing: + output = self.module.test_step(*inputs[0], **kwargs[0]) + else: + output = self.module.validation_step(*inputs[0], **kwargs[0]) + + # sync params according to location (before/after forward) user + # specified as part of hook, if hook was specified. + if self._check_sync_bufs_post_fwd(): + self._sync_buffers() + + if torch.is_grad_enabled() and self.require_backward_grad_sync: + self.require_forward_param_sync = True + # We'll return the output object verbatim since it is a freeform + # object. We need to find any tensors in this object, though, + # because we need to figure out which parameters were used during + # this forward pass, to ensure we short circuit reduction for any + # unused parameters. Only if `find_unused_parameters` is set. + if self.find_unused_parameters and not self.static_graph: + # Do not need to populate this for static graph. + self.reducer.prepare_for_backward(list(_find_tensors(output))) + else: + self.reducer.prepare_for_backward([]) + else: + self.require_forward_param_sync = False + + # TODO: DDPSink is currently enabled for unused parameter detection and + # static graph training for first iteration. + if (self.find_unused_parameters and not self.static_graph) or ( + self.static_graph and self.num_iterations == 1 + ): + state_dict = { + "static_graph": self.static_graph, + "num_iterations": self.num_iterations, + } + + output_tensor_list, treespec, output_is_rref = _tree_flatten_with_rref(output) + output_placeholders = [None for _ in range(len(output_tensor_list))] + # Do not touch tensors that have no grad_fn, which can cause issues + # such as https://github.com/pytorch/pytorch/issues/60733 + for i, output in enumerate(output_tensor_list): + if torch.is_tensor(output) and output.grad_fn is None: + output_placeholders[i] = output + + # When find_unused_parameters=True, makes tensors which require grad + # run through the DDPSink backward pass. When not all outputs are + # used in loss, this makes those corresponding tensors receive + # undefined gradient which the reducer then handles to ensure + # param.grad field is not touched and we don't error out. + passthrough_tensor_list = _DDPSink.apply( + self.reducer, + state_dict, + *output_tensor_list, + ) + for i in range(len(output_placeholders)): + if output_placeholders[i] is None: + output_placeholders[i] = passthrough_tensor_list[i] + + # Reconstruct output data structure. + output = _tree_unflatten_with_rref(output_placeholders, treespec, output_is_rref) + return output diff --git a/GPT_SoVITS/module/distrib.py b/GPT_SoVITS/module/distrib.py new file mode 100644 index 00000000..cabf8f8a --- /dev/null +++ b/GPT_SoVITS/module/distrib.py @@ -0,0 +1,123 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +"""Torch distributed utilities.""" + +import typing as tp + +import torch + + +def rank(): + if torch.distributed.is_initialized(): + return torch.distributed.get_rank() + else: + return 0 + + +def world_size(): + if torch.distributed.is_initialized(): + return torch.distributed.get_world_size() + else: + return 1 + + +def is_distributed(): + return world_size() > 1 + + +def all_reduce(tensor: torch.Tensor, op=torch.distributed.ReduceOp.SUM): + if is_distributed(): + return torch.distributed.all_reduce(tensor, op) + + +def _is_complex_or_float(tensor): + return torch.is_floating_point(tensor) or torch.is_complex(tensor) + + +def _check_number_of_params(params: tp.List[torch.Tensor]): + # utility function to check that the number of params in all workers is the same, + # and thus avoid a deadlock with distributed all reduce. + if not is_distributed() or not params: + return + # print('params[0].device ', params[0].device) + tensor = torch.tensor([len(params)], device=params[0].device, dtype=torch.long) + all_reduce(tensor) + if tensor.item() != len(params) * world_size(): + # If not all the workers have the same number, for at least one of them, + # this inequality will be verified. + raise RuntimeError( + f"Mismatch in number of params: ours is {len(params)}, at least one worker has a different one." + ) + + +def broadcast_tensors(tensors: tp.Iterable[torch.Tensor], src: int = 0): + """Broadcast the tensors from the given parameters to all workers. + This can be used to ensure that all workers have the same model to start with. + """ + if not is_distributed(): + return + tensors = [tensor for tensor in tensors if _is_complex_or_float(tensor)] + _check_number_of_params(tensors) + handles = [] + for tensor in tensors: + handle = torch.distributed.broadcast(tensor.data, src=src, async_op=True) + handles.append(handle) + for handle in handles: + handle.wait() + + +def sync_buffer(buffers, average=True): + """ + Sync grad for buffers. If average is False, broadcast instead of averaging. + """ + if not is_distributed(): + return + handles = [] + for buffer in buffers: + if torch.is_floating_point(buffer.data): + if average: + handle = torch.distributed.all_reduce(buffer.data, op=torch.distributed.ReduceOp.SUM, async_op=True) + else: + handle = torch.distributed.broadcast(buffer.data, src=0, async_op=True) + handles.append((buffer, handle)) + for buffer, handle in handles: + handle.wait() + if average: + buffer.data /= world_size + + +def sync_grad(params): + """ + Simpler alternative to DistributedDataParallel, that doesn't rely + on any black magic. For simple models it can also be as fast. + Just call this on your model parameters after the call to backward! + """ + if not is_distributed(): + return + handles = [] + for p in params: + if p.grad is not None: + handle = torch.distributed.all_reduce(p.grad.data, op=torch.distributed.ReduceOp.SUM, async_op=True) + handles.append((p, handle)) + for p, handle in handles: + handle.wait() + p.grad.data /= world_size() + + +def average_metrics(metrics: tp.Dict[str, float], count=1.0): + """Average a dictionary of metrics across all workers, using the optional + `count` as unormalized weight. + """ + if not is_distributed(): + return metrics + keys, values = zip(*metrics.items()) + device = "cuda" if torch.cuda.is_available() else "cpu" + tensor = torch.tensor(list(values) + [1], device=device, dtype=torch.float32) + tensor *= count + all_reduce(tensor) + averaged = (tensor[:-1] / tensor[-1]).cpu().tolist() + return dict(zip(keys, averaged)) From cb00840c4e06ba6cc27e811d3319461fcc4e2bd1 Mon Sep 17 00:00:00 2001 From: RVC-Boss <129054828+RVC-Boss@users.noreply.github.com> Date: Fri, 28 Nov 2025 22:02:03 +0800 Subject: [PATCH 07/20] Add files via upload --- GPT_SoVITS/s2_train.py | 2 +- GPT_SoVITS/s2_train_v3.py | 4 ++-- GPT_SoVITS/s2_train_v3_lora.py | 4 ++-- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/GPT_SoVITS/s2_train.py b/GPT_SoVITS/s2_train.py index 4b9f6488..333e6a05 100644 --- a/GPT_SoVITS/s2_train.py +++ b/GPT_SoVITS/s2_train.py @@ -124,7 +124,7 @@ def run(rank, n_gpus, hps): collate_fn=collate_fn, batch_sampler=train_sampler, persistent_workers=True, - prefetch_factor=4, + prefetch_factor=3, ) # if rank == 0: # eval_dataset = TextAudioSpeakerLoader(hps.data.validation_files, hps.data, val=True) diff --git a/GPT_SoVITS/s2_train_v3.py b/GPT_SoVITS/s2_train_v3.py index aa8dae7f..bcde98a8 100644 --- a/GPT_SoVITS/s2_train_v3.py +++ b/GPT_SoVITS/s2_train_v3.py @@ -118,13 +118,13 @@ def run(rank, n_gpus, hps): collate_fn = TextAudioSpeakerCollate() train_loader = DataLoader( train_dataset, - num_workers=6, + num_workers=5, shuffle=False, pin_memory=True, collate_fn=collate_fn, batch_sampler=train_sampler, persistent_workers=True, - prefetch_factor=4, + prefetch_factor=3, ) # if rank == 0: # eval_dataset = TextAudioSpeakerLoader(hps.data.validation_files, hps.data, val=True) diff --git a/GPT_SoVITS/s2_train_v3_lora.py b/GPT_SoVITS/s2_train_v3_lora.py index ba9e4ed4..ff62ccfe 100644 --- a/GPT_SoVITS/s2_train_v3_lora.py +++ b/GPT_SoVITS/s2_train_v3_lora.py @@ -120,13 +120,13 @@ def run(rank, n_gpus, hps): collate_fn = TextAudioSpeakerCollate() train_loader = DataLoader( train_dataset, - num_workers=6, + num_workers=5, shuffle=False, pin_memory=True, collate_fn=collate_fn, batch_sampler=train_sampler, persistent_workers=True, - prefetch_factor=4, + prefetch_factor=3, ) save_root = "%s/logs_s2_%s_lora_%s" % (hps.data.exp_dir, hps.model.version, hps.train.lora_rank) os.makedirs(save_root, exist_ok=True) From c85c54eca99a2fd01d6b574584217d0ecfbd90c1 Mon Sep 17 00:00:00 2001 From: XXXXRT666 <157766680+XXXXRT666@users.noreply.github.com> Date: Fri, 28 Nov 2025 14:10:49 +0000 Subject: [PATCH 08/20] Add ModelScope Snapshot Download For ASR (#2627) * Add ModelScope Snapshot Download For ASR * Typo Fix * Remove YUE in whisper * Remove HF ENDPOINT * Add FunASR Download --- requirements.txt | 4 +- tools/asr/config.py | 31 +++----------- tools/asr/fasterwhisper_asr.py | 76 ++++++++++++++++++++++------------ tools/asr/funasr_asr.py | 46 ++++++++++---------- tools/i18n/locale/en_US.json | 4 +- webui.py | 2 - 6 files changed, 81 insertions(+), 82 deletions(-) diff --git a/requirements.txt b/requirements.txt index 90e4957d..578bb87c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -16,7 +16,7 @@ pypinyin pyopenjtalk>=0.4.1 g2p_en torchaudio -modelscope==1.10.0 +modelscope sentencepiece transformers>=4.43,<=4.50 peft @@ -39,7 +39,5 @@ x_transformers torchmetrics<=1.5 pydantic<=2.10.6 ctranslate2>=4.0,<5 -huggingface_hub>=0.13 -tokenizers>=0.13,<1 av>=11 tqdm diff --git a/tools/asr/config.py b/tools/asr/config.py index 9c26a4f6..fdff7518 100644 --- a/tools/asr/config.py +++ b/tools/asr/config.py @@ -1,34 +1,13 @@ -import os - - -def check_fw_local_models(): - """ - 启动时检查本地是否有 Faster Whisper 模型. - """ - model_size_list = [ - "medium", - "medium.en", - "distil-large-v2", - "distil-large-v3", - "large-v1", - "large-v2", - "large-v3", - ] - for i, size in enumerate(model_size_list): - if os.path.exists(f"tools/asr/models/faster-whisper-{size}"): - model_size_list[i] = size + "-local" - return model_size_list - - def get_models(): model_size_list = [ "medium", "medium.en", - "distil-large-v2", - "distil-large-v3", - "large-v1", "large-v2", "large-v3", + "large-v3-turbo", + "distil-large-v2", + "distil-large-v3", + "distil-large-v3.5", ] return model_size_list @@ -36,7 +15,7 @@ def get_models(): asr_dict = { "达摩 ASR (中文)": {"lang": ["zh", "yue"], "size": ["large"], "path": "funasr_asr.py", "precision": ["float32"]}, "Faster Whisper (多语种)": { - "lang": ["auto", "zh", "en", "ja", "ko", "yue"], + "lang": ["auto", "en", "ja", "ko"], "size": get_models(), "path": "fasterwhisper_asr.py", "precision": ["float32", "float16", "int8"], diff --git a/tools/asr/fasterwhisper_asr.py b/tools/asr/fasterwhisper_asr.py index a2ebe975..72a4b82a 100644 --- a/tools/asr/fasterwhisper_asr.py +++ b/tools/asr/fasterwhisper_asr.py @@ -1,12 +1,12 @@ import argparse import os -import time import traceback +import requests import torch from faster_whisper import WhisperModel -from huggingface_hub import snapshot_download -from huggingface_hub.errors import LocalEntryNotFoundError +from huggingface_hub import snapshot_download as snapshot_download_hf +from modelscope import snapshot_download as snapshot_download_ms from tqdm import tqdm from tools.asr.config import get_models @@ -40,11 +40,35 @@ language_code_list = [ def download_model(model_size: str): - if "distil" in model_size: - repo_id = "Systran/faster-{}-whisper-{}".format(*model_size.split("-", maxsplit=1)) + url = "https://huggingface.co/api/models/gpt2" + try: + requests.get(url, timeout=3) + source = "HF" + except Exception: + source = "ModelScope" + + model_path = "" + if source == "HF": + if "distil" in model_size: + if "3.5" in model_size: + repo_id = "distil-whisper/distil-large-v3.5-ct2" + model_path = "tools/asr/models/faster-whisper-distil-large-v3.5" + else: + repo_id = "Systran/faster-{}-whisper-{}".format(*model_size.split("-", maxsplit=1)) + elif model_size == "large-v3-turbo": + repo_id = "mobiuslabsgmbh/faster-whisper-large-v3-turbo" + model_path = "tools/asr/models/faster-whisper-large-v3-turbo" + else: + repo_id = f"Systran/faster-whisper-{model_size}" + model_path = ( + model_path + or f"tools/asr/models/{repo_id.replace('Systran/', '').replace('distil-whisper/', '', 1)}".replace( + "distil-whisper", "whisper-distil" + ) + ) else: - repo_id = f"Systran/faster-whisper-{model_size}" - model_path = f"tools/asr/models/{repo_id.strip('Systran/')}" + repo_id = "XXXXRT/faster-whisper" + model_path = f"tools/asr/models/faster-whisper-{model_size}".replace("distil-whisper", "whisper-distil") files: list[str] = [ "config.json", @@ -58,26 +82,24 @@ def download_model(model_size: str): files.remove("vocabulary.txt") - for attempt in range(2): - try: - snapshot_download( - repo_id=repo_id, - allow_patterns=files, - local_dir=model_path, - ) - break - except LocalEntryNotFoundError: - if attempt < 1: - time.sleep(2) - else: - print("[ERROR] LocalEntryNotFoundError and no fallback.") - traceback.print_exc() - exit(1) - except Exception as e: - print(f"[ERROR] Unexpected error on attempt {attempt + 1}: {e}") - traceback.print_exc() - exit(1) + if source == "ModelScope": + files = [f"faster-whisper-{model_size}/{file}".replace("whisper-distil", "distil-whisper") for file in files] + if source == "HF": + print(f"Downloading model from HuggingFace: {repo_id} to {model_path}") + snapshot_download_hf( + repo_id, + local_dir=model_path, + local_dir_use_symlinks=False, + allow_patterns=files, + ) + else: + print(f"Downloading model from ModelScope: {repo_id} to {model_path}") + snapshot_download_ms( + repo_id, + local_dir=model_path, + allow_patterns=files, + ) return model_path @@ -106,7 +128,7 @@ def execute_asr(input_folder, output_folder, model_path, language, precision): ) text = "" - if info.language == "zh": + if info.language in ["zh", "yue"]: print("检测为中文文本, 转 FunASR 处理") text = only_asr(file_path, language=info.language.lower()) diff --git a/tools/asr/funasr_asr.py b/tools/asr/funasr_asr.py index b0ffceb0..6a5c9989 100644 --- a/tools/asr/funasr_asr.py +++ b/tools/asr/funasr_asr.py @@ -4,9 +4,8 @@ import argparse import os import traceback -# from funasr.utils import version_checker -# version_checker.check_for_update = lambda: None from funasr import AutoModel +from modelscope import snapshot_download from tqdm import tqdm funasr_models = {} # 存储模型避免重复加载 @@ -16,40 +15,43 @@ def only_asr(input_file, language): try: model = create_model(language) text = model.generate(input=input_file)[0]["text"] - except: + except Exception: text = "" print(traceback.format_exc()) return text def create_model(language="zh"): - path_vad = "tools/asr/models/speech_fsmn_vad_zh-cn-16k-common-pytorch" - path_punc = "tools/asr/models/punc_ct-transformer_zh-cn-common-vocab272727-pytorch" - path_vad = path_vad if os.path.exists(path_vad) else "iic/speech_fsmn_vad_zh-cn-16k-common-pytorch" - path_punc = path_punc if os.path.exists(path_punc) else "iic/punc_ct-transformer_zh-cn-common-vocab272727-pytorch" - vad_model_revision = punc_model_revision = "v2.0.4" - if language == "zh": + path_vad = "tools/asr/models/speech_fsmn_vad_zh-cn-16k-common-pytorch" + path_punc = "tools/asr/models/punc_ct-transformer_zh-cn-common-vocab272727-pytorch" path_asr = "tools/asr/models/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch" - path_asr = ( - path_asr - if os.path.exists(path_asr) - else "iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch" + snapshot_download( + "iic/speech_fsmn_vad_zh-cn-16k-common-pytorch", + local_dir="tools/asr/models/speech_fsmn_vad_zh-cn-16k-common-pytorch", + ) + snapshot_download( + "iic/punc_ct-transformer_zh-cn-common-vocab272727-pytorch", + local_dir="tools/asr/models/punc_ct-transformer_zh-cn-common-vocab272727-pytorch", + ) + snapshot_download( + "iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch", + local_dir="tools/asr/models/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch", ) model_revision = "v2.0.4" elif language == "yue": path_asr = "tools/asr/models/speech_UniASR_asr_2pass-cantonese-CHS-16k-common-vocab1468-tensorflow1-online" - path_asr = ( - path_asr - if os.path.exists(path_asr) - else "iic/speech_UniASR_asr_2pass-cantonese-CHS-16k-common-vocab1468-tensorflow1-online" + snapshot_download( + "iic/speech_UniASR_asr_2pass-cantonese-CHS-16k-common-vocab1468-tensorflow1-online", + local_dir="tools/asr/models/speech_UniASR_asr_2pass-cantonese-CHS-16k-common-vocab1468-tensorflow1-online", ) - model_revision = "master" path_vad = path_punc = None - vad_model_revision = punc_model_revision = None - ###友情提示:粤语带VAD识别可能会有少量shape不对报错的,但是不带VAD可以.不带vad只能分阶段单独加标点。不过标点模型对粤语效果真的不行… + vad_model_revision = punc_model_revision = "" + model_revision = "master" else: - raise ValueError("FunASR 不支持该语言" + ": " + language) + raise ValueError(f"{language} is not supported") + + vad_model_revision = punc_model_revision = "v2.0.4" if language in funasr_models: return funasr_models[language] @@ -83,7 +85,7 @@ def execute_asr(input_folder, output_folder, model_size, language): file_path = os.path.join(input_folder, file_name) text = model.generate(input=file_path)[0]["text"] output.append(f"{file_path}|{output_file_name}|{language.upper()}|{text}") - except: + except Exception: print(traceback.format_exc()) output_folder = output_folder or "output/asr_opt" diff --git a/tools/i18n/locale/en_US.json b/tools/i18n/locale/en_US.json index 24d24de4..561d3bfd 100644 --- a/tools/i18n/locale/en_US.json +++ b/tools/i18n/locale/en_US.json @@ -38,7 +38,7 @@ "hop_size:怎么算音量曲线,越小精度越大计算量越高(不是精度越大效果越好)": "hop_size: FO hop size, the smaller the value, the higher the accuracy)", "max:归一化后最大值多少": "Loudness multiplier after normalized", "max_sil_kept:切完后静音最多留多长": "Maximum length for silence to be kept", - "min_interval:最短切割间隔": "Minumum interval for audio cutting", + "min_interval:最短切割间隔": "Minimum interval for audio cutting", "min_length:每段最小多长,如果第一段太短一直和后面段连起来直到超过这个值": "min_length: the minimum length of each segment. If the first segment is too short, it will be concatenated with the next segment until it exceeds this value", "temperature": "temperature", "threshold:音量小于这个值视作静音的备选切割点": "Noise gate threshold (loudness below this value will be treated as noise", @@ -176,7 +176,7 @@ "语音降噪": "Speech Denoising", "请上传3~10秒内参考音频,超过会报错!": "Please upload a reference audio within the 3-10 second range; if it exceeds this duration, it will raise errors.", "请上传参考音频": "Please Upload the Reference Audio", - "请填入推理文本": "Please Fill in the Terget Text", + "请填入推理文本": "Please Fill in the Target Text", "请填入正确的List路径": "Please Fill in the Correct List Path", "请填入正确的音频文件夹路径": "Please Fill in the Correct Audio Folder Path", "请输入有效文本": "Please enter valid text.", diff --git a/webui.py b/webui.py index cf5d8a3a..beb0963a 100644 --- a/webui.py +++ b/webui.py @@ -86,7 +86,6 @@ from config import ( from tools import my_utils from tools.my_utils import check_details, check_for_existance -os.environ["HF_ENDPOINT"] = "https://hf-mirror.com" os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE" # os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1' # 当遇到mps不支持的步骤时使用cpu @@ -1980,4 +1979,3 @@ with gr.Blocks(title="GPT-SoVITS WebUI", analytics_enabled=False, js=js, css=css server_port=webui_port_main, # quiet=True, ) - From 6fb441f65e4b0573d7f7b16d96dc1917d38eda64 Mon Sep 17 00:00:00 2001 From: ChasonJiang <46401978+ChasonJiang@users.noreply.github.com> Date: Fri, 28 Nov 2025 22:13:48 +0800 Subject: [PATCH 09/20] =?UTF-8?q?=E6=9B=B4=E5=8F=8B=E5=A5=BD=E7=9A=84?= =?UTF-8?q?=E6=B5=81=E6=A8=A1=E5=BC=8F=E9=80=89=E9=A1=B9=20(#2678)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- api_v2.py | 48 +++++++++++++++++++++++++++++++----------------- 1 file changed, 31 insertions(+), 17 deletions(-) diff --git a/api_v2.py b/api_v2.py index 5df2da66..8c83bb0f 100644 --- a/api_v2.py +++ b/api_v2.py @@ -41,11 +41,9 @@ POST: "repetition_penalty": 1.35, # float. repetition penalty for T2S model. "sample_steps": 32, # int. number of sampling steps for VITS model V3. "super_sampling": False, # bool. whether to use super-sampling for audio when using VITS model V3. - "return_fragment": False, # bool. step by step return the audio fragment. (Best Quality, Slowest response speed. old version of streaming mode) - "streaming_mode": False, # bool. return audio chunk by chunk. (Medium quality, Slow response speed) + "streaming_mode": False, # bool or int. return audio chunk by chunk.T he available options are: 0,1,2,3 or True/False (0/False: Disabled | 1/True: Best Quality, Slowest response speed (old version streaming_mode) | 2: Medium Quality, Slow response speed | 3: Lower Quality, Faster response speed ) "overlap_length": 2, # int. overlap length of semantic tokens for streaming mode. - "min_chunk_length": 16, # int. The minimum chunk length of semantic tokens for streaming mode. (affects audio chunk size) - "fixed_length_chunk": False, # bool. When turned on, it can achieve faster streaming response, but with lower quality. (lower quality, faster response speed) + "min_chunk_length": 16, # int. The minimum chunk length of semantic tokens for streaming mode. (affects audio chunk size) } ``` @@ -106,7 +104,7 @@ RESP: import os import sys import traceback -from typing import Generator +from typing import Generator, Union now_dir = os.getcwd() sys.path.append(now_dir) @@ -171,15 +169,13 @@ class TTS_Request(BaseModel): fragment_interval: float = 0.3 seed: int = -1 media_type: str = "wav" - streaming_mode: bool = False + streaming_mode: Union[bool, int] = False parallel_infer: bool = True repetition_penalty: float = 1.35 sample_steps: int = 32 super_sampling: bool = False overlap_length: int = 2 min_chunk_length: int = 16 - return_fragment: bool = False - fixed_length_chunk: bool = False def pack_ogg(io_buffer: BytesIO, data: np.ndarray, rate: int): @@ -373,11 +369,9 @@ async def tts_handle(req: dict): "repetition_penalty": 1.35, # float. repetition penalty for T2S model. "sample_steps": 32, # int. number of sampling steps for VITS model V3. "super_sampling": False, # bool. whether to use super-sampling for audio when using VITS model V3. - "return_fragment": False, # bool. step by step return the audio fragment. (Best Quality, Slowest response speed. old version of streaming mode) - "streaming_mode": False, # bool. return audio chunk by chunk. (Medium quality, Slow response speed) + "streaming_mode": False, # bool or int. return audio chunk by chunk.T he available options are: 0,1,2,3 or True/False (0/False: Disabled | 1/True: Best Quality, Slowest response speed (old version streaming_mode) | 2: Medium Quality, Slow response speed | 3: Lower Quality, Faster response speed ) "overlap_length": 2, # int. overlap length of semantic tokens for streaming mode. - "min_chunk_length": 16, # int. The minimum chunk length of semantic tokens for streaming mode. (affects audio chunk size) - "fixed_length_chunk": False, # bool. When turned on, it can achieve faster streaming response, but with lower quality. (lower quality, faster response speed) + "min_chunk_length": 16, # int. The minimum chunk length of semantic tokens for streaming mode. (affects audio chunk size) } returns: StreamingResponse: audio stream response. @@ -390,9 +384,33 @@ async def tts_handle(req: dict): check_res = check_params(req) if check_res is not None: return check_res + + if streaming_mode == 0: + streaming_mode = False + return_fragment = False + fixed_length_chunk = False + elif streaming_mode == 1: + streaming_mode = False + return_fragment = True + fixed_length_chunk = False + elif streaming_mode == 2: + streaming_mode = True + return_fragment = False + fixed_length_chunk = False + elif streaming_mode == 3: + streaming_mode = True + return_fragment = False + fixed_length_chunk = True + + else: + return JSONResponse(status_code=400, content={"message": f"the value of streaming_mode must be 0, 1, 2, 3(int) or true/false(bool)"}) req["streaming_mode"] = streaming_mode req["return_fragment"] = return_fragment + req["fixed_length_chunk"] = fixed_length_chunk + + print(f"{streaming_mode} {return_fragment} {fixed_length_chunk}") + streaming_mode = streaming_mode or return_fragment @@ -457,11 +475,9 @@ async def tts_get_endpoint( repetition_penalty: float = 1.35, sample_steps: int = 32, super_sampling: bool = False, - return_fragment: bool = False, - streaming_mode: bool = False, + streaming_mode: Union[bool, int] = False, overlap_length: int = 2, min_chunk_length: int = 16, - fixed_length_chunk: bool = False, ): req = { "text": text, @@ -488,8 +504,6 @@ async def tts_get_endpoint( "super_sampling": super_sampling, "overlap_length": int(overlap_length), "min_chunk_length": int(min_chunk_length), - "return_fragment": return_fragment, - "fixed_length_chunk": fixed_length_chunk } return await tts_handle(req) From 92d2d337fd98673c126fd40727e067204e4523ae Mon Sep 17 00:00:00 2001 From: Spr_Aachen <51275522+Spr-Aachen@users.noreply.github.com> Date: Fri, 28 Nov 2025 22:53:43 +0800 Subject: [PATCH 10/20] Fix training error caused by float type of default_batch_size parameter (#2662) --- webui.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/webui.py b/webui.py index beb0963a..9c5ddd07 100644 --- a/webui.py +++ b/webui.py @@ -116,8 +116,8 @@ def set_default(): gpu_info = "\n".join(gpu_infos) if is_gpu_ok: minmem = min(mem) - default_batch_size = minmem // 2 if version not in v3v4set else minmem // 8 - default_batch_size_s1 = minmem // 2 + default_batch_size = int(minmem // 2 if version not in v3v4set else minmem // 8) + default_batch_size_s1 = int(minmem // 2) else: default_batch_size = default_batch_size_s1 = int(psutil.virtual_memory().total / 1024 / 1024 / 1024 / 4) if version not in v3v4set: From 857799276c3e8adcda7d662a55b07bf00bc1f01b Mon Sep 17 00:00:00 2001 From: XXXXRT666 <157766680+XXXXRT666@users.noreply.github.com> Date: Mon, 1 Dec 2025 03:13:15 +0000 Subject: [PATCH 11/20] Fix Modelscope (#2679) --- tools/asr/fasterwhisper_asr.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/tools/asr/fasterwhisper_asr.py b/tools/asr/fasterwhisper_asr.py index 72a4b82a..7e39fefb 100644 --- a/tools/asr/fasterwhisper_asr.py +++ b/tools/asr/fasterwhisper_asr.py @@ -52,7 +52,7 @@ def download_model(model_size: str): if "distil" in model_size: if "3.5" in model_size: repo_id = "distil-whisper/distil-large-v3.5-ct2" - model_path = "tools/asr/models/faster-whisper-distil-large-v3.5" + model_path = "tools/asr/models/faster-distil-whisper-large-v3.5" else: repo_id = "Systran/faster-{}-whisper-{}".format(*model_size.split("-", maxsplit=1)) elif model_size == "large-v3-turbo": @@ -61,14 +61,11 @@ def download_model(model_size: str): else: repo_id = f"Systran/faster-whisper-{model_size}" model_path = ( - model_path - or f"tools/asr/models/{repo_id.replace('Systran/', '').replace('distil-whisper/', '', 1)}".replace( - "distil-whisper", "whisper-distil" - ) + model_path or f"tools/asr/models/{repo_id.replace('Systran/', '').replace('distil-whisper/', '', 1)}" ) else: repo_id = "XXXXRT/faster-whisper" - model_path = f"tools/asr/models/faster-whisper-{model_size}".replace("distil-whisper", "whisper-distil") + model_path = "tools/asr/models" files: list[str] = [ "config.json", @@ -76,7 +73,7 @@ def download_model(model_size: str): "tokenizer.json", "vocabulary.txt", ] - if model_size == "large-v3" or "distil" in model_size: + if "large-v3" in model_size or "distil" in model_size: files.append("preprocessor_config.json") files.append("vocabulary.json") From fc533b6fb7d02e52ad297045ce436f3c8b1a8e53 Mon Sep 17 00:00:00 2001 From: RVC-Boss <129054828+RVC-Boss@users.noreply.github.com> Date: Mon, 1 Dec 2025 11:38:37 +0800 Subject: [PATCH 12/20] Update fasterwhisper_asr.py --- tools/asr/fasterwhisper_asr.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tools/asr/fasterwhisper_asr.py b/tools/asr/fasterwhisper_asr.py index 7e39fefb..927230d4 100644 --- a/tools/asr/fasterwhisper_asr.py +++ b/tools/asr/fasterwhisper_asr.py @@ -97,6 +97,7 @@ def download_model(model_size: str): local_dir=model_path, allow_patterns=files, ) + return model_path + f"/faster-whisper-{model_size}".replace("whisper-distil", "distil-whisper") return model_path From 9ec3a60f30d228719e5ec6cd6796c5b2d888dd1a Mon Sep 17 00:00:00 2001 From: RVC-Boss <129054828+RVC-Boss@users.noreply.github.com> Date: Mon, 1 Dec 2025 20:23:49 +0800 Subject: [PATCH 13/20] Update config.py --- tools/asr/config.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tools/asr/config.py b/tools/asr/config.py index fdff7518..90b2302a 100644 --- a/tools/asr/config.py +++ b/tools/asr/config.py @@ -5,9 +5,9 @@ def get_models(): "large-v2", "large-v3", "large-v3-turbo", - "distil-large-v2", - "distil-large-v3", - "distil-large-v3.5", + #"distil-large-v2", + #"distil-large-v3", + #"distil-large-v3.5", ] return model_size_list From 36b3231c6f7b1adabb1e93bb3c6449dbd85b7375 Mon Sep 17 00:00:00 2001 From: ChasonJiang <46401978+ChasonJiang@users.noreply.github.com> Date: Mon, 15 Dec 2025 14:23:06 +0800 Subject: [PATCH 14/20] bug fix (#2689) --- GPT_SoVITS/TTS_infer_pack/TTS.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/GPT_SoVITS/TTS_infer_pack/TTS.py b/GPT_SoVITS/TTS_infer_pack/TTS.py index be3d3a19..2e130978 100644 --- a/GPT_SoVITS/TTS_infer_pack/TTS.py +++ b/GPT_SoVITS/TTS_infer_pack/TTS.py @@ -1332,7 +1332,7 @@ class TTS: audio_fragment = self.vits_model.decode( _pred_semantic, phones, refer_audio_spec, speed=speed_factor, sv_emb=sv_emb ).detach()[0, 0, :] - batch_audio_fragment.append(audio_fragment) ###试试重建不带上prompt部分 + batch_audio_fragment.append(audio_fragment) ###试试重建不带上prompt部分 else: if parallel_infer: print(f"{i18n('并行合成中')}...") From cc89c3660e55792d1fd5cc57597a8e661cca708e Mon Sep 17 00:00:00 2001 From: RVC-Boss <129054828+RVC-Boss@users.noreply.github.com> Date: Fri, 19 Dec 2025 15:54:54 +0800 Subject: [PATCH 15/20] Update requirements.txt --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 578bb87c..3b7cd898 100644 --- a/requirements.txt +++ b/requirements.txt @@ -19,7 +19,7 @@ torchaudio modelscope sentencepiece transformers>=4.43,<=4.50 -peft +peft<0.18.0 chardet PyYAML psutil From abe984395cb6d8ed2055f5496d0bb26007f30365 Mon Sep 17 00:00:00 2001 From: ChasonJiang <46401978+ChasonJiang@users.noreply.github.com> Date: Fri, 19 Dec 2025 16:05:36 +0800 Subject: [PATCH 16/20] =?UTF-8?q?=E5=AF=B9=E9=BD=90gpt=20topk=E9=BB=98?= =?UTF-8?q?=E8=AE=A4=E9=87=87=E6=A0=B7=E5=8F=82=E6=95=B0=20(#2696)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- GPT_SoVITS/TTS_infer_pack/TTS.py | 4 ++-- GPT_SoVITS/inference_webui_fast.py | 2 +- api_v2.py | 8 ++++---- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/GPT_SoVITS/TTS_infer_pack/TTS.py b/GPT_SoVITS/TTS_infer_pack/TTS.py index 2e130978..9c8344b0 100644 --- a/GPT_SoVITS/TTS_infer_pack/TTS.py +++ b/GPT_SoVITS/TTS_infer_pack/TTS.py @@ -1008,7 +1008,7 @@ class TTS: "aux_ref_audio_paths": [], # list.(optional) auxiliary reference audio paths for multi-speaker tone fusion "prompt_text": "", # str.(optional) prompt text for the reference audio "prompt_lang": "", # str.(required) language of the prompt text for the reference audio - "top_k": 5, # int. top k sampling + "top_k": 15, # int. top k sampling "top_p": 1, # float. top p sampling "temperature": 1, # float. temperature for sampling "text_split_method": "cut1", # str. text split method, see text_segmentation_method.py for details. @@ -1039,7 +1039,7 @@ class TTS: aux_ref_audio_paths: list = inputs.get("aux_ref_audio_paths", []) prompt_text: str = inputs.get("prompt_text", "") prompt_lang: str = inputs.get("prompt_lang", "") - top_k: int = inputs.get("top_k", 5) + top_k: int = inputs.get("top_k", 15) top_p: float = inputs.get("top_p", 1) temperature: float = inputs.get("temperature", 1) text_split_method: str = inputs.get("text_split_method", "cut1") diff --git a/GPT_SoVITS/inference_webui_fast.py b/GPT_SoVITS/inference_webui_fast.py index 51a120f1..92d145b3 100644 --- a/GPT_SoVITS/inference_webui_fast.py +++ b/GPT_SoVITS/inference_webui_fast.py @@ -385,7 +385,7 @@ with gr.Blocks(title="GPT-SoVITS WebUI", analytics_enabled=False, js=js, css=css minimum=0.6, maximum=1.65, step=0.05, label="语速", value=1.0, interactive=True ) with gr.Row(): - top_k = gr.Slider(minimum=1, maximum=100, step=1, label=i18n("top_k"), value=5, interactive=True) + top_k = gr.Slider(minimum=1, maximum=100, step=1, label=i18n("top_k"), value=15, interactive=True) top_p = gr.Slider(minimum=0, maximum=1, step=0.05, label=i18n("top_p"), value=1, interactive=True) with gr.Row(): temperature = gr.Slider( diff --git a/api_v2.py b/api_v2.py index 8c83bb0f..21511db3 100644 --- a/api_v2.py +++ b/api_v2.py @@ -27,7 +27,7 @@ POST: "aux_ref_audio_paths": [], # list.(optional) auxiliary reference audio paths for multi-speaker tone fusion "prompt_text": "", # str.(optional) prompt text for the reference audio "prompt_lang": "", # str.(required) language of the prompt text for the reference audio - "top_k": 5, # int. top k sampling + "top_k": 15, # int. top k sampling "top_p": 1, # float. top p sampling "temperature": 1, # float. temperature for sampling "text_split_method": "cut5", # str. text split method, see text_segmentation_method.py for details. @@ -158,7 +158,7 @@ class TTS_Request(BaseModel): aux_ref_audio_paths: list = None prompt_lang: str = None prompt_text: str = "" - top_k: int = 5 + top_k: int = 15 top_p: float = 1 temperature: float = 1 text_split_method: str = "cut5" @@ -355,7 +355,7 @@ async def tts_handle(req: dict): "aux_ref_audio_paths": [], # list.(optional) auxiliary reference audio paths for multi-speaker tone fusion "prompt_text": "", # str.(optional) prompt text for the reference audio "prompt_lang": "", # str.(required) language of the prompt text for the reference audio - "top_k": 5, # int. top k sampling + "top_k": 15, # int. top k sampling "top_p": 1, # float. top p sampling "temperature": 1, # float. temperature for sampling "text_split_method": "cut5", # str. text split method, see text_segmentation_method.py for details. @@ -460,7 +460,7 @@ async def tts_get_endpoint( aux_ref_audio_paths: list = None, prompt_lang: str = None, prompt_text: str = "", - top_k: int = 5, + top_k: int = 15, top_p: float = 1, temperature: float = 1, text_split_method: str = "cut5", From bfca0f6b2dd9f846c76366be807f01a8873140a0 Mon Sep 17 00:00:00 2001 From: ChasonJiang <46401978+ChasonJiang@users.noreply.github.com> Date: Fri, 19 Dec 2025 17:37:19 +0800 Subject: [PATCH 17/20] =?UTF-8?q?=E5=AF=B9=E9=BD=90naive=5Finfer=E7=9A=84?= =?UTF-8?q?=E8=A7=A3=E7=A0=81=E7=AD=96=E7=95=A5=EF=BC=8C=E9=98=B2=E6=AD=A2?= =?UTF-8?q?=E5=90=9E=E5=8F=A5=20(#2697)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- GPT_SoVITS/AR/models/t2s_model.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/GPT_SoVITS/AR/models/t2s_model.py b/GPT_SoVITS/AR/models/t2s_model.py index 0caadd04..81aad1ec 100644 --- a/GPT_SoVITS/AR/models/t2s_model.py +++ b/GPT_SoVITS/AR/models/t2s_model.py @@ -711,6 +711,9 @@ class Text2SemanticDecoder(nn.Module): else: attn_mask = F.pad(attn_mask, (0, 1), value=False) + if idx < 11: ###至少预测出10个token不然不给停止(0.4s) + logits = logits[:, :-1] + samples = sample( logits, y, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature )[0] From 51df9f738428d77377e6788afc8c988b8c475f4f Mon Sep 17 00:00:00 2001 From: sushistack Date: Thu, 25 Dec 2025 17:44:21 +0900 Subject: [PATCH 18/20] Fix model file name in README instructions (#2700) --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 86d50ea2..923f9a0a 100644 --- a/README.md +++ b/README.md @@ -347,7 +347,7 @@ Use v4 from v1/v2/v3 environment: 2. Clone the latest codes from github. -3. Download v4 pretrained models (gsv-v4-pretrained/s2v4.ckpt, and gsv-v4-pretrained/vocoder.pth) from [huggingface](https://huggingface.co/lj1995/GPT-SoVITS/tree/main) and put them into `GPT_SoVITS/pretrained_models`. +3. Download v4 pretrained models (gsv-v4-pretrained/s2v4.pth, and gsv-v4-pretrained/vocoder.pth) from [huggingface](https://huggingface.co/lj1995/GPT-SoVITS/tree/main) and put them into `GPT_SoVITS/pretrained_models`. ## V2Pro Release Notes From 9080a967d5e64f4bfb5a9ea33afc7252136b0256 Mon Sep 17 00:00:00 2001 From: ChasonJiang <46401978+ChasonJiang@users.noreply.github.com> Date: Tue, 30 Dec 2025 15:21:03 +0800 Subject: [PATCH 19/20] =?UTF-8?q?=E4=BF=AE=E5=A4=8D=E9=87=87=E6=A0=B7?= =?UTF-8?q?=E9=94=99=E8=AF=AF=20(#2703)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- GPT_SoVITS/AR/models/t2s_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/GPT_SoVITS/AR/models/t2s_model.py b/GPT_SoVITS/AR/models/t2s_model.py index 81aad1ec..486f85a3 100644 --- a/GPT_SoVITS/AR/models/t2s_model.py +++ b/GPT_SoVITS/AR/models/t2s_model.py @@ -712,7 +712,7 @@ class Text2SemanticDecoder(nn.Module): attn_mask = F.pad(attn_mask, (0, 1), value=False) if idx < 11: ###至少预测出10个token不然不给停止(0.4s) - logits = logits[:, :-1] + logits[:, -1] = float("-inf") samples = sample( logits, y, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature From c767f0b83b998e996a4d230d86da575a03f54a3f Mon Sep 17 00:00:00 2001 From: ChasonJiang <46401978+ChasonJiang@users.noreply.github.com> Date: Tue, 30 Dec 2025 16:00:21 +0800 Subject: [PATCH 20/20] =?UTF-8?q?=E4=BF=AE=E5=A4=8Dbug=20(#2704)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * 修复bug * fallbak and bug fix --- GPT_SoVITS/AR/models/t2s_model.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/GPT_SoVITS/AR/models/t2s_model.py b/GPT_SoVITS/AR/models/t2s_model.py index 486f85a3..ac905f4b 100644 --- a/GPT_SoVITS/AR/models/t2s_model.py +++ b/GPT_SoVITS/AR/models/t2s_model.py @@ -707,12 +707,11 @@ class Text2SemanticDecoder(nn.Module): if idx == 0: attn_mask = F.pad(attn_mask[:, :, -1].unsqueeze(-2), (0, 1), value=False) - logits = logits[:, :-1] else: attn_mask = F.pad(attn_mask, (0, 1), value=False) if idx < 11: ###至少预测出10个token不然不给停止(0.4s) - logits[:, -1] = float("-inf") + logits = logits[:, :-1] samples = sample( logits, y, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature