diff --git a/GPT_SoVITS/TTS_infer_pack/TTS.py b/GPT_SoVITS/TTS_infer_pack/TTS.py index ca51915..a073196 100644 --- a/GPT_SoVITS/TTS_infer_pack/TTS.py +++ b/GPT_SoVITS/TTS_infer_pack/TTS.py @@ -908,10 +908,53 @@ class TTS: else: # audio = [item for batch in audio for item in batch] audio = sum(audio, []) + + def ms_to_srt_time(ms): + N = int(ms) + hours, remainder = divmod(N, 3600000) + minutes, remainder = divmod(remainder, 60000) + seconds, milliseconds = divmod(remainder, 1000) + timesrt = f"{hours:02d}:{minutes:02d}:{seconds:02d},{milliseconds:03d}" + # print(timesrt) + return timesrt + + import soundfile as sf + print("打印") + text = "" + with open(r'./srt/tts-out.txt', 'r',encoding='utf-8') as f: + text = f.read() + text_list = eval(text) + + audio_samples = 0 + srtlines = [] + audio_opt = [] + try: + num = 0 + for x in audio: + ad = (np.concatenate([x], 0) * 32768).astype(np.int16) + + srtline_begin=ms_to_srt_time(audio_samples*1000.0 / int(sr)) + audio_samples += ad.size + srtline_end=ms_to_srt_time(audio_samples*1000.0 / int(sr)) + + audio_opt.append(ad) + + srtlines.append(f"{len(audio_opt):02d}\n") + srtlines.append(srtline_begin+' --> '+srtline_end+"\n") + + + srtlines.append(text_list[num]+"\n\n") + + num += 1 + except Exception as e: + print(e) audio = np.concatenate(audio, 0) audio = (audio * 32768).astype(np.int16) + + with open('./srt/tts-out.srt', 'w', encoding='utf-8') as f: + f.writelines(srtlines) try: if speed_factor != 1.0: @@ -920,6 +963,8 @@ class TTS: print(f"Failed to change speed of audio: \n{e}") return sr, audio + + @@ -943,4 +988,4 @@ def speed_change(input_audio:np.ndarray, speed:float, sr:int): # 将管道输出解码为 NumPy 数组 processed_audio = np.frombuffer(out, np.int16) - return processed_audio \ No newline at end of file + return processed_audio diff --git a/GPT_SoVITS/TTS_infer_pack/TextPreprocessor.py b/GPT_SoVITS/TTS_infer_pack/TextPreprocessor.py index 0891227..9786b7f 100644 --- a/GPT_SoVITS/TTS_infer_pack/TextPreprocessor.py +++ b/GPT_SoVITS/TTS_infer_pack/TextPreprocessor.py @@ -60,6 +60,8 @@ class TextPreprocessor: texts = self.pre_seg_text(text, lang, text_split_method) result = [] print(i18n("############ 提取文本Bert特征 ############")) + with open('./srt/tts-out.txt', 'w', encoding='utf-8') as f: + f.write(str(texts)) for text in tqdm(texts): phones, bert_features, norm_text = self.segment_and_extract_feature_for_text(text, lang) if phones is None: diff --git a/api_v2.py b/api_v2.py index aaa56e0..1d16dd2 100644 --- a/api_v2.py +++ b/api_v2.py @@ -114,6 +114,7 @@ import soundfile as sf from fastapi import FastAPI, Request, HTTPException, Response from fastapi.responses import StreamingResponse, JSONResponse from fastapi import FastAPI, UploadFile, File +from fastapi.staticfiles import StaticFiles import uvicorn from io import BytesIO from tools.i18n.i18n import I18nAuto @@ -143,6 +144,8 @@ tts_config = TTS_Config(config_path) tts_pipeline = TTS(tts_config) APP = FastAPI() +APP.mount("/srt", StaticFiles(directory="./srt"), name="srt") +APP.mount("/audio", StaticFiles(directory="./audio"), name="audio") class TTS_Request(BaseModel): text: str = None text_lang: str = None @@ -329,7 +332,55 @@ async def tts_handle(req:dict): return JSONResponse(status_code=400, content={"message": f"tts failed", "Exception": str(e)}) +async def tts_handle_srt(req:dict,request): + """ + Text to speech handler. + + Args: + req (dict): + { + "text": "", # str.(required) text to be synthesized + "text_lang: "", # str.(required) language of the text to be synthesized + "ref_audio_path": "", # str.(required) reference audio path + "prompt_text": "", # str.(optional) prompt text for the reference audio + "prompt_lang": "", # str.(required) language of the prompt text for the reference audio + "top_k": 5, # int. top k sampling + "top_p": 1, # float. top p sampling + "temperature": 1, # float. temperature for sampling + "text_split_method": "cut5", # str. text split method, see text_segmentation_method.py for details. + "batch_size": 1, # int. batch size for inference + "batch_threshold": 0.75, # float. threshold for batch splitting. + "split_bucket: True, # bool. whether to split the batch into multiple buckets. + "speed_factor":1.0, # float. control the speed of the synthesized audio. + "fragment_interval":0.3, # float. to control the interval of the audio fragment. + "seed": -1, # int. random seed for reproducibility. + "media_type": "wav", # str. media type of the output audio, support "wav", "raw", "ogg", "aac". + "streaming_mode": False, # bool. whether to return a streaming response. + "parallel_infer": True, # bool.(optional) whether to use parallel inference. + "repetition_penalty": 1.35 # float.(optional) repetition penalty for T2S model. + } + returns: + StreamingResponse: audio stream response. + """ + + streaming_mode = req.get("streaming_mode", False) + media_type = req.get("media_type", "wav") + check_res = check_params(req) + if check_res is not None: + return check_res + + + try: + tts_generator=tts_pipeline.run(req) + + sr, audio_data = next(tts_generator) + print(audio_data) + #audio_data = pack_audio(BytesIO(), audio_data, sr, media_type).getvalue() + #return Response(audio_data, media_type=f"audio/{media_type}") + return JSONResponse({"code":"200", "srt":f"http://{request.url.hostname}:{request.url.port}/srt/tts-out.srt","audio":f"http://{request.url.hostname}:{request.url.port}/audio/audio.wav"}) + except Exception as e: + return JSONResponse(status_code=400, content={"message": f"tts failed", "Exception": str(e)}) @@ -339,7 +390,55 @@ async def control(command: str = None): return JSONResponse(status_code=400, content={"message": "command is required"}) handle_control(command) +@APP.get("/srt") +async def tts_get_endpoint_srt(request: Request, + text: str = None, + text_lang: str = None, + ref_audio_path: str = None, + prompt_lang: str = None, + prompt_text: str = "", + top_k:int = 5, + top_p:float = 1, + temperature:float = 1, + text_split_method:str = "cut5", + batch_size:int = 10, + batch_threshold:float = 0.75, + split_bucket:bool = True, + speed_factor:float = 1.0, + fragment_interval:float = 0.3, + seed:int = -1, + media_type:str = "wav", + streaming_mode:bool = False, + parallel_infer:bool = True, + repetition_penalty:float = 1.35 + ): + req = { + "text": text, + "text_lang": text_lang.lower(), + "ref_audio_path": ref_audio_path, + "prompt_text": prompt_text, + "prompt_lang": prompt_lang.lower(), + "top_k": top_k, + "top_p": top_p, + "temperature": temperature, + "text_split_method": text_split_method, + "batch_size":int(batch_size), + "batch_threshold":float(batch_threshold), + "speed_factor":float(speed_factor), + "split_bucket":split_bucket, + "fragment_interval":fragment_interval, + "seed":seed, + "media_type":media_type, + "streaming_mode":streaming_mode, + "parallel_infer":parallel_infer, + "repetition_penalty":float(repetition_penalty) + } + return await tts_handle_srt(req,request) +@APP.post("/srt") +async def tts_post_endpoint_srt(request: TTS_Request,req1: Request): + req = request.dict() + return await tts_handle_srt(req,req1) @APP.get("/tts") async def tts_get_endpoint(