mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-10-07 15:19:59 +08:00
添加说话人特征注入以及相关处理
This commit is contained in:
parent
7303e108e1
commit
0341f13b92
1
.gitignore
vendored
1
.gitignore
vendored
@ -10,3 +10,4 @@ logs
|
|||||||
reference
|
reference
|
||||||
SoVITS_weights
|
SoVITS_weights
|
||||||
GPT_weights
|
GPT_weights
|
||||||
|
TEMP
|
@ -1,7 +1,7 @@
|
|||||||
# modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/data_module.py
|
# modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/data_module.py
|
||||||
from pytorch_lightning import LightningDataModule
|
from pytorch_lightning import LightningDataModule
|
||||||
from AR.data.bucket_sampler import DistributedBucketSampler
|
from AR.data.bucket_sampler import DistributedBucketSampler
|
||||||
from AR.data.dataset import Text2SemanticDataset
|
from AR.data.dataset import Text2SemanticDataset, Text2SemanticSpeakerDataset
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
|
|
||||||
@ -26,12 +26,20 @@ class Text2SemanticDataModule(LightningDataModule):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
def setup(self, stage=None, output_logs=False):
|
def setup(self, stage=None, output_logs=False):
|
||||||
self._train_dataset = Text2SemanticDataset(
|
if 'train_dir' not in self.config:
|
||||||
phoneme_path=self.train_phoneme_path,
|
self._train_dataset = Text2SemanticDataset(
|
||||||
semantic_path=self.train_semantic_path,
|
phoneme_path=self.train_phoneme_path,
|
||||||
max_sec=self.config["data"]["max_sec"],
|
semantic_path=self.train_semantic_path,
|
||||||
pad_val=self.config["data"]["pad_val"],
|
max_sec=self.config["data"]["max_sec"],
|
||||||
)
|
pad_val=self.config["data"]["pad_val"],
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self._train_dataset = Text2SemanticSpeakerDataset(
|
||||||
|
self.config['speaker_dict_path'],
|
||||||
|
self.config['train_dir'],
|
||||||
|
max_sec=self.config["data"]["max_sec"],
|
||||||
|
pad_val=self.config["data"]["pad_val"],
|
||||||
|
)
|
||||||
self._dev_dataset = self._train_dataset
|
self._dev_dataset = self._train_dataset
|
||||||
# self._dev_dataset = Text2SemanticDataset(
|
# self._dev_dataset = Text2SemanticDataset(
|
||||||
# phoneme_path=self.dev_phoneme_path,
|
# phoneme_path=self.dev_phoneme_path,
|
||||||
|
@ -293,6 +293,128 @@ class Text2SemanticDataset(Dataset):
|
|||||||
"bert_feature": bert_padded,
|
"bert_feature": bert_padded,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
class Text2SemanticSpeakerDataset(Dataset):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
speaker_dict_path: str,
|
||||||
|
train_dir: str,
|
||||||
|
max_sample: int = None,
|
||||||
|
max_sec: int = 100,
|
||||||
|
pad_val: int = 1024,
|
||||||
|
# min value of phoneme/sec
|
||||||
|
min_ps_ratio: int = 3,
|
||||||
|
# max value of phoneme/sec
|
||||||
|
max_ps_ratio: int = 25,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.speaker_dict = dict()
|
||||||
|
if os.path.exists(speaker_dict_path):
|
||||||
|
with open(speaker_dict_path, "r", encoding="utf-8") as f:
|
||||||
|
try:
|
||||||
|
self.speaker_dict = json.load(f)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
self.speaker2dataset = {}
|
||||||
|
self.PAD = pad_val
|
||||||
|
for name in os.listdir(train_dir):
|
||||||
|
if name in ["logs_s1", os.path.basename(speaker_dict_path)]:
|
||||||
|
continue
|
||||||
|
opt_dir = os.path.join(train_dir, name)
|
||||||
|
speaker_id = self.speaker_dict.get(name, None)
|
||||||
|
if speaker_id is None:
|
||||||
|
speaker_id = len(self.speaker_dict)
|
||||||
|
self.speaker_dict[name] = speaker_id
|
||||||
|
self.speaker2dataset[speaker_id] = Text2SemanticDataset(
|
||||||
|
phoneme_path=os.path.join(opt_dir, "2-name2text.txt"),
|
||||||
|
semantic_path=os.path.join(opt_dir, "6-name2semantic.tsv"),
|
||||||
|
max_sample=max_sample,
|
||||||
|
max_sec=max_sec,
|
||||||
|
pad_val=pad_val,
|
||||||
|
min_ps_ratio=min_ps_ratio,
|
||||||
|
max_ps_ratio=max_ps_ratio,
|
||||||
|
)
|
||||||
|
with open(speaker_dict_path, "w", encoding="utf-8") as f:
|
||||||
|
json.dump(self.speaker_dict, f, ensure_ascii=False, indent=4)
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
result = 0
|
||||||
|
for _, dataset in self.speaker2dataset.items():
|
||||||
|
result += len(dataset)
|
||||||
|
return result
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
for speaker_id, dataset in self.speaker2dataset.items():
|
||||||
|
if idx < len(dataset):
|
||||||
|
result = dataset[idx]
|
||||||
|
result["speaker_id"] = speaker_id
|
||||||
|
return result
|
||||||
|
else:
|
||||||
|
idx -= len(dataset)
|
||||||
|
raise IndexError("index out of range")
|
||||||
|
|
||||||
|
def get_sample_length(self, idx: int):
|
||||||
|
for speaker_id, dataset in self.speaker2dataset.items():
|
||||||
|
if idx < len(dataset):
|
||||||
|
semantic_ids = dataset.semantic_phoneme[idx][0]
|
||||||
|
sec = 1.0 * len(semantic_ids) / dataset.hz
|
||||||
|
return sec
|
||||||
|
else:
|
||||||
|
idx -= len(dataset)
|
||||||
|
raise IndexError("index out of range")
|
||||||
|
|
||||||
|
def collate(self, examples: List[Dict]) -> Dict:
|
||||||
|
sample_index: List[int] = []
|
||||||
|
phoneme_ids: List[torch.Tensor] = []
|
||||||
|
phoneme_ids_lens: List[int] = []
|
||||||
|
semantic_ids: List[torch.Tensor] = []
|
||||||
|
semantic_ids_lens: List[int] = []
|
||||||
|
speaker_ids: List[int] = []
|
||||||
|
# return
|
||||||
|
|
||||||
|
for item in examples:
|
||||||
|
sample_index.append(item["idx"])
|
||||||
|
phoneme_ids.append(np.array(item["phoneme_ids"], dtype=np.int64))
|
||||||
|
semantic_ids.append(np.array(item["semantic_ids"], dtype=np.int64))
|
||||||
|
phoneme_ids_lens.append(item["phoneme_ids_len"])
|
||||||
|
semantic_ids_lens.append(item["semantic_ids_len"])
|
||||||
|
speaker_ids.append(item["speaker_id"])
|
||||||
|
|
||||||
|
# pad 0
|
||||||
|
phoneme_ids = batch_sequences(phoneme_ids)
|
||||||
|
semantic_ids = batch_sequences(semantic_ids, pad_value=self.PAD)
|
||||||
|
|
||||||
|
# # convert each batch to torch.tensor
|
||||||
|
phoneme_ids = torch.tensor(phoneme_ids)
|
||||||
|
semantic_ids = torch.tensor(semantic_ids)
|
||||||
|
phoneme_ids_lens = torch.tensor(phoneme_ids_lens)
|
||||||
|
semantic_ids_lens = torch.tensor(semantic_ids_lens)
|
||||||
|
bert_padded = torch.FloatTensor(len(examples), 1024, max(phoneme_ids_lens))
|
||||||
|
bert_padded.zero_()
|
||||||
|
speaker_ids = torch.tensor(speaker_ids, dtype=torch.long)
|
||||||
|
|
||||||
|
for idx, item in enumerate(examples):
|
||||||
|
bert = item["bert_feature"]
|
||||||
|
if bert != None:
|
||||||
|
bert_padded[idx, :, : bert.shape[-1]] = bert
|
||||||
|
|
||||||
|
return {
|
||||||
|
# List[int]
|
||||||
|
"ids": sample_index,
|
||||||
|
# torch.Tensor (B, max_phoneme_length)
|
||||||
|
"phoneme_ids": phoneme_ids,
|
||||||
|
# torch.Tensor (B)
|
||||||
|
"phoneme_ids_len": phoneme_ids_lens,
|
||||||
|
# torch.Tensor (B, max_semantic_ids_length)
|
||||||
|
"semantic_ids": semantic_ids,
|
||||||
|
# torch.Tensor (B)
|
||||||
|
"semantic_ids_len": semantic_ids_lens,
|
||||||
|
# torch.Tensor (B, 1024, max_phoneme_length)
|
||||||
|
"bert_feature": bert_padded,
|
||||||
|
"speaker_ids": speaker_ids,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
root_dir = "/data/docker/liujing04/gpt-vits/prepare/dump_mix/"
|
root_dir = "/data/docker/liujing04/gpt-vits/prepare/dump_mix/"
|
||||||
|
@ -34,10 +34,11 @@ class Text2SemanticLightningModule(LightningModule):
|
|||||||
self.eval_dir.mkdir(parents=True, exist_ok=True)
|
self.eval_dir.mkdir(parents=True, exist_ok=True)
|
||||||
for param in self.model.parameters():
|
for param in self.model.parameters():
|
||||||
param.requires_grad = False
|
param.requires_grad = False
|
||||||
self.model.speaker_proj.weight.requires_grad = True
|
self.model.speaker_proj.requires_grad_(True)
|
||||||
self.model.speaker_proj.bias.requires_grad = True
|
self.model.speaker_proj.weight.requires_grad_(True)
|
||||||
self.model.speaker_proj.train()
|
self.model.speaker_proj.bias.requires_grad_(True)
|
||||||
self.model.speaker_feat.requires_grad = True
|
self.model.speaker_emb.requires_grad_(True)
|
||||||
|
self.model.speaker_emb.weight.requires_grad_(True)
|
||||||
|
|
||||||
def training_step(self, batch: Dict, batch_idx: int):
|
def training_step(self, batch: Dict, batch_idx: int):
|
||||||
opt = self.optimizers()
|
opt = self.optimizers()
|
||||||
@ -48,13 +49,13 @@ class Text2SemanticLightningModule(LightningModule):
|
|||||||
batch["semantic_ids"],
|
batch["semantic_ids"],
|
||||||
batch["semantic_ids_len"],
|
batch["semantic_ids_len"],
|
||||||
batch["bert_feature"],
|
batch["bert_feature"],
|
||||||
|
speaker_ids=batch["speaker_ids"] if "speaker_ids" in batch else None,
|
||||||
)
|
)
|
||||||
self.manual_backward(loss)
|
self.manual_backward(loss)
|
||||||
if batch_idx > 0 and batch_idx % 4 == 0:
|
if batch_idx > 0 and batch_idx % 4 == 0:
|
||||||
opt.step()
|
opt.step()
|
||||||
opt.zero_grad()
|
opt.zero_grad()
|
||||||
scheduler.step()
|
scheduler.step()
|
||||||
torch.save(self.model.speaker_feat.data, "C:/Users/86150/Desktop/GPT-SoVITS/zyj.pt")
|
|
||||||
|
|
||||||
self.log(
|
self.log(
|
||||||
"total_loss",
|
"total_loss",
|
||||||
|
@ -52,12 +52,8 @@ class Text2SemanticDecoder(nn.Module):
|
|||||||
# should be same as num of kmeans bin
|
# should be same as num of kmeans bin
|
||||||
# assert self.EOS == 1024
|
# assert self.EOS == 1024
|
||||||
self.bert_proj = nn.Linear(1024, self.embedding_dim)
|
self.bert_proj = nn.Linear(1024, self.embedding_dim)
|
||||||
self.speaker_proj = nn.Linear(1024, self.embedding_dim)
|
self.speaker_proj = nn.Linear(4096, self.embedding_dim)
|
||||||
self.path_speaker = "C:/Users/86150/Desktop/GPT-SoVITS/zyj.pt"
|
self.speaker_emb = nn.Embedding(100, 4096)
|
||||||
if not os.path.exists(self.path_speaker):
|
|
||||||
self.speaker_feat = nn.Parameter(torch.randn(1024) * 0.1)
|
|
||||||
else:
|
|
||||||
self.speaker_feat = nn.Parameter(torch.load(self.path_speaker, map_location="cpu"))
|
|
||||||
self.ar_text_embedding = TokenEmbedding(
|
self.ar_text_embedding = TokenEmbedding(
|
||||||
self.embedding_dim, self.phoneme_vocab_size, self.p_dropout
|
self.embedding_dim, self.phoneme_vocab_size, self.p_dropout
|
||||||
)
|
)
|
||||||
@ -95,7 +91,7 @@ class Text2SemanticDecoder(nn.Module):
|
|||||||
ignore_index=self.EOS,
|
ignore_index=self.EOS,
|
||||||
)
|
)
|
||||||
|
|
||||||
def make_input_data(self, x, x_lens, y, y_lens, bert_feature):
|
def make_input_data(self, x, x_lens, y, y_lens, bert_feature, speaker_ids=None):
|
||||||
x = self.ar_text_embedding(x)
|
x = self.ar_text_embedding(x)
|
||||||
x = x + self.bert_proj(bert_feature.transpose(1, 2))
|
x = x + self.bert_proj(bert_feature.transpose(1, 2))
|
||||||
x = self.ar_text_position(x)
|
x = self.ar_text_position(x)
|
||||||
@ -111,7 +107,8 @@ class Text2SemanticDecoder(nn.Module):
|
|||||||
x_len = x_lens.max()
|
x_len = x_lens.max()
|
||||||
y_len = y_lens.max()
|
y_len = y_lens.max()
|
||||||
y_emb = self.ar_audio_embedding(y)
|
y_emb = self.ar_audio_embedding(y)
|
||||||
y_emb = y_emb + self.speaker_proj(self.speaker_feat).view(1,1,-1)
|
if speaker_ids is not None:
|
||||||
|
y_emb = y_emb + self.speaker_proj(self.speaker_emb(speaker_ids)).unsqueeze(1)
|
||||||
y_pos = self.ar_audio_position(y_emb)
|
y_pos = self.ar_audio_position(y_emb)
|
||||||
|
|
||||||
xy_padding_mask = torch.concat([x_mask, y_mask], dim=1)
|
xy_padding_mask = torch.concat([x_mask, y_mask], dim=1)
|
||||||
@ -149,7 +146,7 @@ class Text2SemanticDecoder(nn.Module):
|
|||||||
|
|
||||||
return xy_pos, xy_attn_mask, targets
|
return xy_pos, xy_attn_mask, targets
|
||||||
|
|
||||||
def forward(self, x, x_lens, y, y_lens, bert_feature):
|
def forward(self, x, x_lens, y, y_lens, bert_feature, speaker_ids=None):
|
||||||
"""
|
"""
|
||||||
x: phoneme_ids
|
x: phoneme_ids
|
||||||
y: semantic_ids
|
y: semantic_ids
|
||||||
@ -157,7 +154,7 @@ class Text2SemanticDecoder(nn.Module):
|
|||||||
|
|
||||||
reject_y, reject_y_lens = make_reject_y(y, y_lens)
|
reject_y, reject_y_lens = make_reject_y(y, y_lens)
|
||||||
|
|
||||||
xy_pos, xy_attn_mask, targets = self.make_input_data(x, x_lens, y, y_lens, bert_feature)
|
xy_pos, xy_attn_mask, targets = self.make_input_data(x, x_lens, y, y_lens, bert_feature, speaker_ids=speaker_ids)
|
||||||
|
|
||||||
xy_dec, _ = self.h(
|
xy_dec, _ = self.h(
|
||||||
(xy_pos, None),
|
(xy_pos, None),
|
||||||
@ -167,7 +164,7 @@ class Text2SemanticDecoder(nn.Module):
|
|||||||
logits = self.ar_predict_layer(xy_dec[:, x_len:])
|
logits = self.ar_predict_layer(xy_dec[:, x_len:])
|
||||||
|
|
||||||
###### DPO #############
|
###### DPO #############
|
||||||
reject_xy_pos, reject_xy_attn_mask, reject_targets = self.make_input_data(x, x_lens, reject_y, reject_y_lens, bert_feature)
|
reject_xy_pos, reject_xy_attn_mask, reject_targets = self.make_input_data(x, x_lens, reject_y, reject_y_lens, bert_feature, speaker_ids=speaker_ids)
|
||||||
|
|
||||||
reject_xy_dec, _ = self.h(
|
reject_xy_dec, _ = self.h(
|
||||||
(reject_xy_pos, None),
|
(reject_xy_pos, None),
|
||||||
@ -338,7 +335,7 @@ class Text2SemanticDecoder(nn.Module):
|
|||||||
top_p: int = 100,
|
top_p: int = 100,
|
||||||
early_stop_num: int = -1,
|
early_stop_num: int = -1,
|
||||||
temperature: float = 1.0,
|
temperature: float = 1.0,
|
||||||
use_speaker_feat=True,
|
speaker_ids=None,
|
||||||
):
|
):
|
||||||
x = self.ar_text_embedding(x)
|
x = self.ar_text_embedding(x)
|
||||||
x = x + self.bert_proj(bert_feature.transpose(1, 2))
|
x = x + self.bert_proj(bert_feature.transpose(1, 2))
|
||||||
@ -362,8 +359,8 @@ class Text2SemanticDecoder(nn.Module):
|
|||||||
"first_infer": 1,
|
"first_infer": 1,
|
||||||
"stage": 0,
|
"stage": 0,
|
||||||
}
|
}
|
||||||
if use_speaker_feat:
|
if speaker_ids is not None:
|
||||||
speaker_feat = self.speaker_proj(self.speaker_feat).view(1,1,-1)
|
speaker_feat = self.speaker_proj(self.speaker_emb(speaker_ids)).unsqueeze(1)
|
||||||
else:
|
else:
|
||||||
speaker_feat = 0
|
speaker_feat = 0
|
||||||
for idx in tqdm(range(1500)):
|
for idx in tqdm(range(1500)):
|
||||||
|
@ -3,7 +3,7 @@ train:
|
|||||||
epochs: 20
|
epochs: 20
|
||||||
batch_size: 8
|
batch_size: 8
|
||||||
save_every_n_epoch: 1
|
save_every_n_epoch: 1
|
||||||
precision: 16-mixed
|
precision: "32"
|
||||||
gradient_clip: 1.0
|
gradient_clip: 1.0
|
||||||
optimizer:
|
optimizer:
|
||||||
lr: 0.01
|
lr: 0.01
|
||||||
|
@ -365,7 +365,10 @@ def merge_short_text_in_array(texts, threshold):
|
|||||||
result[len(result) - 1] += text
|
result[len(result) - 1] += text
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, how_to_cut=i18n("不切"), top_k=20, top_p=0.6, temperature=0.6):
|
def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, how_to_cut=i18n("不切"), top_k=20, top_p=0.6, temperature=0.6,speaker_id="-1"):
|
||||||
|
speaker_id = int(speaker_id)
|
||||||
|
if speaker_id == -1:
|
||||||
|
speaker_id = None
|
||||||
t0 = ttime()
|
t0 = ttime()
|
||||||
prompt_language = dict_language[prompt_language]
|
prompt_language = dict_language[prompt_language]
|
||||||
text_language = dict_language[text_language]
|
text_language = dict_language[text_language]
|
||||||
@ -381,7 +384,7 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
|
|||||||
)
|
)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
wav16k, sr = librosa.load(ref_wav_path, sr=16000)
|
wav16k, sr = librosa.load(ref_wav_path, sr=16000)
|
||||||
if (wav16k.shape[0] > 160000 or wav16k.shape[0] < 48000):
|
if (wav16k.shape[0] > 160000 or wav16k.shape[0] < 0):
|
||||||
raise OSError(i18n("参考音频在3~10秒范围外,请更换!"))
|
raise OSError(i18n("参考音频在3~10秒范围外,请更换!"))
|
||||||
wav16k = torch.from_numpy(wav16k)
|
wav16k = torch.from_numpy(wav16k)
|
||||||
zero_wav_torch = torch.from_numpy(zero_wav)
|
zero_wav_torch = torch.from_numpy(zero_wav)
|
||||||
@ -448,6 +451,7 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
|
|||||||
top_p=top_p,
|
top_p=top_p,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
early_stop_num=hz * max_sec,
|
early_stop_num=hz * max_sec,
|
||||||
|
speaker_ids=torch.LongTensor([speaker_id]).to(device) if speaker_id is not None else None,
|
||||||
)
|
)
|
||||||
t3 = ttime()
|
t3 = ttime()
|
||||||
# print(pred_semantic.shape,idx)
|
# print(pred_semantic.shape,idx)
|
||||||
@ -601,6 +605,7 @@ with gr.Blocks(title="GPT-SoVITS WebUI") as app:
|
|||||||
GPT_dropdown = gr.Dropdown(label=i18n("GPT模型列表"), choices=sorted(GPT_names, key=custom_sort_key), value=gpt_path, interactive=True)
|
GPT_dropdown = gr.Dropdown(label=i18n("GPT模型列表"), choices=sorted(GPT_names, key=custom_sort_key), value=gpt_path, interactive=True)
|
||||||
SoVITS_dropdown = gr.Dropdown(label=i18n("SoVITS模型列表"), choices=sorted(SoVITS_names, key=custom_sort_key), value=sovits_path, interactive=True)
|
SoVITS_dropdown = gr.Dropdown(label=i18n("SoVITS模型列表"), choices=sorted(SoVITS_names, key=custom_sort_key), value=sovits_path, interactive=True)
|
||||||
refresh_button = gr.Button(i18n("刷新模型路径"), variant="primary")
|
refresh_button = gr.Button(i18n("刷新模型路径"), variant="primary")
|
||||||
|
speaker_id = gr.Dropdown(label="说话人idx", choices=[str(i) for i in range(-1, 100)], value="-1")
|
||||||
refresh_button.click(fn=change_choices, inputs=[], outputs=[SoVITS_dropdown, GPT_dropdown])
|
refresh_button.click(fn=change_choices, inputs=[], outputs=[SoVITS_dropdown, GPT_dropdown])
|
||||||
SoVITS_dropdown.change(change_sovits_weights, [SoVITS_dropdown], [])
|
SoVITS_dropdown.change(change_sovits_weights, [SoVITS_dropdown], [])
|
||||||
GPT_dropdown.change(change_gpt_weights, [GPT_dropdown], [])
|
GPT_dropdown.change(change_gpt_weights, [GPT_dropdown], [])
|
||||||
@ -632,7 +637,7 @@ with gr.Blocks(title="GPT-SoVITS WebUI") as app:
|
|||||||
|
|
||||||
inference_button.click(
|
inference_button.click(
|
||||||
get_tts_wav,
|
get_tts_wav,
|
||||||
[inp_ref, prompt_text, prompt_language, text, text_language, how_to_cut,top_k,top_p,temperature],
|
[inp_ref, prompt_text, prompt_language, text, text_language, how_to_cut,top_k,top_p,temperature,speaker_id],
|
||||||
[output],
|
[output],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -98,7 +98,7 @@ for line in lines[int(i_part)::int(all_parts)]:
|
|||||||
try:
|
try:
|
||||||
# wav_name,text=line.split("\t")
|
# wav_name,text=line.split("\t")
|
||||||
wav_name, spk_name, language, text = line.split("|")
|
wav_name, spk_name, language, text = line.split("|")
|
||||||
if (inp_wav_dir !=None):
|
if (inp_wav_dir !=None and len(inp_wav_dir) > 0):
|
||||||
wav_name = os.path.basename(wav_name)
|
wav_name = os.path.basename(wav_name)
|
||||||
wav_path = "%s/%s"%(inp_wav_dir, wav_name)
|
wav_path = "%s/%s"%(inp_wav_dir, wav_name)
|
||||||
|
|
||||||
|
@ -24,24 +24,36 @@ def slice(inp,opt_root,threshold,min_length,min_interval,hop_size,max_sil_kept,_
|
|||||||
)
|
)
|
||||||
_max=float(_max)
|
_max=float(_max)
|
||||||
alpha=float(alpha)
|
alpha=float(alpha)
|
||||||
for inp_path in input[int(i_part)::int(all_part)]:
|
|
||||||
# print(inp_path)
|
def slice_audio(root_dir, inp_wav_path):
|
||||||
try:
|
try:
|
||||||
name = os.path.basename(inp_path)
|
name = os.path.basename(inp_wav_path)
|
||||||
audio = load_audio(inp_path, 32000)
|
audio = load_audio(inp_wav_path, 32000)
|
||||||
# print(audio.shape)
|
# print(audio.shape)
|
||||||
for chunk, start, end in slicer.slice(audio): # start和end是帧数
|
for chunk, start, end in slicer.slice(audio): # start和end是帧数
|
||||||
tmp_max = np.abs(chunk).max()
|
tmp_max = np.abs(chunk).max()
|
||||||
if(tmp_max>1):chunk/=tmp_max
|
if(tmp_max>1):chunk/=tmp_max
|
||||||
chunk = (chunk / tmp_max * (_max * alpha)) + (1 - alpha) * chunk
|
chunk = (chunk / tmp_max * (_max * alpha)) + (1 - alpha) * chunk
|
||||||
wavfile.write(
|
wavfile.write(
|
||||||
"%s/%s_%010d_%010d.wav" % (opt_root, name, start, end),
|
"%s/%s_%010d_%010d.wav" % (root_dir, name, start, end),
|
||||||
32000,
|
32000,
|
||||||
# chunk.astype(np.float32),
|
# chunk.astype(np.float32),
|
||||||
(chunk * 32767).astype(np.int16),
|
(chunk * 32767).astype(np.int16),
|
||||||
)
|
)
|
||||||
except:
|
except:
|
||||||
print(inp_path,"->fail->",traceback.format_exc())
|
print(inp_wav_path,"->fail->",traceback.format_exc())
|
||||||
|
|
||||||
|
for inp_path in input[int(i_part)::int(all_part)]:
|
||||||
|
# print(inp_path)
|
||||||
|
if os.path.isdir(inp_path):
|
||||||
|
inp_dir_name = os.path.basename(inp_path)
|
||||||
|
os.makedirs(os.path.join(opt_root, inp_dir_name), exist_ok=True)
|
||||||
|
for inp_name in os.listdir(inp_path):
|
||||||
|
inp_wav_path = os.path.join(inp_path, inp_name)
|
||||||
|
slice_audio(os.path.join(opt_root, inp_dir_name), inp_wav_path)
|
||||||
|
else:
|
||||||
|
inp_wav_path = inp_path
|
||||||
|
slice_audio(opt_root, inp_wav_path)
|
||||||
return "执行完毕,请检查输出文件"
|
return "执行完毕,请检查输出文件"
|
||||||
|
|
||||||
print(slice(*sys.argv[1:]))
|
print(slice(*sys.argv[1:]))
|
||||||
|
62
webui.py
62
webui.py
@ -195,20 +195,42 @@ def change_tts_inference(if_tts,bert_path,cnhubert_base_path,gpu_number,gpt_path
|
|||||||
from tools.asr.config import asr_dict
|
from tools.asr.config import asr_dict
|
||||||
def open_asr(asr_inp_dir, asr_opt_dir, asr_model, asr_model_size, asr_lang):
|
def open_asr(asr_inp_dir, asr_opt_dir, asr_model, asr_model_size, asr_lang):
|
||||||
global p_asr
|
global p_asr
|
||||||
if(p_asr==None):
|
def run(asr_inp_dir, asr_opt_dir):
|
||||||
asr_inp_dir=my_utils.clean_path(asr_inp_dir)
|
|
||||||
cmd = f'"{python_exec}" tools/asr/{asr_dict[asr_model]["path"]}'
|
cmd = f'"{python_exec}" tools/asr/{asr_dict[asr_model]["path"]}'
|
||||||
cmd += f' -i "{asr_inp_dir}"'
|
cmd += f' -i "{asr_inp_dir}"'
|
||||||
cmd += f' -o "{asr_opt_dir}"'
|
cmd += f' -o "{asr_opt_dir}"'
|
||||||
cmd += f' -s {asr_model_size}'
|
cmd += f' -s {asr_model_size}'
|
||||||
cmd += f' -l {asr_lang}'
|
cmd += f' -l {asr_lang}'
|
||||||
cmd += " -p %s"%("float16"if is_half==True else "float32")
|
cmd += " -p %s"%("float16"if is_half==True else "float32")
|
||||||
|
|
||||||
yield "ASR任务开启:%s"%cmd,{"__type__":"update","visible":False},{"__type__":"update","visible":True}
|
|
||||||
print(cmd)
|
print(cmd)
|
||||||
p_asr = Popen(cmd, shell=True)
|
p_asr = Popen(cmd, shell=True)
|
||||||
p_asr.wait()
|
p_asr.wait()
|
||||||
p_asr=None
|
p_asr=None
|
||||||
|
if(p_asr==None):
|
||||||
|
# yield "ASR任务开启:%s"%cmd,{"__type__":"update","visible":False},{"__type__":"update","visible":True}
|
||||||
|
flag = False
|
||||||
|
for name in os.listdir(asr_inp_dir):
|
||||||
|
if os.path.isdir(os.path.join(asr_inp_dir, name)):
|
||||||
|
os.makedirs(os.path.join(asr_opt_dir, name), exist_ok=True)
|
||||||
|
run(os.path.join(asr_inp_dir, name), os.path.join(asr_opt_dir, name))
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
flag = True
|
||||||
|
if flag:
|
||||||
|
run(asr_inp_dir, asr_opt_dir)
|
||||||
|
# asr_inp_dir=my_utils.clean_path(asr_inp_dir)
|
||||||
|
# cmd = f'"{python_exec}" tools/asr/{asr_dict[asr_model]["path"]}'
|
||||||
|
# cmd += f' -i "{asr_inp_dir}"'
|
||||||
|
# cmd += f' -o "{asr_opt_dir}"'
|
||||||
|
# cmd += f' -s {asr_model_size}'
|
||||||
|
# cmd += f' -l {asr_lang}'
|
||||||
|
# cmd += " -p %s"%("float16"if is_half==True else "float32")
|
||||||
|
|
||||||
|
# print(cmd)
|
||||||
|
# p_asr = Popen(cmd, shell=True)
|
||||||
|
# p_asr.wait()
|
||||||
|
# p_asr=None
|
||||||
yield f"ASR任务完成, 查看终端进行下一步",{"__type__":"update","visible":True},{"__type__":"update","visible":False}
|
yield f"ASR任务完成, 查看终端进行下一步",{"__type__":"update","visible":True},{"__type__":"update","visible":False}
|
||||||
else:
|
else:
|
||||||
yield "已有正在进行的ASR任务,需先终止才能开启下一次任务",{"__type__":"update","visible":False},{"__type__":"update","visible":True}
|
yield "已有正在进行的ASR任务,需先终止才能开启下一次任务",{"__type__":"update","visible":False},{"__type__":"update","visible":True}
|
||||||
@ -273,7 +295,7 @@ def open1Bb(batch_size,total_epoch,exp_name,if_save_latest,if_save_every_weights
|
|||||||
data=f.read()
|
data=f.read()
|
||||||
data=yaml.load(data, Loader=yaml.FullLoader)
|
data=yaml.load(data, Loader=yaml.FullLoader)
|
||||||
s1_dir="%s/%s"%(exp_root,exp_name)
|
s1_dir="%s/%s"%(exp_root,exp_name)
|
||||||
os.makedirs("%s/logs_s1"%(s1_dir),exist_ok=True)
|
# os.makedirs("%s/logs_s1"%(s1_dir),exist_ok=True)
|
||||||
if(is_half==False):
|
if(is_half==False):
|
||||||
data["train"]["precision"]="32"
|
data["train"]["precision"]="32"
|
||||||
batch_size = max(1, batch_size // 2)
|
batch_size = max(1, batch_size // 2)
|
||||||
@ -287,6 +309,8 @@ def open1Bb(batch_size,total_epoch,exp_name,if_save_latest,if_save_every_weights
|
|||||||
data["train"]["exp_name"]=exp_name
|
data["train"]["exp_name"]=exp_name
|
||||||
data["train_semantic_path"]="%s/6-name2semantic.tsv"%s1_dir
|
data["train_semantic_path"]="%s/6-name2semantic.tsv"%s1_dir
|
||||||
data["train_phoneme_path"]="%s/2-name2text.txt"%s1_dir
|
data["train_phoneme_path"]="%s/2-name2text.txt"%s1_dir
|
||||||
|
data["train_dir"] = s1_dir
|
||||||
|
data["speaker_dict_path"] = "%s/speaker_dict.json"%s1_dir
|
||||||
data["output_dir"]="%s/logs_s1"%s1_dir
|
data["output_dir"]="%s/logs_s1"%s1_dir
|
||||||
|
|
||||||
os.environ["_CUDA_VISIBLE_DEVICES"]=gpu_numbers.replace("-",",")
|
os.environ["_CUDA_VISIBLE_DEVICES"]=gpu_numbers.replace("-",",")
|
||||||
@ -352,6 +376,13 @@ def close_slice():
|
|||||||
ps1a=[]
|
ps1a=[]
|
||||||
def open1a(inp_text,inp_wav_dir,exp_name,gpu_numbers,bert_pretrained_dir):
|
def open1a(inp_text,inp_wav_dir,exp_name,gpu_numbers,bert_pretrained_dir):
|
||||||
global ps1a
|
global ps1a
|
||||||
|
if not inp_text.endswith(".list"):
|
||||||
|
for name in os.listdir(inp_text):
|
||||||
|
p = os.path.join(inp_text, name)
|
||||||
|
l = os.path.join(p, name + ".list")
|
||||||
|
assert os.path.exists(l)
|
||||||
|
yield from open1a(l,inp_wav_dir,os.path.join(exp_name, name),gpu_numbers,bert_pretrained_dir)
|
||||||
|
return
|
||||||
inp_text = my_utils.clean_path(inp_text)
|
inp_text = my_utils.clean_path(inp_text)
|
||||||
inp_wav_dir = my_utils.clean_path(inp_wav_dir)
|
inp_wav_dir = my_utils.clean_path(inp_wav_dir)
|
||||||
if (ps1a == []):
|
if (ps1a == []):
|
||||||
@ -410,6 +441,13 @@ def close1a():
|
|||||||
ps1b=[]
|
ps1b=[]
|
||||||
def open1b(inp_text,inp_wav_dir,exp_name,gpu_numbers,ssl_pretrained_dir):
|
def open1b(inp_text,inp_wav_dir,exp_name,gpu_numbers,ssl_pretrained_dir):
|
||||||
global ps1b
|
global ps1b
|
||||||
|
if not inp_text.endswith(".list"):
|
||||||
|
for name in os.listdir(inp_text):
|
||||||
|
p = os.path.join(inp_text, name)
|
||||||
|
l = os.path.join(p, name + ".list")
|
||||||
|
assert os.path.exists(l)
|
||||||
|
yield from open1b(l,inp_wav_dir,os.path.join(exp_name, name),gpu_numbers,ssl_pretrained_dir)
|
||||||
|
return
|
||||||
inp_text = my_utils.clean_path(inp_text)
|
inp_text = my_utils.clean_path(inp_text)
|
||||||
inp_wav_dir = my_utils.clean_path(inp_wav_dir)
|
inp_wav_dir = my_utils.clean_path(inp_wav_dir)
|
||||||
if (ps1b == []):
|
if (ps1b == []):
|
||||||
@ -458,6 +496,13 @@ def close1b():
|
|||||||
ps1c=[]
|
ps1c=[]
|
||||||
def open1c(inp_text,exp_name,gpu_numbers,pretrained_s2G_path):
|
def open1c(inp_text,exp_name,gpu_numbers,pretrained_s2G_path):
|
||||||
global ps1c
|
global ps1c
|
||||||
|
if not inp_text.endswith(".list"):
|
||||||
|
for name in os.listdir(inp_text):
|
||||||
|
p = os.path.join(inp_text, name)
|
||||||
|
l = os.path.join(p, name + ".list")
|
||||||
|
assert os.path.exists(l)
|
||||||
|
yield from open1c(l,os.path.join(exp_name, name),gpu_numbers,pretrained_s2G_path)
|
||||||
|
return
|
||||||
inp_text = my_utils.clean_path(inp_text)
|
inp_text = my_utils.clean_path(inp_text)
|
||||||
if (ps1c == []):
|
if (ps1c == []):
|
||||||
opt_dir="%s/%s"%(exp_root,exp_name)
|
opt_dir="%s/%s"%(exp_root,exp_name)
|
||||||
@ -515,6 +560,13 @@ def close1c():
|
|||||||
ps1abc=[]
|
ps1abc=[]
|
||||||
def open1abc(inp_text,inp_wav_dir,exp_name,gpu_numbers1a,gpu_numbers1Ba,gpu_numbers1c,bert_pretrained_dir,ssl_pretrained_dir,pretrained_s2G_path):
|
def open1abc(inp_text,inp_wav_dir,exp_name,gpu_numbers1a,gpu_numbers1Ba,gpu_numbers1c,bert_pretrained_dir,ssl_pretrained_dir,pretrained_s2G_path):
|
||||||
global ps1abc
|
global ps1abc
|
||||||
|
if not inp_text.endswith(".list"):
|
||||||
|
for name in os.listdir(inp_text):
|
||||||
|
p = os.path.join(inp_text, name)
|
||||||
|
l = os.path.join(p, name + ".list")
|
||||||
|
assert os.path.exists(l)
|
||||||
|
yield from open1abc(l,inp_wav_dir,os.path.join(exp_name, name),gpu_numbers1a,gpu_numbers1Ba,gpu_numbers1c,bert_pretrained_dir,ssl_pretrained_dir,pretrained_s2G_path)
|
||||||
|
return
|
||||||
inp_text = my_utils.clean_path(inp_text)
|
inp_text = my_utils.clean_path(inp_text)
|
||||||
inp_wav_dir = my_utils.clean_path(inp_wav_dir)
|
inp_wav_dir = my_utils.clean_path(inp_wav_dir)
|
||||||
if (ps1abc == []):
|
if (ps1abc == []):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user