Merge c65b448304e97220851ff45def6de62176aa0278 into 5dfce9a3f0def7f1ee1e075df569b0b2d41df9e3

This commit is contained in:
刘悦 2024-08-16 19:16:23 +08:00 committed by GitHub
commit c257f28954
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 147 additions and 1 deletions

View File

@ -908,10 +908,53 @@ class TTS:
else:
# audio = [item for batch in audio for item in batch]
audio = sum(audio, [])
def ms_to_srt_time(ms):
N = int(ms)
hours, remainder = divmod(N, 3600000)
minutes, remainder = divmod(remainder, 60000)
seconds, milliseconds = divmod(remainder, 1000)
timesrt = f"{hours:02d}:{minutes:02d}:{seconds:02d},{milliseconds:03d}"
# print(timesrt)
return timesrt
import soundfile as sf
print("打印")
text = ""
with open(r'./srt/tts-out.txt', 'r',encoding='utf-8') as f:
text = f.read()
text_list = eval(text)
audio_samples = 0
srtlines = []
audio_opt = []
try:
num = 0
for x in audio:
ad = (np.concatenate([x], 0) * 32768).astype(np.int16)
srtline_begin=ms_to_srt_time(audio_samples*1000.0 / int(sr))
audio_samples += ad.size
srtline_end=ms_to_srt_time(audio_samples*1000.0 / int(sr))
audio_opt.append(ad)
srtlines.append(f"{len(audio_opt):02d}\n")
srtlines.append(srtline_begin+' --> '+srtline_end+"\n")
srtlines.append(text_list[num]+"\n\n")
num += 1
except Exception as e:
print(e)
audio = np.concatenate(audio, 0)
audio = (audio * 32768).astype(np.int16)
with open('./srt/tts-out.srt', 'w', encoding='utf-8') as f:
f.writelines(srtlines)
try:
if speed_factor != 1.0:
@ -920,6 +963,8 @@ class TTS:
print(f"Failed to change speed of audio: \n{e}")
return sr, audio
@ -943,4 +988,4 @@ def speed_change(input_audio:np.ndarray, speed:float, sr:int):
# 将管道输出解码为 NumPy 数组
processed_audio = np.frombuffer(out, np.int16)
return processed_audio
return processed_audio

View File

@ -60,6 +60,8 @@ class TextPreprocessor:
texts = self.pre_seg_text(text, lang, text_split_method)
result = []
print(i18n("############ 提取文本Bert特征 ############"))
with open('./srt/tts-out.txt', 'w', encoding='utf-8') as f:
f.write(str(texts))
for text in tqdm(texts):
phones, bert_features, norm_text = self.segment_and_extract_feature_for_text(text, lang)
if phones is None:

View File

@ -114,6 +114,7 @@ import soundfile as sf
from fastapi import FastAPI, Request, HTTPException, Response
from fastapi.responses import StreamingResponse, JSONResponse
from fastapi import FastAPI, UploadFile, File
from fastapi.staticfiles import StaticFiles
import uvicorn
from io import BytesIO
from tools.i18n.i18n import I18nAuto
@ -143,6 +144,8 @@ tts_config = TTS_Config(config_path)
tts_pipeline = TTS(tts_config)
APP = FastAPI()
APP.mount("/srt", StaticFiles(directory="./srt"), name="srt")
APP.mount("/audio", StaticFiles(directory="./audio"), name="audio")
class TTS_Request(BaseModel):
text: str = None
text_lang: str = None
@ -329,7 +332,55 @@ async def tts_handle(req:dict):
return JSONResponse(status_code=400, content={"message": f"tts failed", "Exception": str(e)})
async def tts_handle_srt(req:dict,request):
"""
Text to speech handler.
Args:
req (dict):
{
"text": "", # str.(required) text to be synthesized
"text_lang: "", # str.(required) language of the text to be synthesized
"ref_audio_path": "", # str.(required) reference audio path
"prompt_text": "", # str.(optional) prompt text for the reference audio
"prompt_lang": "", # str.(required) language of the prompt text for the reference audio
"top_k": 5, # int. top k sampling
"top_p": 1, # float. top p sampling
"temperature": 1, # float. temperature for sampling
"text_split_method": "cut5", # str. text split method, see text_segmentation_method.py for details.
"batch_size": 1, # int. batch size for inference
"batch_threshold": 0.75, # float. threshold for batch splitting.
"split_bucket: True, # bool. whether to split the batch into multiple buckets.
"speed_factor":1.0, # float. control the speed of the synthesized audio.
"fragment_interval":0.3, # float. to control the interval of the audio fragment.
"seed": -1, # int. random seed for reproducibility.
"media_type": "wav", # str. media type of the output audio, support "wav", "raw", "ogg", "aac".
"streaming_mode": False, # bool. whether to return a streaming response.
"parallel_infer": True, # bool.(optional) whether to use parallel inference.
"repetition_penalty": 1.35 # float.(optional) repetition penalty for T2S model.
}
returns:
StreamingResponse: audio stream response.
"""
streaming_mode = req.get("streaming_mode", False)
media_type = req.get("media_type", "wav")
check_res = check_params(req)
if check_res is not None:
return check_res
try:
tts_generator=tts_pipeline.run(req)
sr, audio_data = next(tts_generator)
print(audio_data)
#audio_data = pack_audio(BytesIO(), audio_data, sr, media_type).getvalue()
#return Response(audio_data, media_type=f"audio/{media_type}")
return JSONResponse({"code":"200", "srt":f"http://{request.url.hostname}:{request.url.port}/srt/tts-out.srt","audio":f"http://{request.url.hostname}:{request.url.port}/audio/audio.wav"})
except Exception as e:
return JSONResponse(status_code=400, content={"message": f"tts failed", "Exception": str(e)})
@ -339,7 +390,55 @@ async def control(command: str = None):
return JSONResponse(status_code=400, content={"message": "command is required"})
handle_control(command)
@APP.get("/srt")
async def tts_get_endpoint_srt(request: Request,
text: str = None,
text_lang: str = None,
ref_audio_path: str = None,
prompt_lang: str = None,
prompt_text: str = "",
top_k:int = 5,
top_p:float = 1,
temperature:float = 1,
text_split_method:str = "cut5",
batch_size:int = 10,
batch_threshold:float = 0.75,
split_bucket:bool = True,
speed_factor:float = 1.0,
fragment_interval:float = 0.3,
seed:int = -1,
media_type:str = "wav",
streaming_mode:bool = False,
parallel_infer:bool = True,
repetition_penalty:float = 1.35
):
req = {
"text": text,
"text_lang": text_lang.lower(),
"ref_audio_path": ref_audio_path,
"prompt_text": prompt_text,
"prompt_lang": prompt_lang.lower(),
"top_k": top_k,
"top_p": top_p,
"temperature": temperature,
"text_split_method": text_split_method,
"batch_size":int(batch_size),
"batch_threshold":float(batch_threshold),
"speed_factor":float(speed_factor),
"split_bucket":split_bucket,
"fragment_interval":fragment_interval,
"seed":seed,
"media_type":media_type,
"streaming_mode":streaming_mode,
"parallel_infer":parallel_infer,
"repetition_penalty":float(repetition_penalty)
}
return await tts_handle_srt(req,request)
@APP.post("/srt")
async def tts_post_endpoint_srt(request: TTS_Request,req1: Request):
req = request.dict()
return await tts_handle_srt(req,req1)
@APP.get("/tts")
async def tts_get_endpoint(