mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-06-23 21:05:22 +08:00
Merge pull request #2449 from KamioRinn/maga
support v4 v2Pro v2ProPlus for api & optimize LangSegmenter
This commit is contained in:
commit
cd6de7398e
@ -159,6 +159,10 @@ class TextPreprocessor:
|
|||||||
textlist.append(tmp["text"])
|
textlist.append(tmp["text"])
|
||||||
else:
|
else:
|
||||||
for tmp in LangSegmenter.getTexts(text):
|
for tmp in LangSegmenter.getTexts(text):
|
||||||
|
if langlist:
|
||||||
|
if (tmp["lang"] == "en" and langlist[-1] == "en") or (tmp["lang"] != "en" and langlist[-1] != "en"):
|
||||||
|
textlist[-1] += tmp["text"]
|
||||||
|
continue
|
||||||
if tmp["lang"] == "en":
|
if tmp["lang"] == "en":
|
||||||
langlist.append(tmp["lang"])
|
langlist.append(tmp["lang"])
|
||||||
else:
|
else:
|
||||||
|
@ -623,6 +623,10 @@ def get_phones_and_bert(text, language, version, final=False):
|
|||||||
textlist.append(tmp["text"])
|
textlist.append(tmp["text"])
|
||||||
else:
|
else:
|
||||||
for tmp in LangSegmenter.getTexts(text):
|
for tmp in LangSegmenter.getTexts(text):
|
||||||
|
if langlist:
|
||||||
|
if (tmp["lang"] == "en" and langlist[-1] == "en") or (tmp["lang"] != "en" and langlist[-1] != "en"):
|
||||||
|
textlist[-1] += tmp["text"]
|
||||||
|
continue
|
||||||
if tmp["lang"] == "en":
|
if tmp["lang"] == "en":
|
||||||
langlist.append(tmp["lang"])
|
langlist.append(tmp["lang"])
|
||||||
else:
|
else:
|
||||||
|
264
api.py
264
api.py
@ -163,7 +163,7 @@ from transformers import AutoModelForMaskedLM, AutoTokenizer
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from feature_extractor import cnhubert
|
from feature_extractor import cnhubert
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from module.models import SynthesizerTrn, SynthesizerTrnV3
|
from module.models import Generator, SynthesizerTrn, SynthesizerTrnV3
|
||||||
from peft import LoraConfig, get_peft_model
|
from peft import LoraConfig, get_peft_model
|
||||||
from AR.models.t2s_lightning_module import Text2SemanticLightningModule
|
from AR.models.t2s_lightning_module import Text2SemanticLightningModule
|
||||||
from text import cleaned_text_to_sequence
|
from text import cleaned_text_to_sequence
|
||||||
@ -198,8 +198,38 @@ def is_full(*items): # 任意一项为空返回False
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
def init_bigvgan():
|
bigvgan_model = hifigan_model = sv_cn_model = None
|
||||||
|
def clean_hifigan_model():
|
||||||
|
global hifigan_model
|
||||||
|
if hifigan_model:
|
||||||
|
hifigan_model = hifigan_model.cpu()
|
||||||
|
hifigan_model = None
|
||||||
|
try:
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
def clean_bigvgan_model():
|
||||||
global bigvgan_model
|
global bigvgan_model
|
||||||
|
if bigvgan_model:
|
||||||
|
bigvgan_model = bigvgan_model.cpu()
|
||||||
|
bigvgan_model = None
|
||||||
|
try:
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
def clean_sv_cn_model():
|
||||||
|
global sv_cn_model
|
||||||
|
if sv_cn_model:
|
||||||
|
sv_cn_model.embedding_model = sv_cn_model.embedding_model.cpu()
|
||||||
|
sv_cn_model = None
|
||||||
|
try:
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def init_bigvgan():
|
||||||
|
global bigvgan_model, hifigan_model,sv_cn_model
|
||||||
from BigVGAN import bigvgan
|
from BigVGAN import bigvgan
|
||||||
|
|
||||||
bigvgan_model = bigvgan.BigVGAN.from_pretrained(
|
bigvgan_model = bigvgan.BigVGAN.from_pretrained(
|
||||||
@ -209,20 +239,53 @@ def init_bigvgan():
|
|||||||
# remove weight norm in the model and set to eval mode
|
# remove weight norm in the model and set to eval mode
|
||||||
bigvgan_model.remove_weight_norm()
|
bigvgan_model.remove_weight_norm()
|
||||||
bigvgan_model = bigvgan_model.eval()
|
bigvgan_model = bigvgan_model.eval()
|
||||||
|
|
||||||
if is_half == True:
|
if is_half == True:
|
||||||
bigvgan_model = bigvgan_model.half().to(device)
|
bigvgan_model = bigvgan_model.half().to(device)
|
||||||
else:
|
else:
|
||||||
bigvgan_model = bigvgan_model.to(device)
|
bigvgan_model = bigvgan_model.to(device)
|
||||||
|
|
||||||
|
|
||||||
resample_transform_dict = {}
|
def init_hifigan():
|
||||||
|
global hifigan_model, bigvgan_model,sv_cn_model
|
||||||
|
hifigan_model = Generator(
|
||||||
|
initial_channel=100,
|
||||||
|
resblock="1",
|
||||||
|
resblock_kernel_sizes=[3, 7, 11],
|
||||||
|
resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
||||||
|
upsample_rates=[10, 6, 2, 2, 2],
|
||||||
|
upsample_initial_channel=512,
|
||||||
|
upsample_kernel_sizes=[20, 12, 4, 4, 4],
|
||||||
|
gin_channels=0,
|
||||||
|
is_bias=True,
|
||||||
|
)
|
||||||
|
hifigan_model.eval()
|
||||||
|
hifigan_model.remove_weight_norm()
|
||||||
|
state_dict_g = torch.load(
|
||||||
|
"%s/GPT_SoVITS/pretrained_models/gsv-v4-pretrained/vocoder.pth" % (now_dir,), map_location="cpu", weights_only=False
|
||||||
|
)
|
||||||
|
print("loading vocoder", hifigan_model.load_state_dict(state_dict_g))
|
||||||
|
if is_half == True:
|
||||||
|
hifigan_model = hifigan_model.half().to(device)
|
||||||
|
else:
|
||||||
|
hifigan_model = hifigan_model.to(device)
|
||||||
|
|
||||||
|
|
||||||
def resample(audio_tensor, sr0):
|
from sv import SV
|
||||||
|
def init_sv_cn():
|
||||||
|
global hifigan_model, bigvgan_model, sv_cn_model
|
||||||
|
sv_cn_model = SV(device, is_half)
|
||||||
|
|
||||||
|
|
||||||
|
resample_transform_dict={}
|
||||||
|
def resample(audio_tensor, sr0,sr1,device):
|
||||||
global resample_transform_dict
|
global resample_transform_dict
|
||||||
if sr0 not in resample_transform_dict:
|
key="%s-%s-%s"%(sr0,sr1,str(device))
|
||||||
resample_transform_dict[sr0] = torchaudio.transforms.Resample(sr0, 24000).to(device)
|
if key not in resample_transform_dict:
|
||||||
return resample_transform_dict[sr0](audio_tensor)
|
resample_transform_dict[key] = torchaudio.transforms.Resample(
|
||||||
|
sr0, sr1
|
||||||
|
).to(device)
|
||||||
|
return resample_transform_dict[key](audio_tensor)
|
||||||
|
|
||||||
|
|
||||||
from module.mel_processing import mel_spectrogram_torch
|
from module.mel_processing import mel_spectrogram_torch
|
||||||
@ -252,6 +315,19 @@ mel_fn = lambda x: mel_spectrogram_torch(
|
|||||||
"center": False,
|
"center": False,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
mel_fn_v4 = lambda x: mel_spectrogram_torch(
|
||||||
|
x,
|
||||||
|
**{
|
||||||
|
"n_fft": 1280,
|
||||||
|
"win_size": 1280,
|
||||||
|
"hop_size": 320,
|
||||||
|
"num_mels": 100,
|
||||||
|
"sampling_rate": 32000,
|
||||||
|
"fmin": 0,
|
||||||
|
"fmax": None,
|
||||||
|
"center": False,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
sr_model = None
|
sr_model = None
|
||||||
@ -293,12 +369,18 @@ from process_ckpt import get_sovits_version_from_path_fast, load_sovits_new
|
|||||||
|
|
||||||
|
|
||||||
def get_sovits_weights(sovits_path):
|
def get_sovits_weights(sovits_path):
|
||||||
path_sovits_v3 = "GPT_SoVITS/pretrained_models/s2Gv3.pth"
|
from config import pretrained_sovits_name
|
||||||
|
path_sovits_v3 = pretrained_sovits_name["v3"]
|
||||||
|
path_sovits_v4 = pretrained_sovits_name["v4"]
|
||||||
is_exist_s2gv3 = os.path.exists(path_sovits_v3)
|
is_exist_s2gv3 = os.path.exists(path_sovits_v3)
|
||||||
|
is_exist_s2gv4 = os.path.exists(path_sovits_v4)
|
||||||
|
|
||||||
version, model_version, if_lora_v3 = get_sovits_version_from_path_fast(sovits_path)
|
version, model_version, if_lora_v3 = get_sovits_version_from_path_fast(sovits_path)
|
||||||
if if_lora_v3 == True and is_exist_s2gv3 == False:
|
is_exist = is_exist_s2gv3 if model_version == "v3" else is_exist_s2gv4
|
||||||
logger.info("SoVITS V3 底模缺失,无法加载相应 LoRA 权重")
|
path_sovits = path_sovits_v3 if model_version == "v3" else path_sovits_v4
|
||||||
|
|
||||||
|
if if_lora_v3 == True and is_exist == False:
|
||||||
|
logger.info("SoVITS %s 底模缺失,无法加载相应 LoRA 权重" % model_version)
|
||||||
|
|
||||||
dict_s2 = load_sovits_new(sovits_path)
|
dict_s2 = load_sovits_new(sovits_path)
|
||||||
hps = dict_s2["config"]
|
hps = dict_s2["config"]
|
||||||
@ -311,11 +393,13 @@ def get_sovits_weights(sovits_path):
|
|||||||
else:
|
else:
|
||||||
hps.model.version = "v2"
|
hps.model.version = "v2"
|
||||||
|
|
||||||
if model_version == "v3":
|
|
||||||
hps.model.version = "v3"
|
|
||||||
|
|
||||||
model_params_dict = vars(hps.model)
|
model_params_dict = vars(hps.model)
|
||||||
if model_version != "v3":
|
if model_version not in {"v3", "v4"}:
|
||||||
|
if "Pro" in model_version:
|
||||||
|
hps.model.version = model_version
|
||||||
|
if sv_cn_model == None:
|
||||||
|
init_sv_cn()
|
||||||
|
|
||||||
vq_model = SynthesizerTrn(
|
vq_model = SynthesizerTrn(
|
||||||
hps.data.filter_length // 2 + 1,
|
hps.data.filter_length // 2 + 1,
|
||||||
hps.train.segment_size // hps.data.hop_length,
|
hps.train.segment_size // hps.data.hop_length,
|
||||||
@ -323,13 +407,18 @@ def get_sovits_weights(sovits_path):
|
|||||||
**model_params_dict,
|
**model_params_dict,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
hps.model.version = model_version
|
||||||
vq_model = SynthesizerTrnV3(
|
vq_model = SynthesizerTrnV3(
|
||||||
hps.data.filter_length // 2 + 1,
|
hps.data.filter_length // 2 + 1,
|
||||||
hps.train.segment_size // hps.data.hop_length,
|
hps.train.segment_size // hps.data.hop_length,
|
||||||
n_speakers=hps.data.n_speakers,
|
n_speakers=hps.data.n_speakers,
|
||||||
**model_params_dict,
|
**model_params_dict,
|
||||||
)
|
)
|
||||||
init_bigvgan()
|
if model_version == "v3":
|
||||||
|
init_bigvgan()
|
||||||
|
if model_version == "v4":
|
||||||
|
init_hifigan()
|
||||||
|
|
||||||
model_version = hps.model.version
|
model_version = hps.model.version
|
||||||
logger.info(f"模型版本: {model_version}")
|
logger.info(f"模型版本: {model_version}")
|
||||||
if "pretrained" not in sovits_path:
|
if "pretrained" not in sovits_path:
|
||||||
@ -345,7 +434,8 @@ def get_sovits_weights(sovits_path):
|
|||||||
if if_lora_v3 == False:
|
if if_lora_v3 == False:
|
||||||
vq_model.load_state_dict(dict_s2["weight"], strict=False)
|
vq_model.load_state_dict(dict_s2["weight"], strict=False)
|
||||||
else:
|
else:
|
||||||
vq_model.load_state_dict(load_sovits_new(path_sovits_v3)["weight"], strict=False)
|
path_sovits = path_sovits_v3 if model_version == "v3" else path_sovits_v4
|
||||||
|
vq_model.load_state_dict(load_sovits_new(path_sovits)["weight"], strict=False)
|
||||||
lora_rank = dict_s2["lora_rank"]
|
lora_rank = dict_s2["lora_rank"]
|
||||||
lora_config = LoraConfig(
|
lora_config = LoraConfig(
|
||||||
target_modules=["to_k", "to_q", "to_v", "to_out.0"],
|
target_modules=["to_k", "to_q", "to_v", "to_out.0"],
|
||||||
@ -479,6 +569,10 @@ def get_phones_and_bert(text, language, version, final=False):
|
|||||||
textlist.append(tmp["text"])
|
textlist.append(tmp["text"])
|
||||||
else:
|
else:
|
||||||
for tmp in LangSegmenter.getTexts(text):
|
for tmp in LangSegmenter.getTexts(text):
|
||||||
|
if langlist:
|
||||||
|
if (tmp["lang"] == "en" and langlist[-1] == "en") or (tmp["lang"] != "en" and langlist[-1] != "en"):
|
||||||
|
textlist[-1] += tmp["text"]
|
||||||
|
continue
|
||||||
if tmp["lang"] == "en":
|
if tmp["lang"] == "en":
|
||||||
langlist.append(tmp["lang"])
|
langlist.append(tmp["lang"])
|
||||||
else:
|
else:
|
||||||
@ -533,23 +627,32 @@ class DictToAttrRecursive(dict):
|
|||||||
raise AttributeError(f"Attribute {item} not found")
|
raise AttributeError(f"Attribute {item} not found")
|
||||||
|
|
||||||
|
|
||||||
def get_spepc(hps, filename):
|
def get_spepc(hps, filename, dtype, device, is_v2pro=False):
|
||||||
audio, _ = librosa.load(filename, sr=int(hps.data.sampling_rate))
|
sr1=int(hps.data.sampling_rate)
|
||||||
audio = torch.FloatTensor(audio)
|
audio, sr0=torchaudio.load(filename)
|
||||||
|
if sr0!=sr1:
|
||||||
|
audio=audio.to(device)
|
||||||
|
if(audio.shape[0]==2):audio=audio.mean(0).unsqueeze(0)
|
||||||
|
audio=resample(audio,sr0,sr1,device)
|
||||||
|
else:
|
||||||
|
audio=audio.to(device)
|
||||||
|
if(audio.shape[0]==2):audio=audio.mean(0).unsqueeze(0)
|
||||||
|
|
||||||
maxx = audio.abs().max()
|
maxx = audio.abs().max()
|
||||||
if maxx > 1:
|
if maxx > 1:
|
||||||
audio /= min(2, maxx)
|
audio /= min(2, maxx)
|
||||||
audio_norm = audio
|
|
||||||
audio_norm = audio_norm.unsqueeze(0)
|
|
||||||
spec = spectrogram_torch(
|
spec = spectrogram_torch(
|
||||||
audio_norm,
|
audio,
|
||||||
hps.data.filter_length,
|
hps.data.filter_length,
|
||||||
hps.data.sampling_rate,
|
hps.data.sampling_rate,
|
||||||
hps.data.hop_length,
|
hps.data.hop_length,
|
||||||
hps.data.win_length,
|
hps.data.win_length,
|
||||||
center=False,
|
center=False,
|
||||||
)
|
)
|
||||||
return spec
|
spec=spec.to(dtype)
|
||||||
|
if is_v2pro==True:
|
||||||
|
audio=resample(audio,sr1,16000,device).to(dtype)
|
||||||
|
return spec,audio
|
||||||
|
|
||||||
|
|
||||||
def pack_audio(audio_bytes, data, rate):
|
def pack_audio(audio_bytes, data, rate):
|
||||||
@ -736,6 +839,16 @@ def get_tts_wav(
|
|||||||
t2s_model = infer_gpt.t2s_model
|
t2s_model = infer_gpt.t2s_model
|
||||||
max_sec = infer_gpt.max_sec
|
max_sec = infer_gpt.max_sec
|
||||||
|
|
||||||
|
if version == "v3":
|
||||||
|
if sample_steps not in [4, 8, 16, 32, 64, 128]:
|
||||||
|
sample_steps = 32
|
||||||
|
elif version == "v4":
|
||||||
|
if sample_steps not in [4, 8, 16, 32]:
|
||||||
|
sample_steps = 8
|
||||||
|
|
||||||
|
if if_sr and version != "v3":
|
||||||
|
if_sr = False
|
||||||
|
|
||||||
t0 = ttime()
|
t0 = ttime()
|
||||||
prompt_text = prompt_text.strip("\n")
|
prompt_text = prompt_text.strip("\n")
|
||||||
if prompt_text[-1] not in splits:
|
if prompt_text[-1] not in splits:
|
||||||
@ -759,19 +872,29 @@ def get_tts_wav(
|
|||||||
prompt_semantic = codes[0, 0]
|
prompt_semantic = codes[0, 0]
|
||||||
prompt = prompt_semantic.unsqueeze(0).to(device)
|
prompt = prompt_semantic.unsqueeze(0).to(device)
|
||||||
|
|
||||||
if version != "v3":
|
is_v2pro = version in {"v2Pro","v2ProPlus"}
|
||||||
|
if version not in {"v3", "v4"}:
|
||||||
refers = []
|
refers = []
|
||||||
|
if is_v2pro:
|
||||||
|
sv_emb= []
|
||||||
|
if sv_cn_model == None:
|
||||||
|
init_sv_cn()
|
||||||
if inp_refs:
|
if inp_refs:
|
||||||
for path in inp_refs:
|
for path in inp_refs:
|
||||||
try:
|
try:#####这里加上提取sv的逻辑,要么一堆sv一堆refer,要么单个sv单个refer
|
||||||
refer = get_spepc(hps, path).to(dtype).to(device)
|
refer,audio_tensor = get_spepc(hps, path.name, dtype, device, is_v2pro)
|
||||||
refers.append(refer)
|
refers.append(refer)
|
||||||
|
if is_v2pro:
|
||||||
|
sv_emb.append(sv_cn_model.compute_embedding3(audio_tensor))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(e)
|
logger.error(e)
|
||||||
if len(refers) == 0:
|
if len(refers) == 0:
|
||||||
refers = [get_spepc(hps, ref_wav_path).to(dtype).to(device)]
|
refers,audio_tensor = get_spepc(hps, ref_wav_path, dtype, device, is_v2pro)
|
||||||
|
refers=[refers]
|
||||||
|
if is_v2pro:
|
||||||
|
sv_emb=[sv_cn_model.compute_embedding3(audio_tensor)]
|
||||||
else:
|
else:
|
||||||
refer = get_spepc(hps, ref_wav_path).to(device).to(dtype)
|
refer,audio_tensor = get_spepc(hps, ref_wav_path, dtype, device)
|
||||||
|
|
||||||
t1 = ttime()
|
t1 = ttime()
|
||||||
# os.environ['version'] = version
|
# os.environ['version'] = version
|
||||||
@ -811,41 +934,48 @@ def get_tts_wav(
|
|||||||
pred_semantic = pred_semantic[:, -idx:].unsqueeze(0)
|
pred_semantic = pred_semantic[:, -idx:].unsqueeze(0)
|
||||||
t3 = ttime()
|
t3 = ttime()
|
||||||
|
|
||||||
if version != "v3":
|
if version not in {"v3", "v4"}:
|
||||||
audio = (
|
if is_v2pro:
|
||||||
vq_model.decode(pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0), refers, speed=speed)
|
audio = (
|
||||||
.detach()
|
vq_model.decode(pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0), refers, speed=speed,sv_emb=sv_emb)
|
||||||
.cpu()
|
.detach()
|
||||||
.numpy()[0, 0]
|
.cpu()
|
||||||
) ###试试重建不带上prompt部分
|
.numpy()[0, 0]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
audio = (
|
||||||
|
vq_model.decode(pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0), refers, speed=speed)
|
||||||
|
.detach()
|
||||||
|
.cpu()
|
||||||
|
.numpy()[0, 0]
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
phoneme_ids0 = torch.LongTensor(phones1).to(device).unsqueeze(0)
|
phoneme_ids0 = torch.LongTensor(phones1).to(device).unsqueeze(0)
|
||||||
phoneme_ids1 = torch.LongTensor(phones2).to(device).unsqueeze(0)
|
phoneme_ids1 = torch.LongTensor(phones2).to(device).unsqueeze(0)
|
||||||
# print(11111111, phoneme_ids0, phoneme_ids1)
|
|
||||||
fea_ref, ge = vq_model.decode_encp(prompt.unsqueeze(0), phoneme_ids0, refer)
|
fea_ref, ge = vq_model.decode_encp(prompt.unsqueeze(0), phoneme_ids0, refer)
|
||||||
ref_audio, sr = torchaudio.load(ref_wav_path)
|
ref_audio, sr = torchaudio.load(ref_wav_path)
|
||||||
ref_audio = ref_audio.to(device).float()
|
ref_audio = ref_audio.to(device).float()
|
||||||
if ref_audio.shape[0] == 2:
|
if ref_audio.shape[0] == 2:
|
||||||
ref_audio = ref_audio.mean(0).unsqueeze(0)
|
ref_audio = ref_audio.mean(0).unsqueeze(0)
|
||||||
if sr != 24000:
|
|
||||||
ref_audio = resample(ref_audio, sr)
|
tgt_sr = 24000 if version == "v3" else 32000
|
||||||
# print("ref_audio",ref_audio.abs().mean())
|
if sr != tgt_sr:
|
||||||
mel2 = mel_fn(ref_audio)
|
ref_audio = resample(ref_audio, sr, tgt_sr, device)
|
||||||
|
mel2 = mel_fn(ref_audio) if version == "v3" else mel_fn_v4(ref_audio)
|
||||||
mel2 = norm_spec(mel2)
|
mel2 = norm_spec(mel2)
|
||||||
T_min = min(mel2.shape[2], fea_ref.shape[2])
|
T_min = min(mel2.shape[2], fea_ref.shape[2])
|
||||||
mel2 = mel2[:, :, :T_min]
|
mel2 = mel2[:, :, :T_min]
|
||||||
fea_ref = fea_ref[:, :, :T_min]
|
fea_ref = fea_ref[:, :, :T_min]
|
||||||
if T_min > 468:
|
Tref = 468 if version == "v3" else 500
|
||||||
mel2 = mel2[:, :, -468:]
|
Tchunk = 934 if version == "v3" else 1000
|
||||||
fea_ref = fea_ref[:, :, -468:]
|
if T_min > Tref:
|
||||||
T_min = 468
|
mel2 = mel2[:, :, -Tref:]
|
||||||
chunk_len = 934 - T_min
|
fea_ref = fea_ref[:, :, -Tref:]
|
||||||
# print("fea_ref",fea_ref,fea_ref.shape)
|
T_min = Tref
|
||||||
# print("mel2",mel2)
|
chunk_len = Tchunk - T_min
|
||||||
mel2 = mel2.to(dtype)
|
mel2 = mel2.to(dtype)
|
||||||
fea_todo, ge = vq_model.decode_encp(pred_semantic, phoneme_ids1, refer, ge, speed)
|
fea_todo, ge = vq_model.decode_encp(pred_semantic, phoneme_ids1, refer, ge, speed)
|
||||||
# print("fea_todo",fea_todo)
|
|
||||||
# print("ge",ge.abs().mean())
|
|
||||||
cfm_resss = []
|
cfm_resss = []
|
||||||
idx = 0
|
idx = 0
|
||||||
while 1:
|
while 1:
|
||||||
@ -854,22 +984,24 @@ def get_tts_wav(
|
|||||||
break
|
break
|
||||||
idx += chunk_len
|
idx += chunk_len
|
||||||
fea = torch.cat([fea_ref, fea_todo_chunk], 2).transpose(2, 1)
|
fea = torch.cat([fea_ref, fea_todo_chunk], 2).transpose(2, 1)
|
||||||
# set_seed(123)
|
|
||||||
cfm_res = vq_model.cfm.inference(
|
cfm_res = vq_model.cfm.inference(
|
||||||
fea, torch.LongTensor([fea.size(1)]).to(fea.device), mel2, sample_steps, inference_cfg_rate=0
|
fea, torch.LongTensor([fea.size(1)]).to(fea.device), mel2, sample_steps, inference_cfg_rate=0
|
||||||
)
|
)
|
||||||
cfm_res = cfm_res[:, :, mel2.shape[2] :]
|
cfm_res = cfm_res[:, :, mel2.shape[2] :]
|
||||||
mel2 = cfm_res[:, :, -T_min:]
|
mel2 = cfm_res[:, :, -T_min:]
|
||||||
# print("fea", fea)
|
|
||||||
# print("mel2in", mel2)
|
|
||||||
fea_ref = fea_todo_chunk[:, :, -T_min:]
|
fea_ref = fea_todo_chunk[:, :, -T_min:]
|
||||||
cfm_resss.append(cfm_res)
|
cfm_resss.append(cfm_res)
|
||||||
cmf_res = torch.cat(cfm_resss, 2)
|
cfm_res = torch.cat(cfm_resss, 2)
|
||||||
cmf_res = denorm_spec(cmf_res)
|
cfm_res = denorm_spec(cfm_res)
|
||||||
if bigvgan_model == None:
|
if version == "v3":
|
||||||
init_bigvgan()
|
if bigvgan_model == None:
|
||||||
|
init_bigvgan()
|
||||||
|
else: # v4
|
||||||
|
if hifigan_model == None:
|
||||||
|
init_hifigan()
|
||||||
|
vocoder_model = bigvgan_model if version == "v3" else hifigan_model
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
wav_gen = bigvgan_model(cmf_res)
|
wav_gen = vocoder_model(cfm_res)
|
||||||
audio = wav_gen[0][0].cpu().detach().numpy()
|
audio = wav_gen[0][0].cpu().detach().numpy()
|
||||||
|
|
||||||
max_audio = np.abs(audio).max()
|
max_audio = np.abs(audio).max()
|
||||||
@ -880,7 +1012,13 @@ def get_tts_wav(
|
|||||||
audio_opt = np.concatenate(audio_opt, 0)
|
audio_opt = np.concatenate(audio_opt, 0)
|
||||||
t4 = ttime()
|
t4 = ttime()
|
||||||
|
|
||||||
sr = hps.data.sampling_rate if version != "v3" else 24000
|
if version in {"v1", "v2", "v2Pro", "v2ProPlus"}:
|
||||||
|
sr = 32000
|
||||||
|
elif version == "v3":
|
||||||
|
sr = 24000
|
||||||
|
else:
|
||||||
|
sr = 48000 # v4
|
||||||
|
|
||||||
if if_sr and sr == 24000:
|
if if_sr and sr == 24000:
|
||||||
audio_opt = torch.from_numpy(audio_opt).float().to(device)
|
audio_opt = torch.from_numpy(audio_opt).float().to(device)
|
||||||
audio_opt, sr = audio_sr(audio_opt.unsqueeze(0), sr)
|
audio_opt, sr = audio_sr(audio_opt.unsqueeze(0), sr)
|
||||||
@ -900,8 +1038,12 @@ def get_tts_wav(
|
|||||||
|
|
||||||
if not stream_mode == "normal":
|
if not stream_mode == "normal":
|
||||||
if media_type == "wav":
|
if media_type == "wav":
|
||||||
sr = 48000 if if_sr else 24000
|
if version in {"v1", "v2", "v2Pro", "v2ProPlus"}:
|
||||||
sr = hps.data.sampling_rate if version != "v3" else sr
|
sr = 32000
|
||||||
|
elif version == "v3":
|
||||||
|
sr = 48000 if if_sr else 24000
|
||||||
|
else:
|
||||||
|
sr = 48000 # v4
|
||||||
audio_bytes = pack_wav(audio_bytes, sr)
|
audio_bytes = pack_wav(audio_bytes, sr)
|
||||||
yield audio_bytes.getvalue()
|
yield audio_bytes.getvalue()
|
||||||
|
|
||||||
@ -966,8 +1108,6 @@ def handle(
|
|||||||
if not default_refer.is_ready():
|
if not default_refer.is_ready():
|
||||||
return JSONResponse({"code": 400, "message": "未指定参考音频且接口无预设"}, status_code=400)
|
return JSONResponse({"code": 400, "message": "未指定参考音频且接口无预设"}, status_code=400)
|
||||||
|
|
||||||
if sample_steps not in [4, 8, 16, 32]:
|
|
||||||
sample_steps = 32
|
|
||||||
|
|
||||||
if cut_punc == None:
|
if cut_punc == None:
|
||||||
text = cut_text(text, default_cut_punc)
|
text = cut_text(text, default_cut_punc)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user