From 251160362f2d385ef2a3a89f24362bae342dcc9b Mon Sep 17 00:00:00 2001 From: Jarod Mica Date: Mon, 23 Dec 2024 00:11:17 -0800 Subject: [PATCH] Add token streaming in batches to the TTS class. --- GPT_SoVITS/TTS_infer_pack/TTS.py | 419 ++++++++++++++++++++- GPT_SoVITS/TTS_infer_pack/zero_crossing.py | 203 ++++++++++ GPT_SoVITS/api_v2.py | 2 +- infer_script.py | 272 +++++++++++++ 4 files changed, 893 insertions(+), 3 deletions(-) create mode 100644 GPT_SoVITS/TTS_infer_pack/zero_crossing.py create mode 100644 infer_script.py diff --git a/GPT_SoVITS/TTS_infer_pack/TTS.py b/GPT_SoVITS/TTS_infer_pack/TTS.py index 809601b3..42105064 100644 --- a/GPT_SoVITS/TTS_infer_pack/TTS.py +++ b/GPT_SoVITS/TTS_infer_pack/TTS.py @@ -13,6 +13,10 @@ import torch import torch.nn.functional as F import traceback import yaml +import queue +import sounddevice as sd +import soundfile as sf +import threading from huggingface_hub import snapshot_download, hf_hub_download from importlib.resources import files @@ -875,9 +879,64 @@ class TTS: t_34 += t4 - t3 refer_audio_spec:torch.Tensor = [item.to(dtype=self.precision, device=self.configs.device) for item in self.prompt_cache["refer_spec"]] - + + # Split the semantic tokens into chunks + num_chunks = 10 # Number of chunks to split into + chunked_pred_semantic_list = [] # This will store the chunks for each sample + + pred_semantic_list = [item[-idx:] for item, idx in zip(pred_semantic_list, idx_list)] + + for semantic_tokens in pred_semantic_list: + total_length = semantic_tokens.shape[0] + chunk_size = total_length // num_chunks + chunks = [] + for i in range(num_chunks): + overlap = 0 + # samples_per_token = 1280 + # sample_rate = 32000 + # start_idx = i * chunk_size - overlap + start_idx = 0 + # so each subsequent sample is overlapping by 5120 samples + if start_idx < 0: + start_idx = 0 + if i == num_chunks - 1: + # Make sure to include the remainder in the last chunk + end_idx = total_length + else: + end_idx = (i + 1) * chunk_size + chunk = semantic_tokens[start_idx:end_idx] + chunks.append(chunk) + chunked_pred_semantic_list.append(chunks) + + # Process chunks through VITS + batch_audio_chunks = [] # List to hold audio chunks for each sample - batch_audio_fragment = [] + for i, (chunks, phones) in enumerate(zip(chunked_pred_semantic_list, batch_phones)): + phones = phones.unsqueeze(0).to(self.configs.device) + audio_chunks = [] + for chunk in chunks: + # Prepare the chunk for VITS + chunk = chunk.unsqueeze(0).unsqueeze(0).to(self.configs.device) + # Process the chunk through VITS + audio_fragment = self.vits_model.decode( + chunk, phones, refer_audio_spec, speed=speed_factor + ).detach()[0, 0, :] + audio_chunks.append(audio_fragment.cpu().numpy()) + batch_audio_chunks.append(audio_chunks) + + output_dir = 'output_chunks' + os.makedirs(output_dir, exist_ok=True) + + for sample_idx, audio_chunks in enumerate(batch_audio_chunks): + for chunk_idx, audio_chunk in enumerate(audio_chunks): + # Convert audio_chunk to float32 + audio_chunk = audio_chunk.astype(np.float32) + # Create a filename for each chunk + filename = f'sample_{sample_idx}_chunk_{chunk_idx}.wav' + output_path = os.path.join(output_dir, filename) + # Save the audio chunk + sf.write(output_path, audio_chunk, self.configs.sampling_rate) + print(f'Saved audio chunk: {output_path}') # ## vits并行推理 method 1 # pred_semantic_list = [item[-idx:] for item, idx in zip(pred_semantic_list, idx_list)] @@ -965,6 +1024,362 @@ class TTS: raise e finally: self.empty_cache() + + @torch.no_grad() + def run_generator(self, inputs:dict): + """ + Text to speech inference. + + Args: + inputs (dict): + { + "text": "", # str.(required) text to be synthesized + "text_lang: "", # str.(required) language of the text to be synthesized + "ref_audio_path": "", # str.(required) reference audio path + "aux_ref_audio_paths": [], # list.(optional) auxiliary reference audio paths for multi-speaker tone fusion + "prompt_text": "", # str.(optional) prompt text for the reference audio + "prompt_lang": "", # str.(required) language of the prompt text for the reference audio + "top_k": 5, # int. top k sampling + "top_p": 1, # float. top p sampling + "temperature": 1, # float. temperature for sampling + "text_split_method": "cut0", # str. text split method, see text_segmentation_method.py for details. + "batch_size": 1, # int. batch size for inference + "batch_threshold": 0.75, # float. threshold for batch splitting. + "split_bucket: True, # bool. whether to split the batch into multiple buckets. + "return_fragment": False, # bool. step by step return the audio fragment. + "speed_factor":1.0, # float. control the speed of the synthesized audio. + "fragment_interval":0.3, # float. to control the interval of the audio fragment. + "seed": -1, # int. random seed for reproducibility. + "parallel_infer": True, # bool. whether to use parallel inference. + "repetition_penalty": 1.35 # float. repetition penalty for T2S model. + } + returns: + Tuple[int, np.ndarray]: sampling rate and audio data. + """ + ########## variables initialization ########### + self.stop_flag:bool = False + text:str = inputs.get("text", "") + text_lang:str = inputs.get("text_lang", "") + ref_audio_path:str = inputs.get("ref_audio_path", "") + aux_ref_audio_paths:list = inputs.get("aux_ref_audio_paths", []) + prompt_text:str = inputs.get("prompt_text", "") + prompt_lang:str = inputs.get("prompt_lang", "") + top_k:int = inputs.get("top_k", 5) + top_p:float = inputs.get("top_p", 1) + temperature:float = inputs.get("temperature", 1) + text_split_method:str = inputs.get("text_split_method", "cut0") + batch_size = inputs.get("batch_size", 1) + batch_threshold = inputs.get("batch_threshold", 0.75) + speed_factor = inputs.get("speed_factor", 1.0) + split_bucket = inputs.get("split_bucket", True) + return_fragment = inputs.get("return_fragment", False) + fragment_interval = inputs.get("fragment_interval", 0.3) + seed = inputs.get("seed", -1) + seed = -1 if seed in ["", None] else seed + actual_seed = set_seed(seed) + parallel_infer = inputs.get("parallel_infer", True) + repetition_penalty = inputs.get("repetition_penalty", 1.35) + + if parallel_infer: + print(i18n("并行推理模式已开启")) + self.t2s_model.model.infer_panel = self.t2s_model.model.infer_panel_batch_infer + else: + print(i18n("并行推理模式已关闭")) + self.t2s_model.model.infer_panel = self.t2s_model.model.infer_panel_naive_batched + + if return_fragment: + print(i18n("分段返回模式已开启")) + if split_bucket: + split_bucket = False + print(i18n("分段返回模式不支持分桶处理,已自动关闭分桶处理")) + + if split_bucket and speed_factor==1.0: + print(i18n("分桶处理模式已开启")) + elif speed_factor!=1.0: + print(i18n("语速调节不支持分桶处理,已自动关闭分桶处理")) + split_bucket = False + else: + print(i18n("分桶处理模式已关闭")) + + if fragment_interval<0.01: + fragment_interval = 0.01 + print(i18n("分段间隔过小,已自动设置为0.01")) + + no_prompt_text = False + if prompt_text in [None, ""]: + no_prompt_text = True + + assert text_lang in self.configs.languages + if not no_prompt_text: + assert prompt_lang in self.configs.languages + + if ref_audio_path in [None, ""] and \ + ((self.prompt_cache["prompt_semantic"] is None) or (self.prompt_cache["refer_spec"] in [None, []])): + raise ValueError("ref_audio_path cannot be empty, when the reference audio is not set using set_ref_audio()") + + ###### setting reference audio and prompt text preprocessing ######## + t0 = ttime() + if (ref_audio_path is not None) and (ref_audio_path != self.prompt_cache["ref_audio_path"]): + if not os.path.exists(ref_audio_path): + raise ValueError(f"{ref_audio_path} not exists") + self.set_ref_audio(ref_audio_path) + + aux_ref_audio_paths = aux_ref_audio_paths if aux_ref_audio_paths is not None else [] + paths = set(aux_ref_audio_paths)&set(self.prompt_cache["aux_ref_audio_paths"]) + if not (len(list(paths)) == len(aux_ref_audio_paths) == len(self.prompt_cache["aux_ref_audio_paths"])): + self.prompt_cache["aux_ref_audio_paths"] = aux_ref_audio_paths + self.prompt_cache["refer_spec"] = [self.prompt_cache["refer_spec"][0]] + for path in aux_ref_audio_paths: + if path in [None, ""]: + continue + if not os.path.exists(path): + print(i18n("音频文件不存在,跳过:{}").format(path)) + continue + self.prompt_cache["refer_spec"].append(self._get_ref_spec(path)) + + if not no_prompt_text: + prompt_text = prompt_text.strip("\n") + if (prompt_text[-1] not in splits): prompt_text += "。" if prompt_lang != "en" else "." + print(i18n("实际输入的参考文本:"), prompt_text) + if self.prompt_cache["prompt_text"] != prompt_text: + self.prompt_cache["prompt_text"] = prompt_text + self.prompt_cache["prompt_lang"] = prompt_lang + phones, bert_features, norm_text = \ + self.text_preprocessor.segment_and_extract_feature_for_text( + prompt_text, + prompt_lang, + self.configs.version) + self.prompt_cache["phones"] = phones + self.prompt_cache["bert_features"] = bert_features + self.prompt_cache["norm_text"] = norm_text + + + + + ###### text preprocessing ######## + t1 = ttime() + data:list = None + if not return_fragment: + data = self.text_preprocessor.preprocess(text, text_lang, text_split_method, self.configs.version) + if len(data) == 0: + yield self.configs.sampling_rate, np.zeros(int(self.configs.sampling_rate), + dtype=np.int16) + return + + batch_index_list:list = None + data, batch_index_list = self.to_batch(data, + prompt_data=self.prompt_cache if not no_prompt_text else None, + batch_size=batch_size, + threshold=batch_threshold, + split_bucket=split_bucket, + device=self.configs.device, + precision=self.precision + ) + else: + print(i18n("############ 切分文本 ############")) + texts = self.text_preprocessor.pre_seg_text(text, text_lang, text_split_method) + data = [] + for i in range(len(texts)): + if i%batch_size == 0: + data.append([]) + data[-1].append(texts[i]) + + def make_batch(batch_texts): + batch_data = [] + print(i18n("############ 提取文本Bert特征 ############")) + for text in tqdm(batch_texts): + phones, bert_features, norm_text = self.text_preprocessor.segment_and_extract_feature_for_text(text, text_lang, self.configs.version) + if phones is None: + continue + res={ + "phones": phones, + "bert_features": bert_features, + "norm_text": norm_text, + } + batch_data.append(res) + if len(batch_data) == 0: + return None + batch, _ = self.to_batch(batch_data, + prompt_data=self.prompt_cache if not no_prompt_text else None, + batch_size=batch_size, + threshold=batch_threshold, + split_bucket=False, + device=self.configs.device, + precision=self.precision + ) + return batch[0] + + t2 = ttime() + + try: + print("############ 推理 ############") + ###### inference ###### + t_34 = 0.0 + t_45 = 0.0 + audio = [] + for item in data: + t3 = ttime() + if return_fragment: + item = make_batch(item) + if item is None: + continue + + batch_phones:List[torch.LongTensor] = item["phones"] + # batch_phones:torch.LongTensor = item["phones"] + batch_phones_len:torch.LongTensor = item["phones_len"] + all_phoneme_ids:torch.LongTensor = item["all_phones"] + all_phoneme_lens:torch.LongTensor = item["all_phones_len"] + all_bert_features:torch.LongTensor = item["all_bert_features"] + norm_text:str = item["norm_text"] + max_len = item["max_len"] + + print(i18n("前端处理后的文本(每句):"), norm_text) + if no_prompt_text : + prompt = None + else: + prompt = self.prompt_cache["prompt_semantic"].expand(len(all_phoneme_ids), -1).to(self.configs.device) + + refer_audio_spec:torch.Tensor = [item.to(dtype=self.precision, device=self.configs.device) for item in self.prompt_cache["refer_spec"]] + + generated_tokens_list = [] + start = ttime() + + from GPT_SoVITS.TTS_infer_pack.zero_crossing import find_zero_zone, find_matching_index + zc_index1 = 0 + zc_index2 = 0 + crossing_direction = 0 + first_chunk = True + last_chunk = False + search_length = 32000*5 + num_zeroes = 5 + cumulation_amount=50 + + # Use infer_panel_generator to generate tokens in batches + for generated_tokens in self.t2s_model.model.infer_panel_generator( + all_phoneme_ids[0].unsqueeze(0), + all_phoneme_lens[0].unsqueeze(0), + prompt[0].unsqueeze(0) if prompt is not None else None, + all_bert_features[0].unsqueeze(0), + cumulation_amount=cumulation_amount, + top_k=top_k, + top_p=top_p, + temperature=temperature, + early_stop_num=self.configs.hz * self.configs.max_sec, + max_len=max_len, + repetition_penalty=repetition_penalty, + ): + # Append the generated tokens + generated_tokens_list.append(generated_tokens) + total_tokens = sum([tokens.size(1) for tokens in generated_tokens_list]) + + tokens_to_process = torch.cat(generated_tokens_list, dim=1)[:, :total_tokens] # uses full context for decoding + + # Check if tokens_to_process contains the EOS token (1024) + contains_eos = (tokens_to_process == 1024).any() + + if contains_eos: + # Replace all instances of the EOS token (1024) with 0 + tokens_to_process = tokens_to_process.masked_fill(tokens_to_process == 1024, 0) + print("Replaced EOS token (1024) with 0 in tokens_to_process") + last_chunk = True + first_chunk = False + + # Prepare input for VITS model + _pred_semantic = tokens_to_process.unsqueeze(0) + phones = batch_phones[0].unsqueeze(0).to(self.configs.device) + + # Generate audio for the tokens + audio_output = self.vits_model.decode( + _pred_semantic, phones, refer_audio_spec, speed=speed_factor + ).detach()[0, 0, :] + + audio_output = audio_output[:].cpu().numpy() + # Convert audio_fragment to float32 and normalize + audio_output = audio_output.astype(np.float32) + max_val = np.abs(audio_output).max() + if max_val > 1.0: + audio_output /= max_val + + start_index = len(audio_output) - search_length + if start_index < 0: + search_length = len(audio_output) + print(f"search_length is too HIGH! Auto adjusted to {search_length} frames as the chunks are only {len(audio_output)} frames large") + start_index = 0 + center_index = zc_index2 # Start from previous zero crossing index and search outwards + max_offset = int(search_length // 2) # branches out in both ways + + if center_index < 0: + raise "Something wrong is going on here, center index issue, less than 0" + elif center_index >= len(audio_output): + raise "Something wrong is going on here, center index issue, greater than audio_output" + + if first_chunk: + + zc_index1, crossing_direction = find_zero_zone( + chunk=audio_output, + start_index=start_index, + search_length=search_length, + num_zeroes=num_zeroes + ) + audio_fragment = audio_output[:zc_index1] + yield self.configs.sampling_rate, audio_fragment + first_chunk = False + zc_index2 = zc_index1 + elif last_chunk: + zc_index1 = find_matching_index( + chunk=audio_output, + center_index=center_index, + max_offset=max_offset, + crossing_direction=crossing_direction + ) + audio_fragment = audio_output[zc_index1:] + yield self.configs.sampling_rate, audio_fragment + + else: + zc_index1 = find_matching_index( + chunk=audio_output, + center_index=center_index, + max_offset=max_offset, + crossing_direction=crossing_direction + ) + + zc_index2, crossing_direction = find_zero_zone( + chunk=audio_output, + start_index=start_index, + search_length=search_length, + num_zeroes=num_zeroes + ) + audio_fragment = audio_output[zc_index1:zc_index2] + yield self.configs.sampling_rate, audio_fragment + + end = ttime() + print(f"Time to speech: {end-start}") + + except Exception as e: + traceback.print_exc() + # 必须返回一个空音频, 否则会导致显存不释放。 + yield self.configs.sampling_rate, np.zeros(int(self.configs.sampling_rate), + dtype=np.int16) + # 重置模型, 否则会导致显存释放不完全。 + del self.t2s_model + del self.vits_model + self.t2s_model = None + self.vits_model = None + self.init_t2s_weights(self.configs.t2s_weights_path) + self.init_vits_weights(self.configs.vits_weights_path) + raise e + finally: + self.empty_cache() + + def empty_cache(self): + try: + gc.collect() # 触发gc的垃圾回收。避免内存一直增长。 + if "cuda" in str(self.configs.device): + torch.cuda.empty_cache() + elif str(self.configs.device) == "mps": + torch.mps.empty_cache() + except: + pass def empty_cache(self): try: diff --git a/GPT_SoVITS/TTS_infer_pack/zero_crossing.py b/GPT_SoVITS/TTS_infer_pack/zero_crossing.py new file mode 100644 index 00000000..542a6de9 --- /dev/null +++ b/GPT_SoVITS/TTS_infer_pack/zero_crossing.py @@ -0,0 +1,203 @@ +import numpy as np +import wave +import struct + +def read_wav_file(filename): + """ + Reads a WAV file and returns the sample rate and data as a numpy array. + """ + with wave.open(filename, 'rb') as wf: + sample_rate = wf.getframerate() + n_frames = wf.getnframes() + sample_width = wf.getsampwidth() + n_channels = wf.getnchannels() + + audio_data = wf.readframes(n_frames) + # Determine the format string for struct unpacking + fmt = "<" + {1:'b', 2:'h', 4:'i'}[sample_width] * n_frames * n_channels + audio_samples = struct.unpack(fmt, audio_data) + audio_array = np.array(audio_samples, dtype=int) + + # If stereo, reshape the array + if n_channels > 1: + audio_array = audio_array.reshape(-1, n_channels) + return sample_rate, audio_array, sample_width, n_channels + +def write_wav_file(filename, sample_rate, data, sample_width, n_channels): + """ + Writes numpy array data to a WAV file. + """ + with wave.open(filename, 'wb') as wf: + wf.setnchannels(n_channels) + wf.setsampwidth(sample_width) + wf.setframerate(sample_rate) + # Flatten the array if it's multi-dimensional + if data.ndim > 1: + data = data.flatten() + # Pack the data into bytes + fmt = "<" + {1:'b', 2:'h', 4:'i'}[sample_width] * len(data) + byte_data = struct.pack(fmt, *data) + wf.writeframes(byte_data) + +def find_zero_zone(chunk, start_index, search_length, num_zeroes=11): + zone = chunk[start_index:start_index + search_length] + print(f"Zero-crossing search zone: Start={start_index}, Length={len(zone)}") + + zero_threshold = 1.0e-4 + # Check for y consecutive zeros + for idx in range(len(zone), -1 + num_zeroes, -1): + index_to_start = idx-num_zeroes + abs_zone = np.abs(zone[index_to_start:idx]) + if np.all(abs_zone < zero_threshold): + index_midpoint = index_to_start + int(num_zeroes // 2) + return (start_index + index_midpoint), None + + print("Falling back to zero crossing due to no zero zone found. You may hear more prominent pops and clicks in the audio. Try increasing search length or cumulative tokens.") + return find_zero_crossing(chunk, start_index, search_length) + +def find_zero_crossing(chunk, start_index, search_length): + # If the model is falling back on the this function, it might be a bad indicator that the search length is too low + + zone = chunk[start_index:start_index + search_length] + sign_changes = np.where(np.diff(np.sign(zone)) != 0)[0] + + if len(sign_changes) == 0: + raise ("No zero-crossings found in this zone. This should not be happening, debugging time.") + else: + zc_index = start_index + sign_changes[0] + 1 + print(f"Zero-crossing found at index {zc_index}") + # Determine the crossing direction in chunk1 + prev_value = chunk[zc_index - 1] + curr_value = chunk[zc_index] + crossing_direction = np.sign(curr_value) - np.sign(prev_value) + print(f"Crossing direction in chunk1: {np.sign(prev_value)} to {np.sign(curr_value)}") + return zc_index, crossing_direction + +def find_matching_index(chunk, center_index, max_offset, crossing_direction): + """ + Finds a zero-crossing in data that matches the specified crossing direction, + starting from center_index and searching outward. + """ + if crossing_direction == None: + return center_index # if zero zone + + # fall back for zero_crossing + data_length = len(chunk) + print(f"Center index in chunk2: {center_index}") + for offset in range(max_offset + 1): + # Check index bounds + idx_forward = center_index + offset + idx_backward = center_index - offset + found = False + + # Check forward direction + if idx_forward < data_length - 1: + prev_sign = np.sign(chunk[idx_forward]) + curr_sign = np.sign(chunk[idx_forward + 1]) + direction = curr_sign - prev_sign + if direction == crossing_direction: + print(f"Matching zero-crossing found at index {idx_forward + 1} (forward)") + return idx_forward + 1 + + # Check backward direction + if idx_backward > 0: + prev_sign = np.sign(chunk[idx_backward - 1]) + curr_sign = np.sign(chunk[idx_backward]) + direction = curr_sign - prev_sign + if direction == crossing_direction: + print(f"Matching zero-crossing found at index {idx_backward} (backward)") + return idx_backward + + print("No matching zero-crossings found in this zone.") + return None + +# legacy, just for history. delete me sometime +def splice_chunks(chunk1, chunk2, search_length, y): + """ + Splices two audio chunks at zero-crossing points. + """ + # Define the zone to search in chunk1 + start_index1 = len(chunk1) - search_length + if start_index1 < 0: + start_index1 = 0 + search_length = len(chunk1) + print(f"Searching for zero-crossing in chunk1 from index {start_index1} to {len(chunk1)}") + # Find zero-crossing in chunk1 + zc_index1, crossing_direction = find_zero_crossing(chunk1, start_index1, search_length, y) + if zc_index1 is None: + print("No zero-crossing found in chunk1 within the specified zone.") + return None + + # Define the zone to search in chunk2 near the same index + # Since chunk2 overlaps with chunk1, we can assume that index positions correspond + # Adjusted search in chunk2 + # You can adjust this value if needed + center_index = zc_index1 # Assuming alignment between chunk1 and chunk2 + max_offset = search_length + + # Ensure center_index is within bounds + if center_index < 0: + center_index = 0 + elif center_index >= len(chunk2): + center_index = len(chunk2) - 1 + + print(f"Searching for matching zero-crossing in chunk2 around index {center_index} with max offset {max_offset}") + + zc_index2 = find_matching_zero_crossing(chunk2, center_index, max_offset, crossing_direction) + + if zc_index2 is None: + print("No matching zero-crossing found in chunk2.") + return None + + print(f"Zero-crossing in chunk1 at index {zc_index1}, chunk2 at index {zc_index2}") + # Splice the chunks + new_chunk = np.concatenate((chunk1[:zc_index1], chunk2[zc_index2:])) + print(f"Spliced chunk length: {len(new_chunk)}") + return new_chunk + +# legacy, just for history. delete me sometime +def process_audio_chunks(filenames, sample_rate, x, y, output_filename): + """ + Processes and splices a list of audio chunks. + """ + # Read the first chunk + sr, chunk_data, sample_width, n_channels = read_wav_file(filenames[0]) + if sr != sample_rate: + print(f"Sample rate mismatch in {filenames[0]}") + return + print(f"Processing {filenames[0]}") + # Initialize the combined audio with the first chunk + combined_audio = chunk_data + # Process remaining chunks + for filename in filenames[1:]: + sr, next_chunk_data, _, _ = read_wav_file(filename) + if sr != sample_rate: + print(f"Sample rate mismatch in {filename}") + return + print(f"Processing {filename}") + # Splice the current combined audio with the next chunk + new_combined = splice_chunks(combined_audio, next_chunk_data, x, y) + if new_combined is None: + print(f"Failed to splice chunks between {filename} and previous chunk.") + return + combined_audio = new_combined + # Write the final combined audio to output file + write_wav_file(output_filename, sample_rate, combined_audio, sample_width, n_channels) + print(f"Final audio saved to {output_filename}") + +# Main execution +if __name__ == "__main__": + # User-specified parameters + sample_rate = 32000 # Sample rate in Hz + x = 500 # Number of frames to search from the end of the chunk + y = 10 # Number of consecutive zeros to look for + output_filename = "combined_output.wav" + folder_with_chunks = "output_chunks" + import os + def absolute_file_paths(directory): + path = os.path.abspath(directory) + return [entry.path for entry in os.scandir(path) if entry.is_file()] + # List of input audio chunk filenames in sequential order + filenames = absolute_file_paths(folder_with_chunks) + # Process and splice the audio chunks + process_audio_chunks(filenames, sample_rate, x, y, output_filename) diff --git a/GPT_SoVITS/api_v2.py b/GPT_SoVITS/api_v2.py index 5dfbebec..84e6ffd7 100644 --- a/GPT_SoVITS/api_v2.py +++ b/GPT_SoVITS/api_v2.py @@ -358,7 +358,7 @@ async def tts_get_endpoint( top_p:float = 1, temperature:float = 1, text_split_method:str = "cut0", - batch_size:int = 1, + batch_size:int = 4, batch_threshold:float = 0.75, split_bucket:bool = True, speed_factor:float = 1.0, diff --git a/infer_script.py b/infer_script.py new file mode 100644 index 00000000..4bb38f8c --- /dev/null +++ b/infer_script.py @@ -0,0 +1,272 @@ +''' +This is just an example inference script to test batching with llama, mainly for my reference in the future. +''' + +import os +import sys +import numpy as np +import soundfile as sf +import threading +import queue +import sounddevice as sd +import time +import speech_recognition as sr + +# Ensure that GPT_SoVITS is in the Python path +now_dir = os.getcwd() +sys.path.append(now_dir) +sys.path.append(os.path.join(now_dir, 'GPT_SoVITS')) +os.environ['CUDA_LAUNCH_BLOCKING'] = '1' + + +from GPT_SoVITS.TTS_infer_pack.TTS import TTS, TTS_Config + +from llama_cpp import Llama +import sys + +# Initialize the Llama model +llm = Llama( + model_path="ggml-model-q8_0.gguf", + n_gpu_layers=-1, # Uncomment to use GPU acceleration + seed=1337, # Uncomment to set a specific seed + n_ctx=2048, # Uncomment to increase the context window + chat_format="llama-3", + verbose=False +) + +from time import time + +def generate_chat_completion_openai_v1_stream(messages): + start = time() + stream = llm.create_chat_completion_openai_v1( + messages=messages, + temperature=0.8, # Adjust temperature as needed + top_p=0.95, # Adjust top_p as needed + top_k=40, # Adjust top_k as needed + max_tokens=50, # Adjust the maximum number of tokens as needed + # stop=["\n"], # Adjust the stop sequence as needed + stream=True # Enable streaming + ) + end = time() + total = end - start + print(total) + for chunk in stream: + if chunk.choices[0].delta.content is not None: + yield chunk.choices[0].delta.content + +def audio_playback_thread(audio_queue, sample_rate): + """ + Audio playback thread that plays audio fragments from the queue. + """ + sd.default.samplerate = sample_rate + sd.default.channels = 1 + stream = sd.OutputStream(dtype='float32') + stream.start() + + try: + while True: + # Get the next audio fragment + audio_fragment = audio_queue.get() + try: + if audio_fragment is None: + # Sentinel value received, exit the loop + break + # Write the audio fragment to the stream + stream.write(audio_fragment) + finally: + # Mark the item as processed + audio_queue.task_done() + finally: + stream.stop() + stream.close() + +def main(): + + config_path = 'configs/tts_infer.yaml' + # GPT_model_path = 'pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt' + GPT_model_path = 'custom_trained.ckpt' + # SoVITS_model_path = 'pretrained_models/gsv-v2final-pretrained/s2G2333k.pth' + SoVITS_model_path = 'custom_trained.pth' + ref_audio_path = 'ref_audio.wav' + ref_text = 'でもなんか対処法ではないよなこれ対処法ではないけどそもそもの話みたいなことを言ってんのか' + target_text = """hahahaha, well well, let me tell you about that! it was perhaps the most exquisite day of my life! Phew, I've never had one better! """ + output_path = 'output' + ref_language = 'ja' + target_language = 'ja' + + + # Ensure output directory exists + os.makedirs(output_path, exist_ok=True) + + # Initialize TTS configuration and pipeline + tts_config = TTS_Config(config_path) + tts_pipeline = TTS(tts_config) + + # Load model weights + tts_pipeline.init_t2s_weights(GPT_model_path) + tts_pipeline.init_vits_weights(SoVITS_model_path) + + # Prepare inputs for TTS + inputs = { + "text": target_text, + "text_lang": target_language.lower(), + "ref_audio_path": ref_audio_path, + "prompt_text": ref_text, + "prompt_lang": ref_language.lower(), + "top_k": 5, + "top_p": 1.0, + "temperature": 1.0, + "text_split_method": "cut0", + "batch_size": 1, + "batch_threshold": 0.75, + "split_bucket": True, + "speed_factor": 1.0, + "fragment_interval": 0.3, + "seed": 2855904637, + "return_fragment": True, + "parallel_infer": False, + "repetition_penalty": 1.35, + } + + # Run TTS inference + + system_message = '''You are a friendly AI named Vivy. + + HOW YOU SHOULD RESPOND: + - The responses should include only verbal responses, for example *laughs* should be replaced with haha + ''' + + # Initialize conversation history with system message + conversation_history = [ + {"role": "system", "content": f"{system_message}"} + ] + + # Create a queue for audio fragments + audio_queue = queue.Queue(maxsize=100) # Adjust maxsize based on your needs + + # Start the audio playback thread + playback_thread = threading.Thread( + target=audio_playback_thread, + args=(audio_queue, tts_pipeline.configs.sampling_rate) + ) + playback_thread.start() + + # Setup speech recognition + r = sr.Recognizer() + mic = sr.Microphone() + + try: + while True: + # Prompt for speech input instead of text input + while True: + print("\nPlease speak your message (say 'quit' to exit):") + with mic as source: + # Adjust for ambient noise to improve recognition accuracy + r.adjust_for_ambient_noise(source, duration=1.0) + print("Listening...") + audio_data = r.listen(source, timeout=None, phrase_time_limit=60) + try: + # Replace 'recognize_whisper' with your actual recognition method + # Ensure that the method is correctly implemented or available + user_input = r.recognize_whisper(audio_data=audio_data, model="base") + print("You said: " + user_input) + + # Check if the input is not empty or just whitespace + if user_input.strip() == "": + print("No speech detected. Please try again.") + continue # Continue listening + break # Valid input received, exit inner loop + except sr.UnknownValueError: + print("Sorry, I could not understand the audio. Please try again.") + continue # Continue listening + except sr.RequestError as e: + print(f"Could not request results from speech recognition service; {e}") + continue # Continue listening + + # Check if the user wants to quit + if user_input.lower() == "quit": + print("Exiting the application. Goodbye!") + sys.exit() + + # Append user message to conversation history + conversation_history.append({"role": "user", "content": user_input}) + + # Initialize variables to track character count and buffering + buffer = "" + char_count = 0 + waiting_for_punctuation = False + assistant_buffer = "" + + # Generate and print the chat completion with streaming + for token in generate_chat_completion_openai_v1_stream(conversation_history): + print(token, end="", flush=True) # Print each character as it's generated + buffer += token + assistant_buffer += token + char_count += len(token) + + if not waiting_for_punctuation: + if char_count >= 100: + waiting_for_punctuation = True # Start looking for punctuation + else: + if any(punct in token for punct in ['.', '!', '?']): + # Send the buffer to TTS + inputs["text"] = buffer + synthesis_result = tts_pipeline.run_generator(inputs) + # Consume the generator and put audio fragments into the queue + for sampling_rate, audio_fragment in synthesis_result: + audio_queue.put(audio_fragment) + #put sielnce into audio queue after tts sythesis generator has finished + silence_duration = 0.5 # in seconds + num_samples = int(sampling_rate * silence_duration) + silence = np.zeros(num_samples, dtype='float32') + audio_queue.put(silence) + + # Reset counters and buffer + char_count = 0 + buffer = "" + waiting_for_punctuation = False + + # Append assistant message to conversation history + conversation_history.append({"role": "assistant", "content": assistant_buffer}) + + # Handle any remaining text after the generator is done + if buffer.strip(): + inputs["text"] = buffer + synthesis_result = tts_pipeline.run_generator(inputs) + + # Consume the generator and put audio fragments into the queue + for sampling_rate, audio_fragment in synthesis_result: + audio_queue.put(audio_fragment) + #put sielnce into audio queue after tts sythesis generator has finished + silence_duration = 0.5 # in seconds + num_samples = int(sampling_rate * silence_duration) + silence = np.zeros(num_samples, dtype='float32') + audio_queue.put(silence) + + conversation_history.append({"role": "assistant", "content": buffer}) + buffer = "" + char_count = 0 + waiting_for_punctuation = False + finally: + # After all processing is done, send a sentinel to the audio queue and wait for threads to finish + audio_queue.put(None) + audio_queue.join() + playback_thread.join() + + + # text = input("GO:") + # inputs["text"] = text + # synthesis_result = tts_pipeline.run_generator(inputs) + # audio_data_list = list(synthesis_result) + # if audio_data_list: + # # Since return_fragment is False, we expect only one tuple in audio_data_list + # sampling_rate, audio_data = audio_data_list[0] + # output_wav_path = os.path.join(output_path, "output.wav") + # # Save the audio data to a WAV file + # sf.write(output_wav_path, audio_data, sampling_rate) + # print(f"Audio saved to {output_wav_path}") + # else: + # print("No audio data generated.") + +if __name__ == '__main__': + main()