将一个batch中的padding策略,从padding on right改为了padding on left。

This commit is contained in:
ChasonJiang 2024-04-12 00:49:44 +08:00
parent 3706ad1b8b
commit 5fe9842069
8 changed files with 569 additions and 125 deletions

View File

@ -32,6 +32,7 @@ class Text2SemanticDataModule(LightningDataModule):
semantic_path=self.train_semantic_path, semantic_path=self.train_semantic_path,
max_sec=self.config["data"]["max_sec"], max_sec=self.config["data"]["max_sec"],
pad_val=self.config["data"]["pad_val"], pad_val=self.config["data"]["pad_val"],
padding_on_left=self.config["train"]["padding_on_left"],
) )
self._dev_dataset = self._train_dataset self._dev_dataset = self._train_dataset
# self._dev_dataset = Text2SemanticDataset( # self._dev_dataset = Text2SemanticDataset(

View File

@ -55,9 +55,10 @@ class Text2SemanticDataset(Dataset):
min_ps_ratio: int = 3, min_ps_ratio: int = 3,
# max value of phoneme/sec # max value of phoneme/sec
max_ps_ratio: int = 25, max_ps_ratio: int = 25,
padding_on_left:bool=False,
) -> None: ) -> None:
super().__init__() super().__init__()
self.padding_on_left=padding_on_left
self.semantic_data = pd.read_csv( self.semantic_data = pd.read_csv(
semantic_path, delimiter="\t", encoding="utf-8" semantic_path, delimiter="\t", encoding="utf-8"
) )
@ -164,7 +165,9 @@ class Text2SemanticDataset(Dataset):
# if len(semantic_ids) > 1000:###########3 # if len(semantic_ids) > 1000:###########3
# num_deleted_bigger += 1 # num_deleted_bigger += 1
# continue # continue
if (len(semantic_ids)+len(phoneme_ids)) > 1000:###########3
num_deleted_bigger += 1
continue
ps_ratio = len(phoneme_ids) / (len(semantic_ids) / self.hz) ps_ratio = len(phoneme_ids) / (len(semantic_ids) / self.hz)
if ( if (
@ -173,7 +176,8 @@ class Text2SemanticDataset(Dataset):
num_deleted_ps += 1 num_deleted_ps += 1
# print(item_name) # print(item_name)
continue continue
idx_len=[]
self.semantic_phoneme.append((semantic_ids, phoneme_ids)) self.semantic_phoneme.append((semantic_ids, phoneme_ids))
idx += 1 idx += 1
self.item_names.append(item_name) self.item_names.append(item_name)
@ -253,46 +257,73 @@ class Text2SemanticDataset(Dataset):
phoneme_ids_lens: List[int] = [] phoneme_ids_lens: List[int] = []
semantic_ids: List[torch.Tensor] = [] semantic_ids: List[torch.Tensor] = []
semantic_ids_lens: List[int] = [] semantic_ids_lens: List[int] = []
# return
for item in examples: if not self.padding_on_left:
sample_index.append(item["idx"]) for item in examples:
phoneme_ids.append(np.array(item["phoneme_ids"], dtype=np.int64)) sample_index.append(item["idx"])
semantic_ids.append(np.array(item["semantic_ids"], dtype=np.int64)) phoneme_ids.append(np.array(item["phoneme_ids"], dtype=np.int64))
phoneme_ids_lens.append(item["phoneme_ids_len"]) semantic_ids.append(np.array(item["semantic_ids"], dtype=np.int64))
semantic_ids_lens.append(item["semantic_ids_len"]) phoneme_ids_lens.append(item["phoneme_ids_len"])
semantic_ids_lens.append(item["semantic_ids_len"])
# pad 0 # pad 0
phoneme_ids = batch_sequences(phoneme_ids) phoneme_ids = batch_sequences(phoneme_ids)
semantic_ids = batch_sequences(semantic_ids, pad_value=self.PAD) semantic_ids = batch_sequences(semantic_ids, pad_value=self.PAD)
# # convert each batch to torch.tensor # # convert each batch to torch.tensor
phoneme_ids = torch.tensor(phoneme_ids) phoneme_ids = torch.tensor(phoneme_ids)
semantic_ids = torch.tensor(semantic_ids) semantic_ids = torch.tensor(semantic_ids)
phoneme_ids_lens = torch.tensor(phoneme_ids_lens) phoneme_ids_lens = torch.tensor(phoneme_ids_lens)
semantic_ids_lens = torch.tensor(semantic_ids_lens) semantic_ids_lens = torch.tensor(semantic_ids_lens)
bert_padded = torch.FloatTensor(len(examples), 1024, max(phoneme_ids_lens)) bert_padded = torch.FloatTensor(len(examples), 1024, max(phoneme_ids_lens))
bert_padded.zero_() bert_padded.zero_()
for idx, item in enumerate(examples): for idx, item in enumerate(examples):
bert = item["bert_feature"] bert = item["bert_feature"]
if bert != None: if bert != None:
bert_padded[idx, :, : bert.shape[-1]] = bert bert_padded[idx, :, : bert.shape[-1]] = bert
return { return {
# List[int] # List[int]
"ids": sample_index, "ids": sample_index,
# torch.Tensor (B, max_phoneme_length) # torch.Tensor (B, max_phoneme_length)
"phoneme_ids": phoneme_ids, "phoneme_ids": phoneme_ids,
# torch.Tensor (B) # torch.Tensor (B)
"phoneme_ids_len": phoneme_ids_lens, "phoneme_ids_len": phoneme_ids_lens,
# torch.Tensor (B, max_semantic_ids_length) # torch.Tensor (B, max_semantic_ids_length)
"semantic_ids": semantic_ids, "semantic_ids": semantic_ids,
# torch.Tensor (B) # torch.Tensor (B)
"semantic_ids_len": semantic_ids_lens, "semantic_ids_len": semantic_ids_lens,
# torch.Tensor (B, 1024, max_phoneme_length) # torch.Tensor (B, 1024, max_phoneme_length)
"bert_feature": bert_padded, "bert_feature": bert_padded,
} }
else:
for item in examples:
sample_index.append(item["idx"])
phoneme_ids.append(torch.LongTensor(np.array(item["phoneme_ids"], dtype=np.int64)))
semantic_ids.append(torch.LongTensor(np.array(item["semantic_ids"], dtype=np.int64)))
phoneme_ids_lens.append(item["phoneme_ids_len"])
semantic_ids_lens.append(item["semantic_ids_len"])
phoneme_ids_lens = torch.tensor(phoneme_ids_lens)
semantic_ids_lens = torch.tensor(semantic_ids_lens)
bert_features: List[torch.Tensor] = [item["bert_feature"] for item in examples]
return {
# List[int]
"ids": sample_index,
# List[torch.Tensor] (B, max_phoneme_length)
"phoneme_ids": phoneme_ids,
# torch.Tensor (B)
"phoneme_ids_len": phoneme_ids_lens,
# List[torch.Tensor] (B, max_semantic_ids_length)
"semantic_ids": semantic_ids,
# torch.Tensor (B)
"semantic_ids_len": semantic_ids_lens,
# List[torch.Tensor] (B, 1024, max_phoneme_length)
"bert_feature": bert_features,
}
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -14,7 +14,7 @@ from AR.modules.optim import ScaledAdam
class Text2SemanticLightningModule(LightningModule): class Text2SemanticLightningModule(LightningModule):
def __init__(self, config, output_dir, is_train=True, flash_attn_enabled:bool = False): def __init__(self, config, output_dir, is_train=True, flash_attn_enabled:bool = False):
super().__init__() super(Text2SemanticLightningModule,self).__init__()
self.config = config self.config = config
self.top_k = 3 self.top_k = 3
self.model = Text2SemanticDecoder(config=config, top_k=self.top_k,flash_attn_enabled=flash_attn_enabled) self.model = Text2SemanticDecoder(config=config, top_k=self.top_k,flash_attn_enabled=flash_attn_enabled)
@ -35,7 +35,14 @@ class Text2SemanticLightningModule(LightningModule):
def training_step(self, batch: Dict, batch_idx: int): def training_step(self, batch: Dict, batch_idx: int):
opt = self.optimizers() opt = self.optimizers()
scheduler = self.lr_schedulers() scheduler = self.lr_schedulers()
forward=self.model.forward if self.config["train"].get("if_dpo",False)==True else self.model.forward_old forward=None
if self.config["train"].get("if_dpo",False):
forward=self.model.forward
elif self.config["train"].get("padding_on_left",False):
forward=self.model.forward_old_padding_on_left
else:
forward=self.model.forward_old
loss, acc = forward( loss, acc = forward(
batch["phoneme_ids"], batch["phoneme_ids"],
batch["phoneme_ids_len"], batch["phoneme_ids_len"],
@ -56,6 +63,7 @@ class Text2SemanticLightningModule(LightningModule):
on_epoch=True, on_epoch=True,
prog_bar=True, prog_bar=True,
sync_dist=True, sync_dist=True,
batch_size=batch["phoneme_ids_len"].shape[0],
) )
self.log( self.log(
"lr", "lr",
@ -63,6 +71,7 @@ class Text2SemanticLightningModule(LightningModule):
on_epoch=True, on_epoch=True,
prog_bar=True, prog_bar=True,
sync_dist=True, sync_dist=True,
batch_size=batch["phoneme_ids_len"].shape[0],
) )
self.log( self.log(
f"top_{self.top_k}_acc", f"top_{self.top_k}_acc",
@ -71,7 +80,10 @@ class Text2SemanticLightningModule(LightningModule):
on_epoch=True, on_epoch=True,
prog_bar=True, prog_bar=True,
sync_dist=True, sync_dist=True,
batch_size=batch["phoneme_ids_len"].shape[0],
) )
if torch.cuda.is_available():
torch.cuda.empty_cache()
def validation_step(self, batch: Dict, batch_idx: int): def validation_step(self, batch: Dict, batch_idx: int):
return return

View File

@ -1,5 +1,6 @@
# modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/models/t2s_model.py # modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/models/t2s_model.py
# reference: https://github.com/lifeiteng/vall-e # reference: https://github.com/lifeiteng/vall-e
import math
import os, sys import os, sys
now_dir = os.getcwd() now_dir = os.getcwd()
sys.path.append(now_dir) sys.path.append(now_dir)
@ -38,7 +39,6 @@ default_config = {
"EOS": 1024, "EOS": 1024,
} }
@torch.jit.script @torch.jit.script
class T2SMLP: class T2SMLP:
def __init__(self, w1, b1, w2, b2): def __init__(self, w1, b1, w2, b2):
@ -362,7 +362,8 @@ class Text2SemanticDecoder(nn.Module):
loss = loss_1 + loss_2 loss = loss_1 + loss_2
return loss, acc return loss, acc
#padding on right
def forward_old(self, x, x_lens, y, y_lens, bert_feature): def forward_old(self, x, x_lens, y, y_lens, bert_feature):
""" """
x: phoneme_ids x: phoneme_ids
@ -424,6 +425,91 @@ class Text2SemanticDecoder(nn.Module):
loss = F.cross_entropy(logits, targets, reduction="sum") loss = F.cross_entropy(logits, targets, reduction="sum")
acc = self.ar_accuracy_metric(logits.detach(), targets).item() acc = self.ar_accuracy_metric(logits.detach(), targets).item()
return loss, acc return loss, acc
def forward_old_padding_on_left(self,
x:List[torch.Tensor],
x_lens:torch.LongTensor,
y:List[torch.Tensor],
y_lens:torch.LongTensor,
bert_feature:List[torch.Tensor]):
"""
x: phoneme_ids
y: semantic_ids
"""
device = x[0].device
x_len = x_lens.max()
y_len = y_lens.max()
batch_size = len(x)
xy_pos = torch.zeros((batch_size, x_len+y_len, self.embedding_dim)).to(device)
targets:List[torch.LongTensor] = []
xy_attn_mask_list = []
for i in range(batch_size):
padding_len = (x_len-x_lens[i])+(y_len-y_lens[i])
x_item=self.ar_text_embedding(x[i].unsqueeze(0))
if bert_feature[i] is not None:
x_item = x_item + self.bert_proj(bert_feature[i].transpose(0, 1).unsqueeze(0))
# x_item = F.pad(x_item, (0, 0, padding_len, 0), value=0)
x_item = self.ar_text_position(x_item).squeeze(0)
y_item = self.ar_audio_position(self.ar_audio_embedding(y[i].unsqueeze(0))).squeeze(0)
xy_pos[i, padding_len:padding_len+x_lens[i],:] = x_item
xy_pos[i, padding_len+x_lens[i]:,:] = y_item
target = torch.zeros(y_lens[i], dtype=torch.long).to(device)
target[:-1] = y[i][1:]
target[-1] = self.EOS
targets.append(target.unsqueeze(0))
x_attn_mask = torch.zeros((x_len+(y_len-y_lens[i]), x_len+y_len), dtype=torch.bool).to(device)
x_attn_mask[:, -y_lens[i]:] = True
y_attn_mask = F.pad(
torch.triu(
torch.ones(y_lens[i], y_lens[i], dtype=torch.bool).to(device),
diagonal=1,
),
(x_len+(y_len-y_lens[i]), 0),
value=False,
)
attn_mask = torch.concat([x_attn_mask, y_attn_mask], dim=0)
if padding_len>0:
attn_mask[:, :padding_len] = True
xy_attn_mask_list.append(attn_mask)
xy_attn_mask = torch.stack(xy_attn_mask_list, dim=0)
new_attn_mask = torch.zeros_like(xy_attn_mask, dtype=xy_pos.dtype)
new_attn_mask.masked_fill_(xy_attn_mask, torch.finfo(xy_pos.dtype).min)
xy_attn_mask = new_attn_mask
xy_attn_mask = (xy_attn_mask.view(batch_size, 1, x_len+y_len, x_len+y_len)
.expand(-1, self.num_head, -1, -1)
.reshape(batch_size * self.num_head, x_len+y_len, x_len+y_len))
# x 和完整的 y 一次性输入模型
# xy_pos = torch.concat([x, y_pos], dim=1)
xy_dec, _ = self.h(
(xy_pos, None),
mask=xy_attn_mask,
)
logits = [self.ar_predict_layer(xy_dec[i, -y_lens[i]:, :].unsqueeze(0)).permute(0, 2, 1) for i in range(batch_size)]
# loss
# from feiteng: 每次 duration 越多, 梯度更新也应该更多, 所以用 sum
loss = None
acc = None
for i in range(batch_size):
if loss is None:
loss = F.cross_entropy(logits[i], targets[i], reduction="sum")
acc = self.ar_accuracy_metric(logits[i].detach(), targets[i].detach()).item()
else:
loss += F.cross_entropy(logits[i], targets[i], reduction="sum")
acc += self.ar_accuracy_metric(logits[i].detach(), targets[i].detach()).item()
acc /= batch_size
# loss = F.cross_entropy(logits, targets, reduction="sum")
# acc = self.ar_accuracy_metric(logits.detach(), targets).item()
return loss, acc
# 需要看下这个函数和 forward 的区别以及没有 semantic 的时候 prompts 输入什么 # 需要看下这个函数和 forward 的区别以及没有 semantic 的时候 prompts 输入什么
def infer( def infer(
@ -512,85 +598,80 @@ class Text2SemanticDecoder(nn.Module):
top_p: int = 100, top_p: int = 100,
early_stop_num: int = -1, early_stop_num: int = -1,
temperature: float = 1.0, temperature: float = 1.0,
repetition_penalty: float = 1.35,
dtype:torch.dtype = torch.float32,
): ):
# 先对phones进行embedding、对bert_features进行project再pad到相同长度以缓解复读问题。可能还有其他因素导致复读
max_len = 0
for x_item, bert_item in zip(x, bert_feature):
max_len = max(max_len, x_item.shape[0], bert_item.shape[1])
x_list = [self.ar_text_embedding(item) for item in x]
x_list = [F.pad(item,(0,0,0,max_len-item.shape[0]),value=0) if item.shape[0]<max_len else item for item in x_list]
x = torch.stack(x_list, dim=0)
bert_features_list = [self.bert_proj(item.transpose(0, 1)) for item in bert_feature] device = x[0].device
bert_features_list = [F.pad(item,(0,0,0,max_len-item.shape[0]), value=0) if item.shape[0]<max_len else item for item in bert_features_list] x_len = x_lens.max()
bert_feature = torch.stack(bert_features_list, dim=0) batch_size = len(x)
# bert_feature = self.bert_proj(bert_feature.transpose(1, 2))
# x = self.ar_text_embedding(x)
x = x + bert_feature
x = self.ar_text_position(x)
# AR Decoder
y = prompts y = prompts
x_len = x.shape[1]
x_attn_mask = torch.zeros((x_len, x_len), dtype=torch.bool)
stop = False stop = False
# print(1111111,self.num_layers)
k_cache = None k_cache = None
v_cache = None v_cache = None
################### first step ########################## ################### first step ##########################
if y is not None: if y is None:
y_emb = self.ar_audio_embedding(y)
y_len = y_emb.shape[1]
prefix_len = y.shape[1]
y_pos = self.ar_audio_position(y_emb)
xy_pos = torch.concat([x, y_pos], dim=1)
ref_free = False
else:
y_emb = None y_emb = None
y_len = 0 y_len = 0
prefix_len = 0 prefix_len = 0
y_pos = None y_pos = None
xy_pos = x y = torch.zeros(batch_size, 0, dtype=torch.int, device=device)
y = torch.zeros(x.shape[0], 0, dtype=torch.int, device=x.device)
ref_free = True ref_free = True
else:
y_emb = self.ar_audio_embedding(y)
y_len = y_emb.shape[1]
prefix_len = y.shape[1]
y_pos = self.ar_audio_position(y_emb)
ref_free = False
##### create mask #####
bsz = x.shape[0]
src_len = x_len + y_len
y_lens = torch.LongTensor([y_len]*bsz).to(x.device)
y_mask = make_pad_mask(y_lens)
x_mask = make_pad_mask(x_lens)
# (bsz, x_len + y_len) xy_pos = torch.zeros((batch_size, x_len+y_len, self.embedding_dim),dtype=dtype).to(device)
xy_padding_mask = torch.concat([x_mask, y_mask], dim=1) # ar_xy_padding_mask = torch.zeros((batch_size, x_len+y_len), device=device, dtype=torch.bool)
xy_attn_mask_list = []
for i in range(batch_size):
padding_len = (x_len-x_lens[i])
x_item=self.ar_text_embedding(x[i].unsqueeze(0))
if bert_feature[i] is not None:
x_item = x_item + self.bert_proj(bert_feature[i].transpose(0, 1).unsqueeze(0))
# x_item = F.pad(x_item, (0, 0, padding_len, 0), value=0)
x_item = self.ar_text_position(x_item).squeeze(0)
xy_pos[i, padding_len:padding_len+x_lens[i],:] = x_item
if not ref_free:
xy_pos[i, padding_len+x_lens[i]:,:] = y_pos[i]
x_mask = F.pad( x_attn_mask = torch.zeros((x_len, x_len+y_len), dtype=torch.bool).to(device)
x_attn_mask, if not ref_free:
(0, y_len), ###xx的纯0扩展到xx纯0+xy纯1(x,x+y) x_attn_mask[:, -y_len:] = True
value=True, y_attn_mask = F.pad(
) torch.triu(
y_mask = F.pad( ###yy的右上1扩展到左边xy的0,(y,x+y) torch.ones(y_len, y_len, dtype=torch.bool).to(device),
torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1), diagonal=1,
(x_len, 0), ),
value=False, (x_len, 0),
) value=False,
)
xy_mask = torch.concat([x_mask, y_mask], dim=0).view(1 , src_len, src_len).expand(bsz, -1, -1).to(x.device) attn_mask = torch.concat([x_attn_mask, y_attn_mask], dim=0)
# xy_mask = torch.triu(torch.ones(src_len, src_len, dtype=torch.bool, device=x.device), diagonal=1) if padding_len>0:
xy_padding_mask = xy_padding_mask.view(bsz, 1, src_len).expand(-1, src_len, src_len) attn_mask[:, :padding_len] = True
xy_attn_mask = xy_mask.logical_or(xy_padding_mask) xy_attn_mask_list.append(attn_mask)
xy_attn_mask = xy_attn_mask.unsqueeze(1).expand(-1, self.num_head, -1, -1)
new_attn_mask = torch.zeros_like(xy_attn_mask, dtype=x.dtype) xy_attn_mask = torch.stack(xy_attn_mask_list, dim=0)
xy_attn_mask = new_attn_mask.masked_fill(xy_attn_mask, float("-inf")) new_attn_mask = torch.zeros_like(xy_attn_mask, dtype=xy_pos.dtype)
new_attn_mask.masked_fill_(xy_attn_mask, torch.finfo(xy_pos.dtype).min)
xy_attn_mask = new_attn_mask
xy_attn_mask = (xy_attn_mask.view(batch_size, 1, x_len+y_len, x_len+y_len)
.expand(-1, self.num_head, -1, -1))
###### decode ##### ###### decode #####
y_list = [None]*y.shape[0] y_list = [None]*batch_size
batch_idx_map = list(range(y.shape[0])) batch_idx_map = list(range(batch_size))
idx_list = [None]*y.shape[0] idx_list = [None]*batch_size
for idx in tqdm(range(1500)): for idx in tqdm(range(1500)):
if idx == 0: if idx == 0:
xy_dec, k_cache, v_cache = self.t2s_transformer.process_prompt(xy_pos, xy_attn_mask) xy_dec, k_cache, v_cache = self.t2s_transformer.process_prompt(xy_pos, xy_attn_mask)
@ -606,7 +687,7 @@ class Text2SemanticDecoder(nn.Module):
logits = logits[:, :-1] logits = logits[:, :-1]
samples = sample( samples = sample(
logits, y, top_k=top_k, top_p=top_p, repetition_penalty=1.35, temperature=temperature logits, y, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature
)[0] )[0]
y = torch.concat([y, samples], dim=1) y = torch.concat([y, samples], dim=1)
@ -659,12 +740,12 @@ class Text2SemanticDecoder(nn.Module):
xy_pos = y_emb * self.ar_audio_position.x_scale + self.ar_audio_position.alpha * self.ar_audio_position.pe[:, y_len + idx].to( dtype= y_emb.dtype,device=y_emb.device) xy_pos = y_emb * self.ar_audio_position.x_scale + self.ar_audio_position.alpha * self.ar_audio_position.pe[:, y_len + idx].to( dtype= y_emb.dtype,device=y_emb.device)
if (None in idx_list): if (None in idx_list):
for i in range(x.shape[0]): for i in range(batch_size):
if idx_list[i] is None: if idx_list[i] is None:
idx_list[i] = 1500-1 ###如果没有生成到EOS就用最大长度代替 idx_list[i] = 1500-1 ###如果没有生成到EOS就用最大长度代替
if ref_free: if ref_free:
return y_list, [0]*x.shape[0] return y_list, [0]*batch_size
return y_list, idx_list return y_list, idx_list
def infer_panel_batch_only( def infer_panel_batch_only(
@ -677,6 +758,8 @@ class Text2SemanticDecoder(nn.Module):
top_p: int = 100, top_p: int = 100,
early_stop_num: int = -1, early_stop_num: int = -1,
temperature: float = 1.0, temperature: float = 1.0,
repetition_penalty: float = 1.35,
**kwargs
): ):
# 先对phones进行embedding、对bert_features进行project再pad到相同长度以缓解复读问题。可能还有其他因素导致复读 # 先对phones进行embedding、对bert_features进行project再pad到相同长度以缓解复读问题。可能还有其他因素导致复读
max_len = 0 max_len = 0
@ -772,7 +855,7 @@ class Text2SemanticDecoder(nn.Module):
if(idx==0):###第一次跑不能EOS否则没有了 if(idx==0):###第一次跑不能EOS否则没有了
logits = logits[:, :-1] ###刨除1024终止符号的概率 logits = logits[:, :-1] ###刨除1024终止符号的概率
samples = sample( samples = sample(
logits, y, top_k=top_k, top_p=top_p, repetition_penalty=1.35, temperature=temperature logits, y, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature
)[0] )[0]
# 本次生成的 semantic_ids 和之前的 y 构成新的 y # 本次生成的 semantic_ids 和之前的 y 构成新的 y
# print(samples.shape)#[1,1]#第一个1是bs # print(samples.shape)#[1,1]#第一个1是bs
@ -854,4 +937,298 @@ class Text2SemanticDecoder(nn.Module):
if ref_free: if ref_free:
return y_list, [0]*x.shape[0] return y_list, [0]*x.shape[0]
return y_list, idx_list return y_list, idx_list
# padding on right
def infer_panel_batch_infer_with_flash_attn_old(
self,
x:List[torch.LongTensor], #####全部文本token
x_lens:torch.LongTensor,
prompts:torch.LongTensor, ####参考音频token
bert_feature:List[torch.LongTensor],
top_k: int = -100,
top_p: int = 100,
early_stop_num: int = -1,
temperature: float = 1.0,
repetition_penalty: float = 1.35,
**kwargs
):
# 先对phones进行embedding、对bert_features进行project再pad到相同长度以缓解复读问题。可能还有其他因素导致复读
max_len = 0
for x_item, bert_item in zip(x, bert_feature):
max_len = max(max_len, x_item.shape[0], bert_item.shape[1])
x_list = [self.ar_text_embedding(item) for item in x]
x_list = [F.pad(item,(0,0,0,max_len-item.shape[0]),value=0) if item.shape[0]<max_len else item for item in x_list]
x = torch.stack(x_list, dim=0)
bert_features_list = [self.bert_proj(item.transpose(0, 1)) for item in bert_feature]
bert_features_list = [F.pad(item,(0,0,0,max_len-item.shape[0]), value=0) if item.shape[0]<max_len else item for item in bert_features_list]
bert_feature = torch.stack(bert_features_list, dim=0)
# bert_feature = self.bert_proj(bert_feature.transpose(1, 2))
# x = self.ar_text_embedding(x)
x = x + bert_feature
x = self.ar_text_position(x)
# AR Decoder
y = prompts
x_len = x.shape[1]
x_attn_mask = torch.zeros((x_len, x_len), dtype=torch.bool)
stop = False
# print(1111111,self.num_layers)
k_cache = None
v_cache = None
################### first step ##########################
if y is not None:
y_emb = self.ar_audio_embedding(y)
y_len = y_emb.shape[1]
prefix_len = y.shape[1]
y_pos = self.ar_audio_position(y_emb)
xy_pos = torch.concat([x, y_pos], dim=1)
ref_free = False
else:
y_emb = None
y_len = 0
prefix_len = 0
y_pos = None
xy_pos = x
y = torch.zeros(x.shape[0], 0, dtype=torch.int, device=x.device)
ref_free = True
##### create mask #####
bsz = x.shape[0]
src_len = x_len + y_len
y_lens = torch.LongTensor([y_len]*bsz).to(x.device)
y_mask = make_pad_mask(y_lens)
x_mask = make_pad_mask(x_lens)
# (bsz, x_len + y_len)
xy_padding_mask = torch.concat([x_mask, y_mask], dim=1)
x_mask = F.pad(
x_attn_mask,
(0, y_len), ###xx的纯0扩展到xx纯0+xy纯1(x,x+y)
value=True,
)
y_mask = F.pad( ###yy的右上1扩展到左边xy的0,(y,x+y)
torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1),
(x_len, 0),
value=False,
)
xy_mask = torch.concat([x_mask, y_mask], dim=0).view(1 , src_len, src_len).expand(bsz, -1, -1).to(x.device)
# xy_mask = torch.triu(torch.ones(src_len, src_len, dtype=torch.bool, device=x.device), diagonal=1)
xy_padding_mask = xy_padding_mask.view(bsz, 1, src_len).expand(-1, src_len, src_len)
xy_attn_mask = xy_mask.logical_or(xy_padding_mask)
xy_attn_mask = xy_attn_mask.unsqueeze(1).expand(-1, self.num_head, -1, -1)
new_attn_mask = torch.zeros_like(xy_attn_mask, dtype=x.dtype)
xy_attn_mask = new_attn_mask.masked_fill(xy_attn_mask, float("-inf"))
###### decode #####
y_list = [None]*y.shape[0]
batch_idx_map = list(range(y.shape[0]))
idx_list = [None]*y.shape[0]
for idx in tqdm(range(1500)):
if idx == 0:
xy_dec, k_cache, v_cache = self.t2s_transformer.process_prompt(xy_pos, xy_attn_mask)
else:
xy_dec, k_cache, v_cache = self.t2s_transformer.decode_next_token(xy_pos, k_cache, v_cache)
logits = self.ar_predict_layer(
xy_dec[:, -1]
)
if idx == 0:
xy_attn_mask = None
logits = logits[:, :-1]
samples = sample(
logits, y, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature
)[0]
y = torch.concat([y, samples], dim=1)
####### 移除batch中已经生成完毕的序列,进一步优化计算量
reserved_idx_of_batch_for_y = None
if (self.EOS in samples[:, 0]) or \
(self.EOS in torch.argmax(logits, dim=-1)): ###如果生成到EOS则停止
l = samples[:, 0]==self.EOS
removed_idx_of_batch_for_y = torch.where(l==True)[0].tolist()
reserved_idx_of_batch_for_y = torch.where(l==False)[0]
# batch_indexs = torch.tensor(batch_idx_map, device=y.device)[removed_idx_of_batch_for_y]
for i in removed_idx_of_batch_for_y:
batch_index = batch_idx_map[i]
idx_list[batch_index] = idx - 1
y_list[batch_index] = y[i, :-1]
batch_idx_map = [batch_idx_map[i] for i in reserved_idx_of_batch_for_y.tolist()]
# 只保留batch中未生成完毕的序列
if reserved_idx_of_batch_for_y is not None:
# index = torch.LongTensor(batch_idx_map).to(y.device)
y = torch.index_select(y, dim=0, index=reserved_idx_of_batch_for_y)
if k_cache is not None :
for i in range(len(k_cache)):
k_cache[i] = torch.index_select(k_cache[i], dim=0, index=reserved_idx_of_batch_for_y)
v_cache[i] = torch.index_select(v_cache[i], dim=0, index=reserved_idx_of_batch_for_y)
if (early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num) or idx==1499:
print("use early stop num:", early_stop_num)
stop = True
for i, batch_index in enumerate(batch_idx_map):
batch_index = batch_idx_map[i]
idx_list[batch_index] = idx
y_list[batch_index] = y[i, :-1]
if not (None in idx_list):
stop = True
if stop:
if y.shape[1]==0:
y = torch.concat([y, torch.zeros_like(samples)], dim=1)
print("bad zero prediction")
print(f"T2S Decoding EOS [{prefix_len} -> {y.shape[1]}]")
break
####################### update next step ###################################
y_emb = self.ar_audio_embedding(y[:, -1:])
xy_pos = y_emb * self.ar_audio_position.x_scale + self.ar_audio_position.alpha * self.ar_audio_position.pe[:, y_len + idx].to( dtype= y_emb.dtype,device=y_emb.device)
if (None in idx_list):
for i in range(x.shape[0]):
if idx_list[i] is None:
idx_list[i] = 1500-1 ###如果没有生成到EOS就用最大长度代替
if ref_free:
return y_list, [0]*x.shape[0]
return y_list, idx_list
def infer_panel_old(
self,
x, #####全部文本token
x_lens,
prompts, ####参考音频token
bert_feature,
top_k: int = -100,
top_p: int = 100,
early_stop_num: int = -1,
temperature: float = 1.0,
):
x = self.ar_text_embedding(x)
x = x + self.bert_proj(bert_feature.transpose(1, 2))
x = self.ar_text_position(x)
# AR Decoder
y = prompts
x_len = x.shape[1]
x_attn_mask = torch.zeros((x_len, x_len), dtype=torch.bool)
stop = False
# print(1111111,self.num_layers)
cache = {
"all_stage": self.num_layers,
"k": [None] * self.num_layers, ###根据配置自己手写
"v": [None] * self.num_layers,
# "xy_pos":None,##y_pos位置编码每次都不一样的没法缓存每次都要重新拼xy_pos.主要还是写法原因,其实是可以历史统一一样的,但也没啥计算量就不管了
"y_emb": None, ##只需要对最新的samples求emb再拼历史的就行
# "logits":None,###原版就已经只对结尾求再拼接了,不用管
# "xy_dec":None,###不需要本来只需要最后一个做logits
"first_infer": 1,
"stage": 0,
}
################### first step ##########################
if y is not None:
y_emb = self.ar_audio_embedding(y)
y_len = y_emb.shape[1]
prefix_len = y.shape[1]
y_pos = self.ar_audio_position(y_emb)
xy_pos = torch.concat([x, y_pos], dim=1)
cache["y_emb"] = y_emb
ref_free = False
else:
y_emb = None
y_len = 0
prefix_len = 0
y_pos = None
xy_pos = x
y = torch.zeros(x.shape[0], 0, dtype=torch.int, device=x.device)
ref_free = True
x_attn_mask_pad = F.pad(
x_attn_mask,
(0, y_len), ###xx的纯0扩展到xx纯0+xy纯1(x,x+y)
value=True,
)
y_attn_mask = F.pad( ###yy的右上1扩展到左边xy的0,(y,x+y)
torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1),
(x_len, 0),
value=False,
)
xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0).to(
x.device
)
for idx in tqdm(range(1500)):
xy_dec, _ = self.h((xy_pos, None), mask=xy_attn_mask, cache=cache)
logits = self.ar_predict_layer(
xy_dec[:, -1]
) ##不用改如果用了cache的默认就是只有一帧取最后一帧一样的
# samples = topk_sampling(logits, top_k=top_k, top_p=1.0, temperature=temperature)
if(idx==0):###第一次跑不能EOS否则没有了
logits = logits[:, :-1] ###刨除1024终止符号的概率
samples = sample(
logits[0], y, top_k=top_k, top_p=top_p, repetition_penalty=1.35, temperature=temperature
)[0].unsqueeze(0)
# 本次生成的 semantic_ids 和之前的 y 构成新的 y
# print(samples.shape)#[1,1]#第一个1是bs
y = torch.concat([y, samples], dim=1)
if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num:
print("use early stop num:", early_stop_num)
stop = True
if torch.argmax(logits, dim=-1)[0] == self.EOS or samples[0, 0] == self.EOS:
# print(torch.argmax(logits, dim=-1)[0] == self.EOS, samples[0, 0] == self.EOS)
stop = True
if stop:
# if prompts.shape[1] == y.shape[1]:
# y = torch.concat([y, torch.zeros_like(samples)], dim=1)
# print("bad zero prediction")
if y.shape[1]==0:
y = torch.concat([y, torch.zeros_like(samples)], dim=1)
print("bad zero prediction")
print(f"T2S Decoding EOS [{prefix_len} -> {y.shape[1]}]")
break
####################### update next step ###################################
cache["first_infer"] = 0
if cache["y_emb"] is not None:
y_emb = torch.cat(
[cache["y_emb"], self.ar_audio_embedding(y[:, -1:])], dim = 1
)
cache["y_emb"] = y_emb
y_pos = self.ar_audio_position(y_emb)
xy_pos = y_pos[:, -1:]
else:
y_emb = self.ar_audio_embedding(y[:, -1:])
cache["y_emb"] = y_emb
y_pos = self.ar_audio_position(y_emb)
xy_pos = y_pos
y_len = y_pos.shape[1]
###最右边一列(是错的)
# xy_attn_mask=torch.ones((1, x_len+y_len), dtype=torch.bool,device=xy_pos.device)
# xy_attn_mask[:,-1]=False
###最下面一行(是对的)
xy_attn_mask = torch.zeros(
(1, x_len + y_len), dtype=torch.bool, device=xy_pos.device
)
if ref_free:
return y[:, :-1], 0
return y[:, :-1], idx-1

