Merge 27664703d2fb3c86504d8168ae79639b784c56f7 into b7a904a67153170d334fdc0d7fbae220ee21f0e9

This commit is contained in:
Jin.W 2025-01-20 17:08:08 +08:00 committed by GitHub
commit d174fefdc6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 93 additions and 21 deletions

View File

@ -550,6 +550,7 @@ class TTS:
all_phones_len_list = [] all_phones_len_list = []
all_bert_features_list = [] all_bert_features_list = []
norm_text_batch = [] norm_text_batch = []
origin_text_batch = []
all_bert_max_len = 0 all_bert_max_len = 0
all_phones_max_len = 0 all_phones_max_len = 0
for item in item_list: for item in item_list:
@ -575,6 +576,7 @@ class TTS:
all_phones_len_list.append(all_phones.shape[-1]) all_phones_len_list.append(all_phones.shape[-1])
all_bert_features_list.append(all_bert_features) all_bert_features_list.append(all_bert_features)
norm_text_batch.append(item["norm_text"]) norm_text_batch.append(item["norm_text"])
origin_text_batch.append(item["origin_text"])
phones_batch = phones_list phones_batch = phones_list
all_phones_batch = all_phones_list all_phones_batch = all_phones_list
@ -606,6 +608,7 @@ class TTS:
"all_phones_len": torch.LongTensor(all_phones_len_list).to(device), "all_phones_len": torch.LongTensor(all_phones_len_list).to(device),
"all_bert_features": all_bert_features_batch, "all_bert_features": all_bert_features_batch,
"norm_text": norm_text_batch, "norm_text": norm_text_batch,
"origin_text": origin_text_batch,
"max_len": max_len, "max_len": max_len,
} }
_data.append(batch) _data.append(batch)
@ -658,6 +661,7 @@ class TTS:
"batch_threshold": 0.75, # float. threshold for batch splitting. "batch_threshold": 0.75, # float. threshold for batch splitting.
"split_bucket: True, # bool. whether to split the batch into multiple buckets. "split_bucket: True, # bool. whether to split the batch into multiple buckets.
"return_fragment": False, # bool. step by step return the audio fragment. "return_fragment": False, # bool. step by step return the audio fragment.
"return_with_srt": "", # str. return with or without("") subtitles, using "orig"inal or "norm"alized text
"speed_factor":1.0, # float. control the speed of the synthesized audio. "speed_factor":1.0, # float. control the speed of the synthesized audio.
"fragment_interval":0.3, # float. to control the interval of the audio fragment. "fragment_interval":0.3, # float. to control the interval of the audio fragment.
"seed": -1, # int. random seed for reproducibility. "seed": -1, # int. random seed for reproducibility.
@ -685,6 +689,7 @@ class TTS:
split_bucket = inputs.get("split_bucket", True) split_bucket = inputs.get("split_bucket", True)
return_fragment = inputs.get("return_fragment", False) return_fragment = inputs.get("return_fragment", False)
fragment_interval = inputs.get("fragment_interval", 0.3) fragment_interval = inputs.get("fragment_interval", 0.3)
return_with_srt = inputs.get("return_with_srt", "")
seed = inputs.get("seed", -1) seed = inputs.get("seed", -1)
seed = -1 if seed in ["", None] else seed seed = -1 if seed in ["", None] else seed
actual_seed = set_seed(seed) actual_seed = set_seed(seed)
@ -704,6 +709,9 @@ class TTS:
split_bucket = False split_bucket = False
print(i18n("分段返回模式不支持分桶处理,已自动关闭分桶处理")) print(i18n("分段返回模式不支持分桶处理,已自动关闭分桶处理"))
ret_width = 3 if return_with_srt else 2 # return (sr, audio, srt) or (sr, audio)
srt_text = "norm_text" if return_with_srt.startswith("norm") else "origin_text"
if split_bucket and speed_factor==1.0: if split_bucket and speed_factor==1.0:
print(i18n("分桶处理模式已开启")) print(i18n("分桶处理模式已开启"))
elif speed_factor!=1.0: elif speed_factor!=1.0:
@ -773,8 +781,7 @@ class TTS:
if not return_fragment: if not return_fragment:
data = self.text_preprocessor.preprocess(text, text_lang, text_split_method, self.configs.version) data = self.text_preprocessor.preprocess(text, text_lang, text_split_method, self.configs.version)
if len(data) == 0: if len(data) == 0:
yield self.configs.sampling_rate, np.zeros(int(self.configs.sampling_rate), yield self.audio_failure()[:ret_width]
dtype=np.int16)
return return
batch_index_list:list = None batch_index_list:list = None
@ -806,6 +813,7 @@ class TTS:
"phones": phones, "phones": phones,
"bert_features": bert_features, "bert_features": bert_features,
"norm_text": norm_text, "norm_text": norm_text,
"origin_text": text,
} }
batch_data.append(res) batch_data.append(res)
if len(batch_data) == 0: if len(batch_data) == 0:
@ -841,10 +849,11 @@ class TTS:
all_phoneme_ids:torch.LongTensor = item["all_phones"] all_phoneme_ids:torch.LongTensor = item["all_phones"]
all_phoneme_lens:torch.LongTensor = item["all_phones_len"] all_phoneme_lens:torch.LongTensor = item["all_phones_len"]
all_bert_features:torch.LongTensor = item["all_bert_features"] all_bert_features:torch.LongTensor = item["all_bert_features"]
norm_text:str = item["norm_text"] # norm_text:List[str] = item["norm_text"]
# origin_text:List[str] = item["origin_text"]
max_len = item["max_len"] max_len = item["max_len"]
print(i18n("前端处理后的文本(每句):"), norm_text) print(i18n("前端处理后的文本(每批):"), item["norm_text"])
if no_prompt_text : if no_prompt_text :
prompt = None prompt = None
else: else:
@ -915,39 +924,38 @@ class TTS:
if return_fragment: if return_fragment:
print("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t4 - t3, t5 - t4)) print("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t4 - t3, t5 - t4))
yield self.audio_postprocess([batch_audio_fragment], yield self.audio_postprocess([batch_audio_fragment],
[item[srt_text]],
self.configs.sampling_rate, self.configs.sampling_rate,
None, None,
speed_factor, speed_factor,
False, False,
fragment_interval fragment_interval
) )[:ret_width]
else: else:
audio.append(batch_audio_fragment) audio.append(batch_audio_fragment)
if self.stop_flag: if self.stop_flag:
yield self.configs.sampling_rate, np.zeros(int(self.configs.sampling_rate), yield self.audio_failure()[:ret_width]
dtype=np.int16)
return return
if not return_fragment: if not return_fragment:
print("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t_34, t_45)) print("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t_34, t_45))
if len(audio) == 0: if len(audio) == 0:
yield self.configs.sampling_rate, np.zeros(int(self.configs.sampling_rate), yield self.audio_failure()[:ret_width]
dtype=np.int16)
return return
yield self.audio_postprocess(audio, yield self.audio_postprocess(audio,
[v[srt_text] for v in data],
self.configs.sampling_rate, self.configs.sampling_rate,
batch_index_list, batch_index_list,
speed_factor, speed_factor,
split_bucket, split_bucket,
fragment_interval fragment_interval
) )[:ret_width]
except Exception as e: except Exception as e:
traceback.print_exc() traceback.print_exc()
# 必须返回一个空音频, 否则会导致显存不释放。 # 必须返回一个空音频, 否则会导致显存不释放。
yield self.configs.sampling_rate, np.zeros(int(self.configs.sampling_rate), yield self.audio_failure()[:ret_width]
dtype=np.int16)
# 重置模型, 否则会导致显存释放不完全。 # 重置模型, 否则会导致显存释放不完全。
del self.t2s_model del self.t2s_model
del self.vits_model del self.vits_model
@ -969,14 +977,18 @@ class TTS:
except: except:
pass pass
def audio_failure(self):
return self.configs.sampling_rate, np.zeros(int(self.configs.sampling_rate), dtype=np.int16), []
def audio_postprocess(self, def audio_postprocess(self,
audio:List[torch.Tensor], audio:List[torch.Tensor],
texts:List[List[str]],
sr:int, sr:int,
batch_index_list:list=None, batch_index_list:list=None,
speed_factor:float=1.0, speed_factor:float=1.0,
split_bucket:bool=True, split_bucket:bool=True,
fragment_interval:float=0.3 fragment_interval:float=0.3
)->Tuple[int, np.ndarray]: )->Tuple[int, np.ndarray, List]:
zero_wav = torch.zeros( zero_wav = torch.zeros(
int(self.configs.sampling_rate * fragment_interval), int(self.configs.sampling_rate * fragment_interval),
dtype=self.precision, dtype=self.precision,
@ -993,10 +1005,16 @@ class TTS:
if split_bucket: if split_bucket:
audio = self.recovery_order(audio, batch_index_list) audio = self.recovery_order(audio, batch_index_list)
texts = self.recovery_order(texts, batch_index_list)
else: else:
# audio = [item for batch in audio for item in batch] # audio = [item for batch in audio for item in batch]
audio = sum(audio, []) audio = sum(audio, [])
texts = sum(texts, [])
# 按顺序计算每段语音的起止时间,并与文字一一对应,用于生成字幕
from itertools import accumulate
stamps = [0.0] + [x/sr for x in accumulate([v.size for v in audio])]
srts = list(zip(stamps[:-1], stamps[1:], texts)) # time start, end, text
audio = np.concatenate(audio, 0) audio = np.concatenate(audio, 0)
audio = (audio * 32768).astype(np.int16) audio = (audio * 32768).astype(np.int16)
@ -1007,7 +1025,7 @@ class TTS:
# except Exception as e: # except Exception as e:
# print(f"Failed to change speed of audio: \n{e}") # print(f"Failed to change speed of audio: \n{e}")
return sr, audio return sr, audio, srts

View File

@ -69,6 +69,7 @@ class TextPreprocessor:
"phones": phones, "phones": phones,
"bert_features": bert_features, "bert_features": bert_features,
"norm_text": norm_text, "norm_text": norm_text,
"origin_text": text,
} }
result.append(res) result.append(res)
return result return result

