diff --git a/GPT_SoVITS/AR/models/t2s_model.py b/GPT_SoVITS/AR/models/t2s_model.py index c8ad3d82..cecfcbc1 100644 --- a/GPT_SoVITS/AR/models/t2s_model.py +++ b/GPT_SoVITS/AR/models/t2s_model.py @@ -1,8 +1,11 @@ # modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/models/t2s_model.py # reference: https://github.com/lifeiteng/vall-e import torch -from tqdm import tqdm +import random +import numpy as np +from tqdm import tqdm +from typing import List from AR.models.utils import make_pad_mask from AR.models.utils import ( topk_sampling, @@ -10,7 +13,7 @@ from AR.models.utils import ( logits_to_probs, multinomial_sample_one_no_sync, dpo_loss, - make_reject_y, + make_reject_y, get_batch_logps ) from AR.modules.embedding import SinePositionalEmbedding @@ -35,6 +38,139 @@ default_config = { } +@torch.jit.script +class T2SMLP: + def __init__(self, w1, b1, w2, b2): + self.w1 = w1 + self.b1 = b1 + self.w2 = w2 + self.b2 = b2 + + def forward(self, x): + x = F.relu(F.linear(x, self.w1, self.b1)) + x = F.linear(x, self.w2, self.b2) + return x + + +@torch.jit.script +class T2SBlock: + def __init__( + self, + num_heads, + hidden_dim: int, + mlp: T2SMLP, + qkv_w, + qkv_b, + out_w, + out_b, + norm_w1, + norm_b1, + norm_eps1, + norm_w2, + norm_b2, + norm_eps2, + ): + self.num_heads = num_heads + self.mlp = mlp + self.hidden_dim: int = hidden_dim + self.qkv_w = qkv_w + self.qkv_b = qkv_b + self.out_w = out_w + self.out_b = out_b + self.norm_w1 = norm_w1 + self.norm_b1 = norm_b1 + self.norm_eps1 = norm_eps1 + self.norm_w2 = norm_w2 + self.norm_b2 = norm_b2 + self.norm_eps2 = norm_eps2 + + def process_prompt(self, x, attn_mask: torch.Tensor): + q, k, v = F.linear(x, self.qkv_w, self.qkv_b).chunk(3, dim=-1) + + batch_size = q.shape[0] + q_len = q.shape[1] + kv_len = k.shape[1] + + k_cache = k + v_cache = v + + q = q.view(batch_size, q_len, self.num_heads, -1).transpose(1, 2) + k = k_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2) + v = v_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2) + + attn = F.scaled_dot_product_attention(q, k, v, ~attn_mask) + + attn = attn.permute(2, 0, 1, 3).reshape(batch_size, -1, self.hidden_dim) + attn = F.linear(attn, self.out_w, self.out_b) + + x = F.layer_norm( + x + attn, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1 + ) + x = F.layer_norm( + x + self.mlp.forward(x), + [self.hidden_dim], + self.norm_w2, + self.norm_b2, + self.norm_eps2, + ) + return x, k_cache, v_cache + + def decode_next_token(self, x, k_cache, v_cache): + q, k, v = F.linear(x, self.qkv_w, self.qkv_b).chunk(3, dim=-1) + + k_cache = torch.cat([k_cache, k], dim=1) + v_cache = torch.cat([v_cache, v], dim=1) + kv_len = k_cache.shape[1] + + batch_size = q.shape[0] + q_len = q.shape[1] + + q = q.view(batch_size, q_len, self.num_heads, -1).transpose(1, 2) + k = k_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2) + v = v_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2) + + attn = F.scaled_dot_product_attention(q, k, v) + + attn = attn.permute(2, 0, 1, 3).reshape(batch_size, -1, self.hidden_dim) + attn = F.linear(attn, self.out_w, self.out_b) + + x = F.layer_norm( + x + attn, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1 + ) + x = F.layer_norm( + x + self.mlp.forward(x), + [self.hidden_dim], + self.norm_w2, + self.norm_b2, + self.norm_eps2, + ) + return x, k_cache, v_cache + + +@torch.jit.script +class T2STransformer: + def __init__(self, num_blocks: int, blocks: List[T2SBlock]): + self.num_blocks: int = num_blocks + self.blocks = blocks + + def process_prompt( + self, x, attn_mask: torch.Tensor): + k_cache: List[torch.Tensor] = [] + v_cache: List[torch.Tensor] = [] + for i in range(self.num_blocks): + x, k_cache_, v_cache_ = self.blocks[i].process_prompt(x, attn_mask) + k_cache.append(k_cache_) + v_cache.append(v_cache_) + return x, k_cache, v_cache + + def decode_next_token( + self, x, k_cache: List[torch.Tensor], v_cache: List[torch.Tensor] + ): + for i in range(self.num_blocks): + x, k_cache[i], v_cache[i] = self.blocks[i].decode_next_token(x, k_cache[i], v_cache[i]) + return x, k_cache, v_cache + + class Text2SemanticDecoder(nn.Module): def __init__(self, config, norm_first=False, top_k=3): super(Text2SemanticDecoder, self).__init__() @@ -89,6 +225,37 @@ class Text2SemanticDecoder(nn.Module): ignore_index=self.EOS, ) + blocks = [] + + for i in range(self.num_layers): + layer = self.h.layers[i] + t2smlp = T2SMLP( + layer.linear1.weight, + layer.linear1.bias, + layer.linear2.weight, + layer.linear2.bias + ) + # (layer.self_attn.in_proj_weight, layer.self_attn.in_proj_bias) + block = T2SBlock( + self.num_head, + self.model_dim, + t2smlp, + layer.self_attn.in_proj_weight, + layer.self_attn.in_proj_bias, + layer.self_attn.out_proj.weight, + layer.self_attn.out_proj.bias, + layer.norm1.weight, + layer.norm1.bias, + layer.norm1.eps, + layer.norm2.weight, + layer.norm2.bias, + layer.norm2.eps + ) + + blocks.append(block) + + self.t2s_transformer = T2STransformer(self.num_layers, blocks) + def make_input_data(self, x, x_lens, y, y_lens, bert_feature): x = self.ar_text_embedding(x) x = x + self.bert_proj(bert_feature.transpose(1, 2)) @@ -116,7 +283,7 @@ class Text2SemanticDecoder(nn.Module): (0, y_len), value=True, ) - + y_attn_mask = F.pad( torch.triu( torch.ones(y_len, y_len, dtype=torch.bool, device=x.device), @@ -177,7 +344,7 @@ class Text2SemanticDecoder(nn.Module): A_logits, R_logits = get_batch_logps(logits, reject_logits, targets, reject_targets) loss_2, _, _ = dpo_loss(A_logits, R_logits, 0, 0, 0.2, reference_free=True) - + loss = loss_1 + loss_2 return loss, acc @@ -246,14 +413,14 @@ class Text2SemanticDecoder(nn.Module): # 需要看下这个函数和 forward 的区别以及没有 semantic 的时候 prompts 输入什么 def infer( - self, - x, - x_lens, - prompts, - bert_feature, - top_k: int = -100, - early_stop_num: int = -1, - temperature: float = 1.0, + self, + x, + x_lens, + prompts, + bert_feature, + top_k: int = -100, + early_stop_num: int = -1, + temperature: float = 1.0, ): x = self.ar_text_embedding(x) x = x + self.bert_proj(bert_feature.transpose(1, 2)) @@ -322,15 +489,15 @@ class Text2SemanticDecoder(nn.Module): return targets[:, :-1], targets[:, 1:] def infer_panel( - self, - x, #####全部文本token - x_lens, - prompts, ####参考音频token - bert_feature, - top_k: int = -100, - top_p: int = 100, - early_stop_num: int = -1, - temperature: float = 1.0, + self, + x, #####全部文本token + x_lens, + prompts, ####参考音频token + bert_feature, + top_k: int = -100, + top_p: int = 100, + early_stop_num: int = -1, + temperature: float = 1.0, ): x = self.ar_text_embedding(x) x = x + self.bert_proj(bert_feature.transpose(1, 2)) @@ -338,22 +505,14 @@ class Text2SemanticDecoder(nn.Module): # AR Decoder y = prompts - + x_len = x.shape[1] x_attn_mask = torch.zeros((x_len, x_len), dtype=torch.bool) stop = False # print(1111111,self.num_layers) - cache = { - "all_stage": self.num_layers, - "k": [None] * self.num_layers, ###根据配置自己手写 - "v": [None] * self.num_layers, - # "xy_pos":None,##y_pos位置编码每次都不一样的没法缓存,每次都要重新拼xy_pos.主要还是写法原因,其实是可以历史统一一样的,但也没啥计算量就不管了 - "y_emb": None, ##只需要对最新的samples求emb,再拼历史的就行 - # "logits":None,###原版就已经只对结尾求再拼接了,不用管 - # "xy_dec":None,###不需要,本来只需要最后一个做logits - "first_infer": 1, - "stage": 0, - } + + k_cache = None + v_cache = None ################### first step ########################## if y is not None: y_emb = self.ar_audio_embedding(y) @@ -361,7 +520,6 @@ class Text2SemanticDecoder(nn.Module): prefix_len = y.shape[1] y_pos = self.ar_audio_position(y_emb) xy_pos = torch.concat([x, y_pos], dim=1) - cache["y_emb"] = y_emb ref_free = False else: y_emb = None @@ -370,13 +528,14 @@ class Text2SemanticDecoder(nn.Module): y_pos = None xy_pos = x y = torch.zeros(x.shape[0], 0, dtype=torch.int, device=x.device) + prompts = y ref_free = True x_attn_mask_pad = F.pad( - x_attn_mask, - (0, y_len), ###xx的纯0扩展到xx纯0+xy纯1,(x,x+y) - value=True, - ) + x_attn_mask, + (0, y_len), ###xx的纯0扩展到xx纯0+xy纯1,(x,x+y) + value=True, + ) y_attn_mask = F.pad( ###yy的右上1扩展到左边xy的0,(y,x+y) torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1), (x_len, 0), @@ -385,64 +544,43 @@ class Text2SemanticDecoder(nn.Module): xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0).to( x.device ) - for idx in tqdm(range(1500)): - - xy_dec, _ = self.h((xy_pos, None), mask=xy_attn_mask, cache=cache) + if xy_attn_mask is not None: + xy_dec, k_cache, v_cache = self.t2s_transformer.process_prompt(xy_pos, xy_attn_mask) + else: + xy_dec, k_cache, v_cache = self.t2s_transformer.decode_next_token(xy_pos, k_cache, v_cache) + logits = self.ar_predict_layer( xy_dec[:, -1] - ) ##不用改,如果用了cache的默认就是只有一帧,取最后一帧一样的 - # samples = topk_sampling(logits, top_k=top_k, top_p=1.0, temperature=temperature) - if(idx==0):###第一次跑不能EOS否则没有了 - logits = logits[:, :-1] ###刨除1024终止符号的概率 + ) + + if idx == 0: + xy_attn_mask = None + logits = logits[:, :-1] samples = sample( logits[0], y, top_k=top_k, top_p=top_p, repetition_penalty=1.35, temperature=temperature )[0].unsqueeze(0) - # 本次生成的 semantic_ids 和之前的 y 构成新的 y - # print(samples.shape)#[1,1]#第一个1是bs - y = torch.concat([y, samples], dim=1) + + y = torch.concat([y, samples], dim=1) if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num: print("use early stop num:", early_stop_num) stop = True if torch.argmax(logits, dim=-1)[0] == self.EOS or samples[0, 0] == self.EOS: - # print(torch.argmax(logits, dim=-1)[0] == self.EOS, samples[0, 0] == self.EOS) stop = True if stop: - # if prompts.shape[1] == y.shape[1]: - # y = torch.concat([y, torch.zeros_like(samples)], dim=1) - # print("bad zero prediction") - if y.shape[1]==0: + if y.shape[1] == 0: y = torch.concat([y, torch.zeros_like(samples)], dim=1) print("bad zero prediction") print(f"T2S Decoding EOS [{prefix_len} -> {y.shape[1]}]") break - - ####################### update next step ################################### - cache["first_infer"] = 0 - if cache["y_emb"] is not None: - y_emb = torch.cat( - [cache["y_emb"], self.ar_audio_embedding(y[:, -1:])], dim = 1 - ) - cache["y_emb"] = y_emb - y_pos = self.ar_audio_position(y_emb) - xy_pos = y_pos[:, -1:] - else: - y_emb = self.ar_audio_embedding(y[:, -1:]) - cache["y_emb"] = y_emb - y_pos = self.ar_audio_position(y_emb) - xy_pos = y_pos - y_len = y_pos.shape[1] - ###最右边一列(是错的) - # xy_attn_mask=torch.ones((1, x_len+y_len), dtype=torch.bool,device=xy_pos.device) - # xy_attn_mask[:,-1]=False - ###最下面一行(是对的) - xy_attn_mask = torch.zeros( - (1, x_len + y_len), dtype=torch.bool, device=xy_pos.device - ) + ####################### update next step ################################### + y_emb = self.ar_audio_embedding(y[:, -1:]) + xy_pos = y_emb * self.ar_audio_position.x_scale + self.ar_audio_position.alpha * self.ar_audio_position.pe[:, y_len + idx].to(dtype=y_emb.dtype,device=y_emb.device) + if ref_free: return y[:, :-1], 0 - return y[:, :-1], idx-1 + return y[:, :-1], idx - 1 \ No newline at end of file diff --git a/GPT_SoVITS/inference_webui.py b/GPT_SoVITS/inference_webui.py index dc061ffb..4b694319 100644 --- a/GPT_SoVITS/inference_webui.py +++ b/GPT_SoVITS/inference_webui.py @@ -65,7 +65,7 @@ from text import cleaned_text_to_sequence from text.cleaner import clean_text from time import time as ttime from module.mel_processing import spectrogram_torch -from my_utils import load_audio +from tools.my_utils import load_audio from tools.i18n.i18n import I18nAuto i18n = I18nAuto() @@ -331,27 +331,29 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, int(hps.data.sampling_rate * 0.3), dtype=np.float16 if is_half == True else np.float32, ) - with torch.no_grad(): - wav16k, sr = librosa.load(ref_wav_path, sr=16000) - if (wav16k.shape[0] > 160000 or wav16k.shape[0] < 48000): - raise OSError(i18n("参考音频在3~10秒范围外,请更换!")) - wav16k = torch.from_numpy(wav16k) - zero_wav_torch = torch.from_numpy(zero_wav) - if is_half == True: - wav16k = wav16k.half().to(device) - zero_wav_torch = zero_wav_torch.half().to(device) - else: - wav16k = wav16k.to(device) - zero_wav_torch = zero_wav_torch.to(device) - wav16k = torch.cat([wav16k, zero_wav_torch]) - ssl_content = ssl_model.model(wav16k.unsqueeze(0))[ - "last_hidden_state" - ].transpose( - 1, 2 - ) # .float() - codes = vq_model.extract_latent(ssl_content) - - prompt_semantic = codes[0, 0] + if not ref_free: + with torch.no_grad(): + wav16k, sr = librosa.load(ref_wav_path, sr=16000) + if (wav16k.shape[0] > 160000 or wav16k.shape[0] < 48000): + raise OSError(i18n("参考音频在3~10秒范围外,请更换!")) + wav16k = torch.from_numpy(wav16k) + zero_wav_torch = torch.from_numpy(zero_wav) + if is_half == True: + wav16k = wav16k.half().to(device) + zero_wav_torch = zero_wav_torch.half().to(device) + else: + wav16k = wav16k.to(device) + zero_wav_torch = zero_wav_torch.to(device) + wav16k = torch.cat([wav16k, zero_wav_torch]) + ssl_content = ssl_model.model(wav16k.unsqueeze(0))[ + "last_hidden_state" + ].transpose( + 1, 2 + ) # .float() + codes = vq_model.extract_latent(ssl_content) + prompt_semantic = codes[0, 0] + prompt = prompt_semantic.unsqueeze(0).to(device) + t1 = ttime() if (how_to_cut == i18n("凑四句一切")): @@ -391,7 +393,7 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, bert = bert.to(device).unsqueeze(0) all_phoneme_len = torch.tensor([all_phoneme_ids.shape[-1]]).to(device) - prompt = prompt_semantic.unsqueeze(0).to(device) + t2 = ttime() with torch.no_grad(): # pred_semantic = t2s_model.model.infer( @@ -510,16 +512,26 @@ def cut4(inp): # contributed by https://github.com/AI-Hobbyist/GPT-SoVITS/blob/main/GPT_SoVITS/inference_webui.py def cut5(inp): - # if not re.search(r'[^\w\s]', inp[-1]): - # inp += '。' inp = inp.strip("\n") - punds = r'[,.;?!、,。?!;:…]' - items = re.split(f'({punds})', inp) - mergeitems = ["".join(group) for group in zip(items[::2], items[1::2])] - # 在句子不存在符号或句尾无符号的时候保证文本完整 - if len(items)%2 == 1: - mergeitems.append(items[-1]) - opt = [item for item in mergeitems if not set(item).issubset(punctuation)] + punds = {',', '.', ';', '?', '!', '、', ',', '。', '?', '!', ';', ':', '…'} + mergeitems = [] + items = [] + + for i, char in enumerate(inp): + if char in punds: + if char == '.' and i > 0 and i < len(inp) - 1 and inp[i - 1].isdigit() and inp[i + 1].isdigit(): + items.append(char) + else: + items.append(char) + mergeitems.append("".join(items)) + items = [] + else: + items.append(char) + + if items: + mergeitems.append("".join(items)) + + opt = [item for item in mergeitems if not set(item).issubset(punds)] return "\n".join(opt) diff --git a/GPT_SoVITS/module/data_utils.py b/GPT_SoVITS/module/data_utils.py index ff4c4f43..72c80555 100644 --- a/GPT_SoVITS/module/data_utils.py +++ b/GPT_SoVITS/module/data_utils.py @@ -17,7 +17,7 @@ from functools import lru_cache import requests from scipy.io import wavfile from io import BytesIO -from my_utils import load_audio +from tools.my_utils import load_audio # ZeroDivisionError fixed by Tybost (https://github.com/RVC-Boss/GPT-SoVITS/issues/79) class TextAudioSpeakerLoader(torch.utils.data.Dataset): diff --git a/GPT_SoVITS/my_utils.py b/GPT_SoVITS/my_utils.py deleted file mode 100644 index 776939dd..00000000 --- a/GPT_SoVITS/my_utils.py +++ /dev/null @@ -1,21 +0,0 @@ -import ffmpeg -import numpy as np - - -def load_audio(file, sr): - try: - # https://github.com/openai/whisper/blob/main/whisper/audio.py#L26 - # This launches a subprocess to decode audio while down-mixing and resampling as necessary. - # Requires the ffmpeg CLI and `ffmpeg-python` package to be installed. - file = ( - file.strip(" ").strip('"').strip("\n").strip('"').strip(" ") - ) # 防止小白拷路径头尾带了空格和"和回车 - out, _ = ( - ffmpeg.input(file, threads=0) - .output("-", format="f32le", acodec="pcm_f32le", ac=1, ar=sr) - .run(cmd=["ffmpeg", "-nostdin"], capture_stdout=True, capture_stderr=True) - ) - except Exception as e: - raise RuntimeError(f"Failed to load audio: {e}") - - return np.frombuffer(out, np.float32).flatten() diff --git a/GPT_SoVITS/onnx_export.py b/GPT_SoVITS/onnx_export.py index b82e987f..ab457d75 100644 --- a/GPT_SoVITS/onnx_export.py +++ b/GPT_SoVITS/onnx_export.py @@ -9,7 +9,7 @@ cnhubert.cnhubert_base_path=cnhubert_base_path ssl_model = cnhubert.get_model() from text import cleaned_text_to_sequence import soundfile -from my_utils import load_audio +from tools.my_utils import load_audio import os import json diff --git a/GPT_SoVITS/prepare_datasets/2-get-hubert-wav32k.py b/GPT_SoVITS/prepare_datasets/2-get-hubert-wav32k.py index 61c933a4..17394ee4 100644 --- a/GPT_SoVITS/prepare_datasets/2-get-hubert-wav32k.py +++ b/GPT_SoVITS/prepare_datasets/2-get-hubert-wav32k.py @@ -17,7 +17,7 @@ from scipy.io import wavfile import librosa,torch now_dir = os.getcwd() sys.path.append(now_dir) -from my_utils import load_audio +from tools.my_utils import load_audio # from config import cnhubert_base_path # cnhubert.cnhubert_base_path=cnhubert_base_path diff --git a/GPT_SoVITS/s1_train.py b/GPT_SoVITS/s1_train.py index 43cfa19a..ece295d3 100644 --- a/GPT_SoVITS/s1_train.py +++ b/GPT_SoVITS/s1_train.py @@ -79,15 +79,17 @@ class my_model_ckpt(ModelCheckpoint): to_save_od["config"] = self.config to_save_od["info"] = "GPT-e%s" % (trainer.current_epoch + 1) # torch.save( - my_save( - to_save_od, - "%s/%s-e%s.ckpt" - % ( - self.half_weights_save_dir, - self.exp_name, - trainer.current_epoch + 1, - ), - ) + # print(os.environ) + if(os.environ.get("LOCAL_RANK","0")=="0"): + my_save( + to_save_od, + "%s/%s-e%s.ckpt" + % ( + self.half_weights_save_dir, + self.exp_name, + trainer.current_epoch + 1, + ), + ) self._save_last_checkpoint(trainer, monitor_candidates) diff --git a/GPT_SoVITS/text/cleaner.py b/GPT_SoVITS/text/cleaner.py index 0d82d850..c42264ab 100644 --- a/GPT_SoVITS/text/cleaner.py +++ b/GPT_SoVITS/text/cleaner.py @@ -22,6 +22,11 @@ def clean_text(text, language): phones, word2ph = language_module.g2p(norm_text) assert len(phones) == sum(word2ph) assert len(norm_text) == len(word2ph) + elif language == "en": + phones = language_module.g2p(norm_text) + if len(phones) < 4: + phones = [','] * (4 - len(phones)) + phones + word2ph = None else: phones = language_module.g2p(norm_text) word2ph = None diff --git a/GPT_SoVITS/text/engdict-hot.rep b/GPT_SoVITS/text/engdict-hot.rep index 120e5ef6..b9268303 100644 --- a/GPT_SoVITS/text/engdict-hot.rep +++ b/GPT_SoVITS/text/engdict-hot.rep @@ -1,2 +1,3 @@ CHATGPT CH AE1 T JH IY1 P IY1 T IY1 -JSON JH EY1 S AH0 N \ No newline at end of file +JSON JH EY1 S AH0 N +CONDA K AA1 N D AH0 \ No newline at end of file diff --git a/api.py b/api.py index b5340715..aa822ca7 100644 --- a/api.py +++ b/api.py @@ -143,7 +143,7 @@ from AR.models.t2s_lightning_module import Text2SemanticLightningModule from text import cleaned_text_to_sequence from text.cleaner import clean_text from module.mel_processing import spectrogram_torch -from my_utils import load_audio +from tools.my_utils import load_audio import config as global_config import logging import subprocess diff --git a/docs/cn/Changelog_CN.md b/docs/cn/Changelog_CN.md index abd7263f..114761e8 100644 --- a/docs/cn/Changelog_CN.md +++ b/docs/cn/Changelog_CN.md @@ -183,13 +183,36 @@ 4-修复了webui的GPT中文微调没读到bert导致和推理不一致,训练太多可能效果还会变差的问题。如果大量数据微调的建议重新微调模型得到质量优化 [#99f09c8](https://github.com/RVC-Boss/GPT-SoVITS/commit/99f09c8bdc155c1f4272b511940717705509582a) +### 20240706 + +小问题修复: + +1-修正CPU推理默认bs小数 https://github.com/RVC-Boss/GPT-SoVITS/commit/db50670598f0236613eefa6f2d5a23a271d82041 + +2-修复降噪、asr中途遇到异常跳出所有需处理的音频文件的问题 https://github.com/RVC-Boss/GPT-SoVITS/pull/1258 https://github.com/RVC-Boss/GPT-SoVITS/pull/1265 https://github.com/RVC-Boss/GPT-SoVITS/pull/1267 + +3-修复按标点符号切分时小数会被切分 https://github.com/RVC-Boss/GPT-SoVITS/pull/1253 + +4-多卡训练多进程保存逻辑修复 + +https://github.com/RVC-Boss/GPT-SoVITS/commit/a208698e775155efc95b187b746d153d0f2847ca + +5-移除冗余my_utils https://github.com/RVC-Boss/GPT-SoVITS/pull/1251 + +重点: + +6-倍速推理代码经过验证后推理效果和base完全一致,合并进main。使用的代码:https://github.com/RVC-Boss/GPT-SoVITS/pull/672。支持无参考文本模式也倍速。 + +后面会逐渐验证快速推理分支的推理改动的一致性 + + todolist: 1-中文多音字推理优化(有没有人来测试的,欢迎把测试结果写在pr评论区里) https://github.com/RVC-Boss/GPT-SoVITS/pull/488 (v2底模训练已经合了,下个版本发布就要合了) -2-正在尝试解决低音质参考音频导致音质较差的问题,v2再试试如果能解决就发了,节点暂定高考后吧 +2-正在尝试解决低音质参考音频导致音质较差的问题,v2再试试如果能解决就发了,节点暂定7月吧 diff --git a/tools/asr/fasterwhisper_asr.py b/tools/asr/fasterwhisper_asr.py index 669ac3aa..e9fc6a47 100644 --- a/tools/asr/fasterwhisper_asr.py +++ b/tools/asr/fasterwhisper_asr.py @@ -78,7 +78,7 @@ def execute_asr(input_folder, output_folder, model_size, language, precision): text += segment.text output.append(f"{file_path}|{output_file_name}|{info.language.upper()}|{text}") except: - return print(traceback.format_exc()) + print(traceback.format_exc()) output_folder = output_folder or "output/asr_opt" os.makedirs(output_folder, exist_ok=True) diff --git a/tools/cmd-denoise.py b/tools/cmd-denoise.py index 69b51e66..1fdcab6d 100644 --- a/tools/cmd-denoise.py +++ b/tools/cmd-denoise.py @@ -1,4 +1,5 @@ import os,argparse +import traceback from modelscope.pipelines import pipeline from modelscope.utils.constant import Tasks @@ -12,7 +13,10 @@ def execute_denoise(input_folder,output_folder): # print(input_folder) # print(list(os.listdir(input_folder).sort())) for name in tqdm(os.listdir(input_folder)): - ans("%s/%s"%(input_folder,name),output_path='%s/%s'%(output_folder,name)) + try: + ans("%s/%s"%(input_folder,name),output_path='%s/%s'%(output_folder,name)) + except: + traceback.print_exc() if __name__ == '__main__': parser = argparse.ArgumentParser() diff --git a/tools/slice_audio.py b/tools/slice_audio.py index 46ee408a..8a06292d 100644 --- a/tools/slice_audio.py +++ b/tools/slice_audio.py @@ -3,7 +3,7 @@ import traceback from scipy.io import wavfile # parent_directory = os.path.dirname(os.path.abspath(__file__)) # sys.path.append(parent_directory) -from my_utils import load_audio +from tools.my_utils import load_audio from slicer2 import Slicer def slice(inp,opt_root,threshold,min_length,min_interval,hop_size,max_sil_kept,_max,alpha,i_part,all_part):