diff --git a/GPT_SoVITS/export_torch_script.py b/GPT_SoVITS/export_torch_script.py index e4406f28..786e22d1 100644 --- a/GPT_SoVITS/export_torch_script.py +++ b/GPT_SoVITS/export_torch_script.py @@ -261,41 +261,21 @@ class T2SBlock: attn = F.scaled_dot_product_attention(q, k, v, ~attn_mask) - attn = attn.permute(2, 0, 1, 3).reshape(batch_size * q_len, self.hidden_dim) - attn = attn.view(q_len, batch_size, self.hidden_dim).transpose(1, 0) + # attn = attn.permute(2, 0, 1, 3).reshape(batch_size * q_len, self.hidden_dim) + # attn = attn.view(q_len, batch_size, self.hidden_dim).transpose(1, 0) + attn = attn.transpose(1, 2).reshape(batch_size, q_len, -1) attn = F.linear(self.to_mask(attn, padding_mask), self.out_w, self.out_b) - if padding_mask is not None: - for i in range(batch_size): - # mask = padding_mask[i,:,0] - if self.false.device != padding_mask.device: - self.false = self.false.to(padding_mask.device) - idx = torch.where(padding_mask[i, :, 0] == self.false)[0] - x_item = x[i, idx, :].unsqueeze(0) - attn_item = attn[i, idx, :].unsqueeze(0) - x_item = x_item + attn_item - x_item = F.layer_norm(x_item, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1) - x_item = x_item + self.mlp.forward(x_item) - x_item = F.layer_norm( - x_item, - [self.hidden_dim], - self.norm_w2, - self.norm_b2, - self.norm_eps2, - ) - x[i, idx, :] = x_item.squeeze(0) - x = self.to_mask(x, padding_mask) - else: - x = x + attn - x = F.layer_norm(x, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1) - x = x + self.mlp.forward(x) - x = F.layer_norm( - x, - [self.hidden_dim], - self.norm_w2, - self.norm_b2, - self.norm_eps2, - ) + x = x + attn + x = F.layer_norm(x, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1) + x = x + self.mlp.forward(x) + x = F.layer_norm( + x, + [self.hidden_dim], + self.norm_w2, + self.norm_b2, + self.norm_eps2, + ) return x, k_cache, v_cache def decode_next_token(self, x: torch.Tensor, k_cache: torch.Tensor, v_cache: torch.Tensor): diff --git a/GPT_SoVITS/stream_v2pro.py b/GPT_SoVITS/stream_v2pro.py new file mode 100644 index 00000000..0a7712ad --- /dev/null +++ b/GPT_SoVITS/stream_v2pro.py @@ -0,0 +1,607 @@ +# 这是一个实验性质的实现,旨在探索 stream infer 的可能性。(xiao hai xie zhe wan de) +from typing import List +from export_torch_script import ExportERes2NetV2, SSLModel, T2SModel, VitsModel, get_raw_t2s_model, init_sv_cn, resamplex, sample, spectrogram_torch +import export_torch_script +from my_utils import load_audio +import torch +from torch import LongTensor, Tensor, nn +from torch.nn import functional as F + +import soundfile +from inference_webui import get_phones_and_bert +import matplotlib.pyplot as plt + + +class StreamT2SModel(nn.Module): + def __init__(self, t2s: T2SModel): + super(StreamT2SModel, self).__init__() + self.t2s = t2s + + @torch.jit.export + def pre_infer( + self, + prompts: LongTensor, + ref_seq: LongTensor, + text_seq: LongTensor, + ref_bert: torch.Tensor, + text_bert: torch.Tensor, + top_k: int, + ) -> tuple[int, Tensor, Tensor, List[Tensor], List[Tensor]]: + bert = torch.cat([ref_bert.T, text_bert.T], 1) + all_phoneme_ids = torch.cat([ref_seq, text_seq], 1) + bert = bert.unsqueeze(0) + + x = self.t2s.ar_text_embedding(all_phoneme_ids) + x = x + self.t2s.bert_proj(bert.transpose(1, 2)) + x: torch.Tensor = self.t2s.ar_text_position(x) + + # [1,N,512] [1,N] + # y, k, v, y_emb, x_example = self.first_stage_decoder(x, prompts) + y = prompts + # x_example = x[:,:,0] * 0.0 + + x_len = x.shape[1] + x_attn_mask = torch.zeros((x_len, x_len), dtype=torch.bool) + + y_emb = self.t2s.ar_audio_embedding(y) + y_len: int = y_emb.shape[1] + prefix_len = y.shape[1] + y_pos = self.t2s.ar_audio_position(y_emb) + xy_pos = torch.concat([x, y_pos], dim=1) + + bsz = x.shape[0] + src_len = x_len + y_len + x_attn_mask_pad = F.pad( + x_attn_mask, + (0, y_len), ###xx的纯0扩展到xx纯0+xy纯1,(x,x+y) + value=True, + ) + y_attn_mask = F.pad( ###yy的右上1扩展到左边xy的0,(y,x+y) + torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1), + (x_len, 0), + value=False, + ) + xy_attn_mask = ( + torch.concat([x_attn_mask_pad, y_attn_mask], dim=0) + .unsqueeze(0) + .expand(bsz * self.t2s.num_head, -1, -1) + .view(bsz, self.t2s.num_head, src_len, src_len) + .to(device=x.device, dtype=torch.bool) + ) + + xy_dec, k_cache, v_cache = self.t2s.t2s_transformer.process_prompt( + xy_pos, xy_attn_mask, None + ) + + logits = self.t2s.ar_predict_layer(xy_dec[:, -1]) + logits = logits[:, :-1] + samples = sample( + logits, y, top_k=top_k, top_p=1, repetition_penalty=1.35, temperature=1.0 + )[0] + y = torch.concat([y, samples], dim=1) + y_emb: Tensor = self.t2s.ar_audio_embedding(y[:, -1:]) + xy_pos: Tensor = ( + y_emb * self.t2s.ar_audio_position.x_scale + + self.t2s.ar_audio_position.alpha + * self.t2s.ar_audio_position.pe[:, y_len].to( + dtype=y_emb.dtype, device=y_emb.device + ) + ) + + return y_len, y, xy_pos, k_cache, v_cache + + @torch.jit.export + def decode_next_token( + self, + idx: int, # 记住从1开始 到1500 + top_k: int, + y_len: int, + y: Tensor, + xy_pos: Tensor, + k_cache: List[Tensor], + v_cache: List[Tensor], + ) -> tuple[Tensor, Tensor, int, List[Tensor], List[Tensor]]: + # [1, N] [N_layer, N, 1, 512] [N_layer, N, 1, 512] [1, N, 512] [1] [1, N, 512] [1, N] + # y, k, v, y_emb, logits, samples = self.stage_decoder(y, k, v, y_emb, x_example) + xy_dec, k_cache, v_cache = self.t2s.t2s_transformer.decode_next_token( + xy_pos, k_cache, v_cache + ) + logits = self.t2s.ar_predict_layer(xy_dec[:, -1]) + + if idx < 11: ###至少预测出10个token不然不给停止(0.4s) + logits = logits[:, :-1] + + samples = sample( + logits, y, top_k=top_k, top_p=1, repetition_penalty=1.35, temperature=1.0 + )[0] + + y = torch.concat([y, samples], dim=1) + last_token = int(samples[0, 0]) + + # if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num: + # stop = True + if torch.argmax(logits, dim=-1)[0] == self.t2s.EOS or samples[0, 0] == self.t2s.EOS: + return y[:,:-1], xy_pos, self.t2s.EOS, k_cache, v_cache + + # if stop: + # if y.shape[1] == 0: + # y = torch.concat([y, torch.zeros_like(samples)], dim=1) + # break + + y_emb = self.t2s.ar_audio_embedding(y[:, -1:]) + xy_pos = ( + y_emb * self.t2s.ar_audio_position.x_scale + + self.t2s.ar_audio_position.alpha + * self.t2s.ar_audio_position.pe[:, y_len + idx].to( + dtype=y_emb.dtype, device=y_emb.device + ) + ) + return y, xy_pos, last_token, k_cache, v_cache + + def forward( + self, + idx: int, # 记住从1开始 到1500 + top_k: int, + y_len: int, + y: Tensor, + xy_pos: Tensor, + k_cache: List[Tensor], + v_cache: List[Tensor], + ): + return self.decode_next_token(idx,top_k,y_len,y,xy_pos,k_cache,v_cache) + + +class StepVitsModel(nn.Module): + def __init__(self, vits: VitsModel,sv_model:ExportERes2NetV2): + super().__init__() + self.hps = vits.hps + self.vq_model = vits.vq_model + self.hann_window = vits.hann_window + self.sv = sv_model + + def ref_handle(self, ref_audio_32k): + refer = spectrogram_torch( + self.hann_window, + ref_audio_32k, + self.hps.data.filter_length, + self.hps.data.sampling_rate, + self.hps.data.hop_length, + self.hps.data.win_length, + center=False, + ) + ref_audio_16k = resamplex(ref_audio_32k, 32000, 16000).to(ref_audio_32k.dtype).to(ref_audio_32k.device) + sv_emb = self.sv(ref_audio_16k) + return refer, sv_emb + + def extract_latent(self, ssl_content): + codes = self.vq_model.extract_latent(ssl_content) + return codes[0] + + def forward(self, pred_semantic, text_seq, refer, sv_emb=None): + return self.vq_model( + pred_semantic, text_seq, refer, speed=1.0, sv_emb=sv_emb + )[0, 0] + + +@torch.jit.script +def find_best_audio_offset_fast(reference_audio: Tensor, search_audio: Tensor): + ref_len = len(reference_audio) + search_len = len(search_audio) + + if search_len < ref_len: + raise ValueError( + f"搜索音频长度 ({search_len}) 必须大于等于参考音频长度 ({ref_len})" + ) + + # 使用F.conv1d计算原始互相关 + reference_flipped = reference_audio.unsqueeze(0).unsqueeze(0) + search_padded = search_audio.unsqueeze(0).unsqueeze(0) + + # 计算点积 + dot_products = F.conv1d(search_padded, reference_flipped).squeeze() + + if len(dot_products.shape) == 0: + dot_products = dot_products.unsqueeze(0) + + # 计算参考音频的平方和 + ref_squared_sum = torch.sum(reference_audio**2) + + # 计算搜索音频每个位置的平方和(滑动窗口) + search_squared = search_audio**2 + search_squared_padded = search_squared.unsqueeze(0).unsqueeze(0) + ones_kernel = torch.ones( + 1, 1, ref_len, dtype=search_audio.dtype, device=search_audio.device + ) + + segment_squared_sums = F.conv1d(search_squared_padded, ones_kernel).squeeze() + + if len(segment_squared_sums.shape) == 0: + segment_squared_sums = segment_squared_sums.unsqueeze(0) + + # 计算归一化因子 + ref_norm = torch.sqrt(ref_squared_sum) + segment_norms = torch.sqrt(segment_squared_sums) + + # 避免除零 + epsilon = 1e-8 + normalization_factor = ref_norm * segment_norms + epsilon + + # 归一化互相关 + correlation_scores = dot_products / normalization_factor + + best_offset = torch.argmax(correlation_scores).item() + + return best_offset, correlation_scores + + +import time + +def test_stream( + gpt_path, + vits_path, + version, + ref_audio_path, + ref_text, + output_path, + device="cpu", + is_half=True, +): + if export_torch_script.sv_cn_model == None: + init_sv_cn(device,is_half) + + ref_audio = torch.tensor([load_audio(ref_audio_path, 16000)]).float() + ssl = SSLModel() + + print(f"device: {device}") + + ref_seq_id, ref_bert_T, ref_norm_text = get_phones_and_bert( + ref_text, "all_zh", "v2" + ) + ref_seq = torch.LongTensor([ref_seq_id]).to(device) + ref_bert = ref_bert_T.T + if is_half: + ref_bert = ref_bert.half() + ref_bert = ref_bert.to(ref_seq.device) + + text_seq_id, text_bert_T, norm_text = get_phones_and_bert( + "这是一个简单的示例,真没想到这么简单就完成了,真的神奇,接下来我们说说狐狸,可能这就是狐狸吧.它有长长的尾巴,尖尖的耳朵,传说中还有九条尾巴。你觉得狐狸神奇吗?", "auto", "v2" + ) + text_seq = torch.LongTensor([text_seq_id]).to(device) + text_bert = text_bert_T.T + if is_half: + text_bert = text_bert.half() + text_bert = text_bert.to(text_seq.device) + + ssl_content = ssl(ref_audio) + if is_half: + ssl_content = ssl_content.half() + ssl_content = ssl_content.to(device) + + sv_model = ExportERes2NetV2(export_torch_script.sv_cn_model) + + # vits_path = "SoVITS_weights_v2/xw_e8_s216.pth" + vits = VitsModel(vits_path, version,is_half=is_half,device=device) + vits.eval() + + # gpt_path = "GPT_weights_v2/xw-e15.ckpt" + # dict_s1 = torch.load(gpt_path, map_location=device) + dict_s1 = torch.load(gpt_path, weights_only=False) + raw_t2s = get_raw_t2s_model(dict_s1).to(device) + print("#### get_raw_t2s_model ####") + print(raw_t2s.config) + if is_half: + raw_t2s = raw_t2s.half() + t2s_m = T2SModel(raw_t2s) + t2s_m.eval() + # t2s = torch.jit.script(t2s_m).to(device) + t2s = t2s_m + print("#### script t2s_m ####") + + print("vits.hps.data.sampling_rate:", vits.hps.data.sampling_rate) + + stream_t2s = StreamT2SModel(t2s).to(device) + stream_t2s = torch.jit.script(stream_t2s) + + ref_audio_sr = resamplex(ref_audio, 16000, 32000) + if is_half: + ref_audio_sr = ref_audio_sr.half() + ref_audio_sr = ref_audio_sr.to(device) + + top_k = 15 + + codes = vits.vq_model.extract_latent(ssl_content) + prompt_semantic = codes[0, 0] + prompts = prompt_semantic.unsqueeze(0) + + audio_16k = resamplex(ref_audio_sr, 32000, 16000).to(ref_audio_sr.dtype) + sv_emb = sv_model(audio_16k) + print("text_seq",text_seq.shape) + + refer = spectrogram_torch( + vits.hann_window, + ref_audio_sr, + vits.hps.data.filter_length, + vits.hps.data.sampling_rate, + vits.hps.data.hop_length, + vits.hps.data.win_length, + center=False, + ) + + st = time.time() + et = time.time() + + y_len, y, xy_pos, k_cache, v_cache = stream_t2s.pre_infer(prompts, ref_seq, text_seq, ref_bert, text_bert, top_k) + idx = 1 + last_idx = 0 + audios = [] + raw_audios = [] + last_audio_ret = None + offset_index = [] + full_audios = [] + print("y.shape:", y.shape) + cut_id = 0 + while True: + y, xy_pos, last_token, k_cache, v_cache = stream_t2s(idx, top_k, y_len, y, xy_pos, k_cache, v_cache) + # print("y.shape:", y.shape) + stop = last_token==t2s.EOS + print('idx:',idx , 'y.shape:', y.shape, y.shape[1]-idx) + + if last_token < 50 and idx-last_idx > (len(audios)+1) * 25 and idx > cut_id: + cut_id = idx + 7 + print('trigger:',idx, last_idx, y[:,-idx+last_idx:], y[:,-idx+last_idx:].shape) + # y = torch.cat([y, y[:,-1:]], dim=1) + # idx+=1 + + if stop : + idx -=1 + print('stop') + print(idx, y[:,-idx+last_idx:]) + print(idx,last_idx, y.shape) + print(y[:,-idx:-idx+20]) + + + # 玄学这档子事说不清楚 + if idx == cut_id or stop: + print(f"idx: {idx}, last_idx: {last_idx}, cut_id: {cut_id}, stop: {stop}") + audio = vits.vq_model(y[:,-idx:].unsqueeze(0), text_seq, refer, speed=1.0, sv_emb=sv_emb)[0, 0] + full_audios.append(audio) + if last_idx == 0: + last_audio_ret = audio[-1280*8:-1280*8+256] + audio = audio[:-1280*8] + raw_audios.append(audio) + et = time.time() + else: + if stop: + audio_ = audio[last_idx*1280 -1280*8:] + raw_audios.append(audio_) + i, x = find_best_audio_offset_fast(last_audio_ret, audio_[:1280]) + offset_index.append(i) + audio = audio_[i:] + else: + audio_ = audio[last_idx*1280 -1280*8:-1280*8] + raw_audios.append(audio_) + i, x = find_best_audio_offset_fast(last_audio_ret, audio_[:1280]) + offset_index.append(i) + last_audio_ret = audio[-1280*8:-1280*8+256] + audio = audio_[i:] + last_idx = idx + # print(f'write {output_path}/out_{audio_index}') + # soundfile.write(f"{output_path}/out_{audio_index}.wav", audio.float().detach().cpu().numpy(), 32000) + audios.append(audio) + # print(idx,'/',1500 , y.shape, y[0,-1].item(), stop) + if idx>1500: + break + + if stop: + break + + idx+=1 + + at = time.time() + + for (i,a) in enumerate(audios): + print(f'write {output_path}/out_{i}') + soundfile.write(f"{output_path}/out_{i}.wav", a.float().detach().cpu().numpy(), 32000) + + print(f"frist token: {et - st:.4f} seconds") + print(f"all token: {at - st:.4f} seconds") + audio = vits.vq_model(y[:,-idx:].unsqueeze(0), text_seq, refer, speed=1.0, sv_emb=sv_emb)[0, 0] + soundfile.write(f"{output_path}/out_final.wav", audio.float().detach().cpu().numpy(), 32000) + audio = torch.cat(audios, dim=0) + soundfile.write(f"{output_path}/out.wav", audio.float().detach().cpu().numpy(), 32000) + audio_raw = torch.cat(raw_audios, dim=0) + soundfile.write(f"{output_path}/out.raw.wav", audio_raw.float().detach().cpu().numpy(), 32000) + + + colors = ['red', 'green', 'blue', 'orange', 'purple', 'cyan', 'magenta', 'yellow'] + + max_duration = full_audios[-1].shape[0] + plt.xlim(0, max_duration) + + last_line = 0 + + for i,a in enumerate(full_audios): + plt.plot((a+2.0*i).float().detach().cpu().numpy(), color=colors[i], alpha=0.5, label=f"Audio {i}") + # plt.axvline(x=last_line, color=colors[i], linestyle='--') + last_line = a.shape[0]-8*1280 + plt.axvline(x=last_line, color=colors[i], linestyle='--') + + plt.plot((audio-2.0).float().detach().cpu().numpy(), color='black', label='Final Audio') + + plt.plot((audio_raw-4.0).float().detach().cpu().numpy(), color='cyan', label='Raw Audio') + + print("offset_index:", offset_index) + plt.show() + + +def export_prov2( + gpt_path, + vits_path, + version, + ref_audio_path, + ref_text, + output_path, + device="cpu", + is_half=True, +): + if export_torch_script.sv_cn_model == None: + init_sv_cn(device,is_half) + + ref_audio = torch.tensor([load_audio(ref_audio_path, 16000)]).float() + ssl = SSLModel() + + print(f"device: {device}") + + ref_seq_id, ref_bert_T, ref_norm_text = get_phones_and_bert( + ref_text, "all_zh", "v2" + ) + ref_seq = torch.LongTensor([ref_seq_id]).to(device) + ref_bert = ref_bert_T.T + if is_half: + ref_bert = ref_bert.half() + ref_bert = ref_bert.to(ref_seq.device) + + text_seq_id, text_bert_T, norm_text = get_phones_and_bert( + "这是一个简单的示例,真没想到这么简单就完成了.The King and His Stories.Once there was a king.He likes to write stories, but his stories were not good.", "auto", "v2" + ) + text_seq = torch.LongTensor([text_seq_id]).to(device) + text_bert = text_bert_T.T + if is_half: + text_bert = text_bert.half() + text_bert = text_bert.to(text_seq.device) + + ssl_content = ssl(ref_audio) + if is_half: + ssl_content = ssl_content.half() + ssl_content = ssl_content.to(device) + + sv_model = ExportERes2NetV2(export_torch_script.sv_cn_model) + + # vits_path = "SoVITS_weights_v2/xw_e8_s216.pth" + vits = VitsModel(vits_path, version,is_half=is_half,device=device) + vits.eval() + vits = StepVitsModel(vits, sv_model) + + # gpt_path = "GPT_weights_v2/xw-e15.ckpt" + # dict_s1 = torch.load(gpt_path, map_location=device) + dict_s1 = torch.load(gpt_path, weights_only=False) + raw_t2s = get_raw_t2s_model(dict_s1).to(device) + print("#### get_raw_t2s_model ####") + print(raw_t2s.config) + if is_half: + raw_t2s = raw_t2s.half() + t2s_m = T2SModel(raw_t2s) + t2s_m.eval() + # t2s = torch.jit.script(t2s_m).to(device) + t2s = t2s_m + print("#### script t2s_m ####") + + print("vits.hps.data.sampling_rate:", vits.hps.data.sampling_rate) + + stream_t2s = StreamT2SModel(t2s).to(device) + stream_t2s = torch.jit.script(stream_t2s) + + ref_audio_sr = resamplex(ref_audio, 16000, 32000) + if is_half: + ref_audio_sr = ref_audio_sr.half() + ref_audio_sr = ref_audio_sr.to(device) + + top_k = 15 + + prompts = vits.extract_latent(ssl_content) + + audio_16k = resamplex(ref_audio_sr, 32000, 16000).to(ref_audio_sr.dtype) + sv_emb = sv_model(audio_16k) + print("text_seq",text_seq.shape) + # torch.jit.trace() + + refer,sv_emb = vits.ref_handle(ref_audio_sr) + + st = time.time() + et = time.time() + + y_len, y, xy_pos, k_cache, v_cache = stream_t2s.pre_infer(prompts, ref_seq, text_seq, ref_bert, text_bert, top_k) + idx = 1 + print("y.shape:", y.shape) + while True: + y, xy_pos, last_token, k_cache, v_cache = stream_t2s(idx, top_k, y_len, y, xy_pos, k_cache, v_cache) + # print("y.shape:", y.shape) + + idx+=1 + # print(idx,'/',1500 , y.shape, y[0,-1].item(), stop) + if idx>1500: + break + + if last_token == t2s.EOS: + break + + at = time.time() + print("EOS:",t2s.EOS) + + print(f"frist token: {et - st:.4f} seconds") + print(f"all token: {at - st:.4f} seconds") + print("sv_emb", sv_emb.shape) + print("refer",refer.shape) + y = y[:,-idx:].unsqueeze(0) + print("y", y.shape) + audio = vits(y, text_seq, refer, sv_emb) + soundfile.write(f"{output_path}/out_final.wav", audio.float().detach().cpu().numpy(), 32000) + + torch._dynamo.mark_dynamic(ssl_content, 2) + torch._dynamo.mark_dynamic(ref_audio_sr, 1) + torch._dynamo.mark_dynamic(ref_seq, 1) + torch._dynamo.mark_dynamic(text_seq, 1) + torch._dynamo.mark_dynamic(ref_bert, 0) + torch._dynamo.mark_dynamic(text_bert, 0) + torch._dynamo.mark_dynamic(refer, 2) + torch._dynamo.mark_dynamic(y, 2) + + inputs = { + "forward": (y, text_seq, refer, sv_emb), + "extract_latent": ssl_content, + "ref_handle": ref_audio_sr, + } + + stream_t2s.save(f"{output_path}/t2s.pt") + torch.jit.trace_module(vits, inputs=inputs, optimize=True).save(f"{output_path}/vits.pt") + torch.jit.script(find_best_audio_offset_fast, optimize=True).save(f"{output_path}/find_best_audio_offset_fast.pt") + +import argparse +import os + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="GPT-SoVITS Command Line Tool") + parser.add_argument("--gpt_model", required=True, help="Path to the GPT model file") + parser.add_argument( + "--sovits_model", required=True, help="Path to the SoVITS model file" + ) + parser.add_argument( + "--ref_audio", required=True, help="Path to the reference audio file" + ) + parser.add_argument( + "--ref_text", required=True, help="Path to the reference text file" + ) + parser.add_argument( + "--output_path", required=True, help="Path to the output directory" + ) + parser.add_argument("--device", help="Device to use", default="cuda" if torch.cuda.is_available() else "cpu") + parser.add_argument("--version", help="version of the model", default="v2Pro") + parser.add_argument("--no-half", action="store_true", help = "Do not use half precision for model weights") + + args = parser.parse_args() + + if not os.path.exists(args.output_path): + os.makedirs(args.output_path) + + is_half = not args.no_half + with torch.no_grad(): + export_prov2( + gpt_path=args.gpt_model, + vits_path=args.sovits_model, + version=args.version, + ref_audio_path=args.ref_audio, + ref_text=args.ref_text, + output_path=args.output_path, + device=args.device, + is_half=is_half, + )