View File

@ -11,7 +11,7 @@ def sequence_mask(length, max_length=None):
return x.unsqueeze(0) < length.unsqueeze(1) return x.unsqueeze(0) < length.unsqueeze(1)
def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor: def make_pad_mask(lengths: torch.Tensor, max_len: int = 0, padding_left:bool=False) -> torch.Tensor:
""" """
Args: Args:
lengths: lengths:
@ -35,8 +35,10 @@ def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
n = lengths.size(0) n = lengths.size(0)
seq_range = torch.arange(0, max_len, device=lengths.device) seq_range = torch.arange(0, max_len, device=lengths.device)
expaned_lengths = seq_range.unsqueeze(0).expand(n, max_len) expaned_lengths = seq_range.unsqueeze(0).expand(n, max_len)
if padding_left:
return expaned_lengths >= lengths.unsqueeze(-1) return expaned_lengths < (max_len-lengths.unsqueeze(-1))
else:
return expaned_lengths >= lengths.unsqueeze(-1)
# https://github.com/microsoft/unilm/blob/master/xtune/src/transformers/modeling_utils.py # https://github.com/microsoft/unilm/blob/master/xtune/src/transformers/modeling_utils.py

View File

@ -63,7 +63,7 @@ def set_seed(seed:int):
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.cuda.manual_seed(seed) torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed) torch.cuda.manual_seed_all(seed)
# torch.backends.cudnn.deterministic = True torch.backends.cudnn.deterministic = True
# torch.backends.cudnn.benchmark = False # torch.backends.cudnn.benchmark = False
# torch.backends.cudnn.enabled = True # torch.backends.cudnn.enabled = True
except: except:
@ -435,8 +435,7 @@ class TTS:
device:torch.device=torch.device("cpu"), device:torch.device=torch.device("cpu"),
precision:torch.dtype=torch.float32, precision:torch.dtype=torch.float32,
): ):
# 但是这里不能套,反而会负优化
# with torch.no_grad():
_data:list = [] _data:list = []
index_and_len_list = [] index_and_len_list = []
for idx, item in enumerate(data): for idx, item in enumerate(data):
@ -484,8 +483,7 @@ class TTS:
norm_text_batch = [] norm_text_batch = []
bert_max_len = 0 bert_max_len = 0
phones_max_len = 0 phones_max_len = 0
# 但是这里也不能套,反而会负优化
# with torch.no_grad():
for item in item_list: for item in item_list:
if prompt_data is not None: if prompt_data is not None:
all_bert_features = torch.cat([prompt_data["bert_features"], item["bert_features"]], 1)\ all_bert_features = torch.cat([prompt_data["bert_features"], item["bert_features"]], 1)\
@ -533,6 +531,12 @@ class TTS:
# all_bert_features_list = [F.pad(item,(0,0,0,max_len-item.shape[0]), value=0) for item in all_bert_features_list] # all_bert_features_list = [F.pad(item,(0,0,0,max_len-item.shape[0]), value=0) for item in all_bert_features_list]
# all_bert_features_batch = torch.stack(all_bert_features_list, dim=0) # all_bert_features_batch = torch.stack(all_bert_features_list, dim=0)
#### padding on left
# all_phones_list = [F.pad(item,(max_len-item.shape[0],0),value=0) for item in all_phones_list]
# all_phones_batch = torch.stack(all_phones_list, dim=0)
# all_bert_features_list = [F.pad(item,(max_len-item.shape[1],0,0,0), value=0) for item in all_bert_features_list]
# all_bert_features_batch = torch.stack(all_bert_features_list, dim=0)
batch = { batch = {
"phones": phones_batch, "phones": phones_batch,
"phones_len": torch.LongTensor(phones_len_list).to(device), "phones_len": torch.LongTensor(phones_len_list).to(device),
@ -569,7 +573,6 @@ class TTS:
''' '''
self.stop_flag = True self.stop_flag = True
# 使用装饰器
@torch.no_grad() @torch.no_grad()
def run(self, inputs:dict): def run(self, inputs:dict):
""" """
@ -586,9 +589,10 @@ class TTS:
"top_k": 5, # int. top k sampling "top_k": 5, # int. top k sampling
"top_p": 1, # float. top p sampling "top_p": 1, # float. top p sampling
"temperature": 1, # float. temperature for sampling "temperature": 1, # float. temperature for sampling
"repetition_penalty": 1.35, # float. repetition penalty for sampling of T2S model.
"text_split_method": "cut0", # str. text split method, see text_segmentation_method.py for details. "text_split_method": "cut0", # str. text split method, see text_segmentation_method.py for details.
"batch_size": 1, # int. batch size for inference "batch_size": 1, # int. batch size for inference
"batch_threshold": 0.75, # float. threshold for batch splitting. "batch_threshold": 1, # float. threshold for batch splitting.
"split_bucket: True, # bool. whether to split the batch into multiple buckets. "split_bucket: True, # bool. whether to split the batch into multiple buckets.
"return_fragment": False, # bool. step by step return the audio fragment. "return_fragment": False, # bool. step by step return the audio fragment.
"speed_factor":1.0, # float. control the speed of the synthesized audio. "speed_factor":1.0, # float. control the speed of the synthesized audio.
@ -608,6 +612,7 @@ class TTS:
top_k:int = inputs.get("top_k", 5) top_k:int = inputs.get("top_k", 5)
top_p:float = inputs.get("top_p", 1) top_p:float = inputs.get("top_p", 1)
temperature:float = inputs.get("temperature", 1) temperature:float = inputs.get("temperature", 1)
repetition_penalty: float = inputs.get("repetition_penalty", 1.35)
text_split_method:str = inputs.get("text_split_method", "cut0") text_split_method:str = inputs.get("text_split_method", "cut0")
batch_size = inputs.get("batch_size", 1) batch_size = inputs.get("batch_size", 1)
batch_threshold = inputs.get("batch_threshold", 0.75) batch_threshold = inputs.get("batch_threshold", 0.75)
@ -618,9 +623,16 @@ class TTS:
seed = inputs.get("seed", -1) seed = inputs.get("seed", -1)
seed = -1 if seed in ["", None] else seed seed = -1 if seed in ["", None] else seed
actual_seed = set_seed(seed) actual_seed = set_seed(seed)
padding_on_left = inputs.get("padding_on_left", False)
if padding_on_left:
print("padding on left")
self.t2s_model.model.infer_panel = self.t2s_model.model.infer_panel_batch_infer_with_flash_attn
else:
print("padding on right")
self.t2s_model.model.infer_panel = self.t2s_model.model.infer_panel_batch_infer_with_flash_attn_old
if return_fragment: if return_fragment:
# split_bucket = False
print(i18n("分段返回模式已开启")) print(i18n("分段返回模式已开启"))
if split_bucket: if split_bucket:
split_bucket = False split_bucket = False
@ -745,8 +757,7 @@ class TTS:
prompt = None prompt = None
else: else:
prompt = self.prompt_cache["prompt_semantic"].expand(len(all_phoneme_ids), -1).to(self.configs.device) prompt = self.prompt_cache["prompt_semantic"].expand(len(all_phoneme_ids), -1).to(self.configs.device)
pred_semantic_list, idx_list = self.t2s_model.model.infer_panel( pred_semantic_list, idx_list = self.t2s_model.model.infer_panel(
all_phoneme_ids, all_phoneme_ids,
all_phoneme_lens, all_phoneme_lens,
@ -756,7 +767,9 @@ class TTS:
top_k=top_k, top_k=top_k,
top_p=top_p, top_p=top_p,
temperature=temperature, temperature=temperature,
repetition_penalty = repetition_penalty,
early_stop_num=self.configs.hz * self.configs.max_sec, early_stop_num=self.configs.hz * self.configs.max_sec,
dtype = self.precision,
) )
t4 = ttime() t4 = ttime()
t_34 += t4 - t3 t_34 += t4 - t3

View File

@ -89,12 +89,14 @@ sovits_path = tts_config.vits_weights_path
def inference(text, text_lang, def inference(text, text_lang,
ref_audio_path, prompt_text, ref_audio_path, prompt_text,
prompt_lang, top_k, prompt_lang, top_k,
top_p, temperature, top_p, temperature, repetition_penalty,
text_split_method, batch_size, text_split_method, batch_size,
speed_factor, ref_text_free, speed_factor, ref_text_free,
split_bucket,fragment_interval, split_bucket, fragment_interval,
seed, seed, keep_random, padding_on_left
): ):
if keep_random:
seed = random.randrange(1 << 32)
actual_seed = seed if seed not in [-1, "", None] else random.randrange(1 << 32) actual_seed = seed if seed not in [-1, "", None] else random.randrange(1 << 32)
inputs={ inputs={
"text": text, "text": text,
@ -105,6 +107,7 @@ def inference(text, text_lang,
"top_k": top_k, "top_k": top_k,
"top_p": top_p, "top_p": top_p,
"temperature": temperature, "temperature": temperature,
"repetition_penalty": repetition_penalty,
"text_split_method": cut_method[text_split_method], "text_split_method": cut_method[text_split_method],
"batch_size":int(batch_size), "batch_size":int(batch_size),
"speed_factor":float(speed_factor), "speed_factor":float(speed_factor),
@ -112,6 +115,7 @@ def inference(text, text_lang,
"return_fragment":False, "return_fragment":False,
"fragment_interval":fragment_interval, "fragment_interval":fragment_interval,
"seed":actual_seed, "seed":actual_seed,
"padding_on_left":padding_on_left
} }
for item in tts_pipeline.run(inputs): for item in tts_pipeline.run(inputs):
yield item, actual_seed yield item, actual_seed
@ -197,6 +201,7 @@ with gr.Blocks(title="GPT-SoVITS WebUI") as app:
top_k = gr.Slider(minimum=1,maximum=100,step=1,label=i18n("top_k"),value=5,interactive=True) top_k = gr.Slider(minimum=1,maximum=100,step=1,label=i18n("top_k"),value=5,interactive=True)
top_p = gr.Slider(minimum=0,maximum=1,step=0.05,label=i18n("top_p"),value=1,interactive=True) top_p = gr.Slider(minimum=0,maximum=1,step=0.05,label=i18n("top_p"),value=1,interactive=True)
temperature = gr.Slider(minimum=0,maximum=1,step=0.05,label=i18n("temperature"),value=1,interactive=True) temperature = gr.Slider(minimum=0,maximum=1,step=0.05,label=i18n("temperature"),value=1,interactive=True)
repetition_penalty = gr.Slider(minimum=0.5,maximum=4.0,step=0.01,label=i18n("repetition_penalty"),value=1.35,interactive=True)
with gr.Column(): with gr.Column():
how_to_cut = gr.Radio( how_to_cut = gr.Radio(
label=i18n("怎么切"), label=i18n("怎么切"),
@ -207,6 +212,8 @@ with gr.Blocks(title="GPT-SoVITS WebUI") as app:
with gr.Row(): with gr.Row():
split_bucket = gr.Checkbox(label=i18n("数据分桶(可能会降低一点计算量,选就对了)"), value=True, interactive=True, show_label=True) split_bucket = gr.Checkbox(label=i18n("数据分桶(可能会降低一点计算量,选就对了)"), value=True, interactive=True, show_label=True)
seed = gr.Number(label=i18n("随机种子"),value=-1) seed = gr.Number(label=i18n("随机种子"),value=-1)
keep_random = gr.Checkbox(label=i18n("保持随机"), value=True, interactive=True, show_label=True)
padding_on_left = gr.Checkbox(label=i18n("左侧补齐"), value=True, interactive=True, show_label=True)
# with gr.Column(): # with gr.Column():
output = gr.Audio(label=i18n("输出的语音")) output = gr.Audio(label=i18n("输出的语音"))
with gr.Row(): with gr.Row():
@ -219,11 +226,11 @@ with gr.Blocks(title="GPT-SoVITS WebUI") as app:
[ [
text,text_language, inp_ref, text,text_language, inp_ref,
prompt_text, prompt_language, prompt_text, prompt_language,
top_k, top_p, temperature, top_k, top_p, temperature, repetition_penalty,
how_to_cut, batch_size, how_to_cut, batch_size,
speed_factor, ref_text_free, speed_factor, ref_text_free,
split_bucket,fragment_interval, split_bucket,fragment_interval,
seed seed, keep_random, padding_on_left
], ],
[output, seed], [output, seed],
) )

View File

@ -126,6 +126,7 @@ def main(args):
benchmark=False, benchmark=False,
fast_dev_run=False, fast_dev_run=False,
strategy = DDPStrategy( strategy = DDPStrategy(
find_unused_parameters=True,
process_group_backend="nccl" if platform.system() != "Windows" else "gloo" process_group_backend="nccl" if platform.system() != "Windows" else "gloo"
) if torch.cuda.is_available() else "auto", ) if torch.cuda.is_available() else "auto",
precision=config["train"]["precision"], precision=config["train"]["precision"],