mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-04-06 03:57:44 +08:00
Add files via upload
This commit is contained in:
parent
48da30c4af
commit
ea62d6e0cf
@ -115,7 +115,6 @@ vq_model.eval()
|
|||||||
print(vq_model.load_state_dict(dict_s2["weight"], strict=False))
|
print(vq_model.load_state_dict(dict_s2["weight"], strict=False))
|
||||||
hz = 50
|
hz = 50
|
||||||
max_sec = config["data"]["max_sec"]
|
max_sec = config["data"]["max_sec"]
|
||||||
# t2s_model = Text2SemanticLightningModule.load_from_checkpoint(checkpoint_path=gpt_path, config=config, map_location="cpu")#########todo
|
|
||||||
t2s_model = Text2SemanticLightningModule(config, "ojbk", is_train=False)
|
t2s_model = Text2SemanticLightningModule(config, "ojbk", is_train=False)
|
||||||
t2s_model.load_state_dict(dict_s1["weight"])
|
t2s_model.load_state_dict(dict_s1["weight"])
|
||||||
if is_half == True:
|
if is_half == True:
|
||||||
@ -149,13 +148,21 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language)
|
|||||||
t0 = ttime()
|
t0 = ttime()
|
||||||
prompt_text = prompt_text.strip("\n")
|
prompt_text = prompt_text.strip("\n")
|
||||||
prompt_language, text = prompt_language, text.strip("\n")
|
prompt_language, text = prompt_language, text.strip("\n")
|
||||||
|
zero_wav = np.zeros(
|
||||||
|
int(hps.data.sampling_rate * 0.3),
|
||||||
|
dtype=np.float16 if is_half == True else np.float32,
|
||||||
|
)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
wav16k, sr = librosa.load(ref_wav_path, sr=16000) # 派蒙
|
wav16k, sr = librosa.load(ref_wav_path, sr=16000)
|
||||||
wav16k = torch.from_numpy(wav16k)
|
wav16k = torch.from_numpy(wav16k)
|
||||||
|
zero_wav_torch = torch.from_numpy(zero_wav)
|
||||||
if is_half == True:
|
if is_half == True:
|
||||||
wav16k = wav16k.half().to(device)
|
wav16k = wav16k.half().to(device)
|
||||||
|
zero_wav_torch = zero_wav_torch.half().to(device)
|
||||||
else:
|
else:
|
||||||
wav16k = wav16k.to(device)
|
wav16k = wav16k.to(device)
|
||||||
|
zero_wav_torch = zero_wav_torch.to(device)
|
||||||
|
wav16k=torch.cat([wav16k,zero_wav_torch])
|
||||||
ssl_content = ssl_model.model(wav16k.unsqueeze(0))[
|
ssl_content = ssl_model.model(wav16k.unsqueeze(0))[
|
||||||
"last_hidden_state"
|
"last_hidden_state"
|
||||||
].transpose(
|
].transpose(
|
||||||
@ -170,10 +177,6 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language)
|
|||||||
phones1 = cleaned_text_to_sequence(phones1)
|
phones1 = cleaned_text_to_sequence(phones1)
|
||||||
texts = text.split("\n")
|
texts = text.split("\n")
|
||||||
audio_opt = []
|
audio_opt = []
|
||||||
zero_wav = np.zeros(
|
|
||||||
int(hps.data.sampling_rate * 0.3),
|
|
||||||
dtype=np.float16 if is_half == True else np.float32,
|
|
||||||
)
|
|
||||||
for text in texts:
|
for text in texts:
|
||||||
# 解决输入目标文本的空行导致报错的问题
|
# 解决输入目标文本的空行导致报错的问题
|
||||||
if (len(text.strip()) == 0):
|
if (len(text.strip()) == 0):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user