View File

@ -36,6 +36,7 @@ POST:
"split_bucket: True, # bool. whether to split the batch into multiple buckets. "split_bucket: True, # bool. whether to split the batch into multiple buckets.
"speed_factor":1.0, # float. control the speed of the synthesized audio. "speed_factor":1.0, # float. control the speed of the synthesized audio.
"streaming_mode": False, # bool. whether to return a streaming response. "streaming_mode": False, # bool. whether to return a streaming response.
"with_srt_format": "", # str. ""(no srt) or "raw" or "srt", "lrc", "vtt", ... formats (not implemented yet)
"seed": -1, # int. random seed for reproducibility. "seed": -1, # int. random seed for reproducibility.
"parallel_infer": True, # bool. whether to use parallel inference. "parallel_infer": True, # bool. whether to use parallel inference.
"repetition_penalty": 1.35 # float. repetition penalty for T2S model. "repetition_penalty": 1.35 # float. repetition penalty for T2S model.
@ -98,7 +99,7 @@ RESP:
import os import os
import sys import sys
import traceback import traceback
from typing import Generator from typing import Generator, List, Union
now_dir = os.getcwd() now_dir = os.getcwd()
sys.path.append(now_dir) sys.path.append(now_dir)
@ -162,6 +163,7 @@ class TTS_Request(BaseModel):
seed:int = -1 seed:int = -1
media_type:str = "wav" media_type:str = "wav"
streaming_mode:bool = False streaming_mode:bool = False
with_srt_format:str = ""
parallel_infer:bool = True parallel_infer:bool = True
repetition_penalty:float = 1.35 repetition_penalty:float = 1.35
@ -211,7 +213,38 @@ def pack_audio(io_buffer:BytesIO, data:np.ndarray, rate:int, media_type:str):
io_buffer.seek(0) io_buffer.seek(0)
return io_buffer return io_buffer
def pack_srt(srt:List, fmt:str):
if fmt == "raw":
return srt
# TODO: support formats like "srt", "lrc", "vtt", ...
return srt
def load_base64_audio(audio):
import base64
if isinstance(audio, (bytes, bytearray)):
audio = bytes(audio)
elif hasattr(audio, 'read'): # file-like obj
audio = audio.read()
else: # path-like
audio = open(audio, 'rb').read()
return base64.b64encode(audio).decode('ascii')
_base64_audio_cache = {}
def save_base64_audio(b64str:str):
import filetype, base64, uuid
global _base64_audio_cache
if b64str in _base64_audio_cache:
return _base64_audio_cache[b64str]
savedir = 'TEMP/upload'
data = base64.b64decode(b64str)
ft = filetype.guess(data)
ext = f'.{ft.extension}' if ft else ''
os.makedirs(savedir, exist_ok=True)
saveto = f'{savedir}/{uuid.uuid1()}{ext}'
with open(saveto, 'wb') as outf:
outf.write(data)
_base64_audio_cache[b64str] = saveto
return saveto
# from https://huggingface.co/spaces/coqui/voice-chat-with-mistral/blob/main/app.py # from https://huggingface.co/spaces/coqui/voice-chat-with-mistral/blob/main/app.py
def wave_header_chunk(frame_input=b"", channels=1, sample_width=2, sample_rate=32000): def wave_header_chunk(frame_input=b"", channels=1, sample_width=2, sample_rate=32000):
@ -277,7 +310,7 @@ async def tts_handle(req:dict):
{ {
"text": "", # str.(required) text to be synthesized "text": "", # str.(required) text to be synthesized
"text_lang: "", # str.(required) language of the text to be synthesized "text_lang: "", # str.(required) language of the text to be synthesized
"ref_audio_path": "", # str.(required) reference audio path "ref_audio_path": "", # str.(required) reference audio path ; allow data of format base64:xxxxxx
"aux_ref_audio_paths": [], # list.(optional) auxiliary reference audio paths for multi-speaker synthesis "aux_ref_audio_paths": [], # list.(optional) auxiliary reference audio paths for multi-speaker synthesis
"prompt_text": "", # str.(optional) prompt text for the reference audio "prompt_text": "", # str.(optional) prompt text for the reference audio
"prompt_lang": "", # str.(required) language of the prompt text for the reference audio "prompt_lang": "", # str.(required) language of the prompt text for the reference audio
@ -293,6 +326,7 @@ async def tts_handle(req:dict):
"seed": -1, # int. random seed for reproducibility. "seed": -1, # int. random seed for reproducibility.
"media_type": "wav", # str. media type of the output audio, support "wav", "raw", "ogg", "aac". "media_type": "wav", # str. media type of the output audio, support "wav", "raw", "ogg", "aac".
"streaming_mode": False, # bool. whether to return a streaming response. "streaming_mode": False, # bool. whether to return a streaming response.
"with_srt_format": "", # str. ""(no srt) or "raw" or "srt", "lrc", "vtt", ... formats (not implemented yet)
"parallel_infer": True, # bool.(optional) whether to use parallel inference. "parallel_infer": True, # bool.(optional) whether to use parallel inference.
"repetition_penalty": 1.35 # float.(optional) repetition penalty for T2S model. "repetition_penalty": 1.35 # float.(optional) repetition penalty for T2S model.
} }
@ -303,6 +337,10 @@ async def tts_handle(req:dict):
streaming_mode = req.get("streaming_mode", False) streaming_mode = req.get("streaming_mode", False)
return_fragment = req.get("return_fragment", False) return_fragment = req.get("return_fragment", False)
media_type = req.get("media_type", "wav") media_type = req.get("media_type", "wav")
with_srt_format = req.get("with_srt_format", "")
ref_audio_path = req.get("ref_audio_path", "")
if ref_audio_path.startswith("base64:"):
req['ref_audio_path'] = ref_audio_path = save_base64_audio(ref_audio_path[len("base64:"):])
check_res = check_params(req) check_res = check_params(req)
if check_res is not None: if check_res is not None:
@ -311,6 +349,9 @@ async def tts_handle(req:dict):
if streaming_mode or return_fragment: if streaming_mode or return_fragment:
req["return_fragment"] = True req["return_fragment"] = True
if streaming_mode: with_srt_format = "" # streaming not support srt
req["return_with_srt"] = "orig" if with_srt_format else ""
try: try:
tts_generator=tts_pipeline.run(req) tts_generator=tts_pipeline.run(req)
@ -324,6 +365,16 @@ async def tts_handle(req:dict):
# _media_type = f"audio/{media_type}" if not (streaming_mode and media_type in ["wav", "raw"]) else f"audio/x-{media_type}" # _media_type = f"audio/{media_type}" if not (streaming_mode and media_type in ["wav", "raw"]) else f"audio/x-{media_type}"
return StreamingResponse(streaming_generator(tts_generator, media_type, ), media_type=f"audio/{media_type}") return StreamingResponse(streaming_generator(tts_generator, media_type, ), media_type=f"audio/{media_type}")
elif with_srt_format:
output = []
for sr, audio_data, srt_data in tts_generator:
audio_data = pack_audio(BytesIO(), audio_data, sr, media_type).getvalue()
output.append({
"audio": load_base64_audio(audio_data), "media_type": f"audio/{media_type}",
"srt": pack_srt(srt_data, with_srt_format), "srt_fmt": with_srt_format,
})
return { "message":"succeed", "output":output } # Jsonresponse(status_code=200, content=...)
else: else:
sr, audio_data = next(tts_generator) sr, audio_data = next(tts_generator)
audio_data = pack_audio(BytesIO(), audio_data, sr, media_type).getvalue() audio_data = pack_audio(BytesIO(), audio_data, sr, media_type).getvalue()
@ -364,6 +415,7 @@ async def tts_get_endpoint(
seed:int = -1, seed:int = -1,
media_type:str = "wav", media_type:str = "wav",
streaming_mode:bool = False, streaming_mode:bool = False,
with_srt_format:str = "",
parallel_infer:bool = True, parallel_infer:bool = True,
repetition_penalty:float = 1.35 repetition_penalty:float = 1.35
): ):

View File

@ -34,3 +34,4 @@ opencc; sys_platform != 'linux'
opencc==1.1.1; sys_platform == 'linux' opencc==1.1.1; sys_platform == 'linux'
python_mecab_ko; sys_platform != 'win32' python_mecab_ko; sys_platform != 'win32'
fastapi<0.112.2 fastapi<0.112.2
filetype