mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-10-06 14:40:00 +08:00
Generate & return subtitles with the audio.
生成与音频同步的字幕并返回: - TTS_infer_pack/TTS.py 生成与音频对应的字幕信息 - api_v2.py /tts 接口可用JSON同时返回生成的音频(转为base64)和字幕 - 通过参数控制是否生成字幕,默认关闭,不影响其他模块
This commit is contained in:
parent
0a17694ede
commit
27664703d2
@ -550,6 +550,7 @@ class TTS:
|
||||
all_phones_len_list = []
|
||||
all_bert_features_list = []
|
||||
norm_text_batch = []
|
||||
origin_text_batch = []
|
||||
all_bert_max_len = 0
|
||||
all_phones_max_len = 0
|
||||
for item in item_list:
|
||||
@ -575,6 +576,7 @@ class TTS:
|
||||
all_phones_len_list.append(all_phones.shape[-1])
|
||||
all_bert_features_list.append(all_bert_features)
|
||||
norm_text_batch.append(item["norm_text"])
|
||||
origin_text_batch.append(item["origin_text"])
|
||||
|
||||
phones_batch = phones_list
|
||||
all_phones_batch = all_phones_list
|
||||
@ -606,6 +608,7 @@ class TTS:
|
||||
"all_phones_len": torch.LongTensor(all_phones_len_list).to(device),
|
||||
"all_bert_features": all_bert_features_batch,
|
||||
"norm_text": norm_text_batch,
|
||||
"origin_text": origin_text_batch,
|
||||
"max_len": max_len,
|
||||
}
|
||||
_data.append(batch)
|
||||
@ -658,6 +661,7 @@ class TTS:
|
||||
"batch_threshold": 0.75, # float. threshold for batch splitting.
|
||||
"split_bucket: True, # bool. whether to split the batch into multiple buckets.
|
||||
"return_fragment": False, # bool. step by step return the audio fragment.
|
||||
"return_with_srt": "", # str. return with or without("") subtitles, using "orig"inal or "norm"alized text
|
||||
"speed_factor":1.0, # float. control the speed of the synthesized audio.
|
||||
"fragment_interval":0.3, # float. to control the interval of the audio fragment.
|
||||
"seed": -1, # int. random seed for reproducibility.
|
||||
@ -685,6 +689,7 @@ class TTS:
|
||||
split_bucket = inputs.get("split_bucket", True)
|
||||
return_fragment = inputs.get("return_fragment", False)
|
||||
fragment_interval = inputs.get("fragment_interval", 0.3)
|
||||
return_with_srt = inputs.get("return_with_srt", "")
|
||||
seed = inputs.get("seed", -1)
|
||||
seed = -1 if seed in ["", None] else seed
|
||||
actual_seed = set_seed(seed)
|
||||
@ -704,6 +709,9 @@ class TTS:
|
||||
split_bucket = False
|
||||
print(i18n("分段返回模式不支持分桶处理,已自动关闭分桶处理"))
|
||||
|
||||
ret_width = 3 if return_with_srt else 2 # return (sr, audio, srt) or (sr, audio)
|
||||
srt_text = "norm_text" if return_with_srt.startswith("norm") else "origin_text"
|
||||
|
||||
if split_bucket and speed_factor==1.0:
|
||||
print(i18n("分桶处理模式已开启"))
|
||||
elif speed_factor!=1.0:
|
||||
@ -773,8 +781,7 @@ class TTS:
|
||||
if not return_fragment:
|
||||
data = self.text_preprocessor.preprocess(text, text_lang, text_split_method, self.configs.version)
|
||||
if len(data) == 0:
|
||||
yield self.configs.sampling_rate, np.zeros(int(self.configs.sampling_rate),
|
||||
dtype=np.int16)
|
||||
yield self.audio_failure()[:ret_width]
|
||||
return
|
||||
|
||||
batch_index_list:list = None
|
||||
@ -806,6 +813,7 @@ class TTS:
|
||||
"phones": phones,
|
||||
"bert_features": bert_features,
|
||||
"norm_text": norm_text,
|
||||
"origin_text": text,
|
||||
}
|
||||
batch_data.append(res)
|
||||
if len(batch_data) == 0:
|
||||
@ -841,10 +849,11 @@ class TTS:
|
||||
all_phoneme_ids:torch.LongTensor = item["all_phones"]
|
||||
all_phoneme_lens:torch.LongTensor = item["all_phones_len"]
|
||||
all_bert_features:torch.LongTensor = item["all_bert_features"]
|
||||
norm_text:str = item["norm_text"]
|
||||
# norm_text:List[str] = item["norm_text"]
|
||||
# origin_text:List[str] = item["origin_text"]
|
||||
max_len = item["max_len"]
|
||||
|
||||
print(i18n("前端处理后的文本(每句):"), norm_text)
|
||||
print(i18n("前端处理后的文本(每批):"), item["norm_text"])
|
||||
if no_prompt_text :
|
||||
prompt = None
|
||||
else:
|
||||
@ -915,39 +924,38 @@ class TTS:
|
||||
if return_fragment:
|
||||
print("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t4 - t3, t5 - t4))
|
||||
yield self.audio_postprocess([batch_audio_fragment],
|
||||
[item[srt_text]],
|
||||
self.configs.sampling_rate,
|
||||
None,
|
||||
speed_factor,
|
||||
False,
|
||||
fragment_interval
|
||||
)
|
||||
)[:ret_width]
|
||||
else:
|
||||
audio.append(batch_audio_fragment)
|
||||
|
||||
if self.stop_flag:
|
||||
yield self.configs.sampling_rate, np.zeros(int(self.configs.sampling_rate),
|
||||
dtype=np.int16)
|
||||
yield self.audio_failure()[:ret_width]
|
||||
return
|
||||
|
||||
if not return_fragment:
|
||||
print("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t_34, t_45))
|
||||
if len(audio) == 0:
|
||||
yield self.configs.sampling_rate, np.zeros(int(self.configs.sampling_rate),
|
||||
dtype=np.int16)
|
||||
yield self.audio_failure()[:ret_width]
|
||||
return
|
||||
yield self.audio_postprocess(audio,
|
||||
[v[srt_text] for v in data],
|
||||
self.configs.sampling_rate,
|
||||
batch_index_list,
|
||||
speed_factor,
|
||||
split_bucket,
|
||||
fragment_interval
|
||||
)
|
||||
)[:ret_width]
|
||||
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
# 必须返回一个空音频, 否则会导致显存不释放。
|
||||
yield self.configs.sampling_rate, np.zeros(int(self.configs.sampling_rate),
|
||||
dtype=np.int16)
|
||||
yield self.audio_failure()[:ret_width]
|
||||
# 重置模型, 否则会导致显存释放不完全。
|
||||
del self.t2s_model
|
||||
del self.vits_model
|
||||
@ -969,14 +977,18 @@ class TTS:
|
||||
except:
|
||||
pass
|
||||
|
||||
def audio_failure(self):
|
||||
return self.configs.sampling_rate, np.zeros(int(self.configs.sampling_rate), dtype=np.int16), []
|
||||
|
||||
def audio_postprocess(self,
|
||||
audio:List[torch.Tensor],
|
||||
texts:List[List[str]],
|
||||
sr:int,
|
||||
batch_index_list:list=None,
|
||||
speed_factor:float=1.0,
|
||||
split_bucket:bool=True,
|
||||
fragment_interval:float=0.3
|
||||
)->Tuple[int, np.ndarray]:
|
||||
)->Tuple[int, np.ndarray, List]:
|
||||
zero_wav = torch.zeros(
|
||||
int(self.configs.sampling_rate * fragment_interval),
|
||||
dtype=self.precision,
|
||||
@ -993,10 +1005,16 @@ class TTS:
|
||||
|
||||
if split_bucket:
|
||||
audio = self.recovery_order(audio, batch_index_list)
|
||||
texts = self.recovery_order(texts, batch_index_list)
|
||||
else:
|
||||
# audio = [item for batch in audio for item in batch]
|
||||
audio = sum(audio, [])
|
||||
texts = sum(texts, [])
|
||||
|
||||
# 按顺序计算每段语音的起止时间,并与文字一一对应,用于生成字幕
|
||||
from itertools import accumulate
|
||||
stamps = [0.0] + [x/sr for x in accumulate([v.size for v in audio])]
|
||||
srts = list(zip(stamps[:-1], stamps[1:], texts)) # time start, end, text
|
||||
|
||||
audio = np.concatenate(audio, 0)
|
||||
audio = (audio * 32768).astype(np.int16)
|
||||
@ -1007,7 +1025,7 @@ class TTS:
|
||||
# except Exception as e:
|
||||
# print(f"Failed to change speed of audio: \n{e}")
|
||||
|
||||
return sr, audio
|
||||
return sr, audio, srts
|
||||
|
||||
|
||||
|
||||
|
@ -69,6 +69,7 @@ class TextPreprocessor:
|
||||
"phones": phones,
|
||||
"bert_features": bert_features,
|
||||
"norm_text": norm_text,
|
||||
"origin_text": text,
|
||||
}
|
||||
result.append(res)
|
||||
return result
|
||||
|
35
api_v2.py
35
api_v2.py
@ -36,6 +36,7 @@ POST:
|
||||
"split_bucket: True, # bool. whether to split the batch into multiple buckets.
|
||||
"speed_factor":1.0, # float. control the speed of the synthesized audio.
|
||||
"streaming_mode": False, # bool. whether to return a streaming response.
|
||||
"with_srt_format": "", # str. ""(no srt) or "raw" or "srt", "lrc", "vtt", ... formats (not implemented yet)
|
||||
"seed": -1, # int. random seed for reproducibility.
|
||||
"parallel_infer": True, # bool. whether to use parallel inference.
|
||||
"repetition_penalty": 1.35 # float. repetition penalty for T2S model.
|
||||
@ -98,7 +99,7 @@ RESP:
|
||||
import os
|
||||
import sys
|
||||
import traceback
|
||||
from typing import Generator
|
||||
from typing import Generator, List, Union
|
||||
|
||||
now_dir = os.getcwd()
|
||||
sys.path.append(now_dir)
|
||||
@ -162,6 +163,7 @@ class TTS_Request(BaseModel):
|
||||
seed:int = -1
|
||||
media_type:str = "wav"
|
||||
streaming_mode:bool = False
|
||||
with_srt_format:str = ""
|
||||
parallel_infer:bool = True
|
||||
repetition_penalty:float = 1.35
|
||||
|
||||
@ -211,6 +213,21 @@ def pack_audio(io_buffer:BytesIO, data:np.ndarray, rate:int, media_type:str):
|
||||
io_buffer.seek(0)
|
||||
return io_buffer
|
||||
|
||||
def pack_srt(srt:List, fmt:str):
|
||||
if fmt == "raw":
|
||||
return srt
|
||||
# TODO: support formats like "srt", "lrc", "vtt", ...
|
||||
return srt
|
||||
|
||||
def load_base64_audio(audio):
|
||||
import base64
|
||||
if isinstance(audio, (bytes, bytearray)):
|
||||
audio = bytes(audio)
|
||||
elif hasattr(audio, 'read'): # file-like obj
|
||||
audio = audio.read()
|
||||
else: # path-like
|
||||
audio = open(audio, 'rb').read()
|
||||
return base64.b64encode(audio).decode('ascii')
|
||||
|
||||
_base64_audio_cache = {}
|
||||
def save_base64_audio(b64str:str):
|
||||
@ -309,6 +326,7 @@ async def tts_handle(req:dict):
|
||||
"seed": -1, # int. random seed for reproducibility.
|
||||
"media_type": "wav", # str. media type of the output audio, support "wav", "raw", "ogg", "aac".
|
||||
"streaming_mode": False, # bool. whether to return a streaming response.
|
||||
"with_srt_format": "", # str. ""(no srt) or "raw" or "srt", "lrc", "vtt", ... formats (not implemented yet)
|
||||
"parallel_infer": True, # bool.(optional) whether to use parallel inference.
|
||||
"repetition_penalty": 1.35 # float.(optional) repetition penalty for T2S model.
|
||||
}
|
||||
@ -319,6 +337,7 @@ async def tts_handle(req:dict):
|
||||
streaming_mode = req.get("streaming_mode", False)
|
||||
return_fragment = req.get("return_fragment", False)
|
||||
media_type = req.get("media_type", "wav")
|
||||
with_srt_format = req.get("with_srt_format", "")
|
||||
ref_audio_path = req.get("ref_audio_path", "")
|
||||
if ref_audio_path.startswith("base64:"):
|
||||
req['ref_audio_path'] = ref_audio_path = save_base64_audio(ref_audio_path[len("base64:"):])
|
||||
@ -330,6 +349,9 @@ async def tts_handle(req:dict):
|
||||
if streaming_mode or return_fragment:
|
||||
req["return_fragment"] = True
|
||||
|
||||
if streaming_mode: with_srt_format = "" # streaming not support srt
|
||||
req["return_with_srt"] = "orig" if with_srt_format else ""
|
||||
|
||||
try:
|
||||
tts_generator=tts_pipeline.run(req)
|
||||
|
||||
@ -343,6 +365,16 @@ async def tts_handle(req:dict):
|
||||
# _media_type = f"audio/{media_type}" if not (streaming_mode and media_type in ["wav", "raw"]) else f"audio/x-{media_type}"
|
||||
return StreamingResponse(streaming_generator(tts_generator, media_type, ), media_type=f"audio/{media_type}")
|
||||
|
||||
elif with_srt_format:
|
||||
output = []
|
||||
for sr, audio_data, srt_data in tts_generator:
|
||||
audio_data = pack_audio(BytesIO(), audio_data, sr, media_type).getvalue()
|
||||
output.append({
|
||||
"audio": load_base64_audio(audio_data), "media_type": f"audio/{media_type}",
|
||||
"srt": pack_srt(srt_data, with_srt_format), "srt_fmt": with_srt_format,
|
||||
})
|
||||
return { "message":"succeed", "output":output } # Jsonresponse(status_code=200, content=...)
|
||||
|
||||
else:
|
||||
sr, audio_data = next(tts_generator)
|
||||
audio_data = pack_audio(BytesIO(), audio_data, sr, media_type).getvalue()
|
||||
@ -383,6 +415,7 @@ async def tts_get_endpoint(
|
||||
seed:int = -1,
|
||||
media_type:str = "wav",
|
||||
streaming_mode:bool = False,
|
||||
with_srt_format:str = "",
|
||||
parallel_infer:bool = True,
|
||||
repetition_penalty:float = 1.35
|
||||
):
|
||||
|
Loading…
x
Reference in New Issue
Block a user