Add token streaming in batches to the TTS class.

This commit is contained in:
Jarod Mica 2024-12-23 00:11:17 -08:00
parent 38218e794d
commit 251160362f
4 changed files with 893 additions and 3 deletions

View File

@ -13,6 +13,10 @@ import torch
import torch.nn.functional as F
import traceback
import yaml
import queue
import sounddevice as sd
import soundfile as sf
import threading
from huggingface_hub import snapshot_download, hf_hub_download
from importlib.resources import files
@ -876,8 +880,63 @@ class TTS:
refer_audio_spec:torch.Tensor = [item.to(dtype=self.precision, device=self.configs.device) for item in self.prompt_cache["refer_spec"]]
# Split the semantic tokens into chunks
num_chunks = 10 # Number of chunks to split into
chunked_pred_semantic_list = [] # This will store the chunks for each sample
batch_audio_fragment = []
pred_semantic_list = [item[-idx:] for item, idx in zip(pred_semantic_list, idx_list)]
for semantic_tokens in pred_semantic_list:
total_length = semantic_tokens.shape[0]
chunk_size = total_length // num_chunks
chunks = []
for i in range(num_chunks):
overlap = 0
# samples_per_token = 1280
# sample_rate = 32000
# start_idx = i * chunk_size - overlap
start_idx = 0
# so each subsequent sample is overlapping by 5120 samples
if start_idx < 0:
start_idx = 0
if i == num_chunks - 1:
# Make sure to include the remainder in the last chunk
end_idx = total_length
else:
end_idx = (i + 1) * chunk_size
chunk = semantic_tokens[start_idx:end_idx]
chunks.append(chunk)
chunked_pred_semantic_list.append(chunks)
# Process chunks through VITS
batch_audio_chunks = [] # List to hold audio chunks for each sample
for i, (chunks, phones) in enumerate(zip(chunked_pred_semantic_list, batch_phones)):
phones = phones.unsqueeze(0).to(self.configs.device)
audio_chunks = []
for chunk in chunks:
# Prepare the chunk for VITS
chunk = chunk.unsqueeze(0).unsqueeze(0).to(self.configs.device)
# Process the chunk through VITS
audio_fragment = self.vits_model.decode(
chunk, phones, refer_audio_spec, speed=speed_factor
).detach()[0, 0, :]
audio_chunks.append(audio_fragment.cpu().numpy())
batch_audio_chunks.append(audio_chunks)
output_dir = 'output_chunks'
os.makedirs(output_dir, exist_ok=True)
for sample_idx, audio_chunks in enumerate(batch_audio_chunks):
for chunk_idx, audio_chunk in enumerate(audio_chunks):
# Convert audio_chunk to float32
audio_chunk = audio_chunk.astype(np.float32)
# Create a filename for each chunk
filename = f'sample_{sample_idx}_chunk_{chunk_idx}.wav'
output_path = os.path.join(output_dir, filename)
# Save the audio chunk
sf.write(output_path, audio_chunk, self.configs.sampling_rate)
print(f'Saved audio chunk: {output_path}')
# ## vits并行推理 method 1
# pred_semantic_list = [item[-idx:] for item, idx in zip(pred_semantic_list, idx_list)]
@ -966,6 +1025,362 @@ class TTS:
finally:
self.empty_cache()
@torch.no_grad()
def run_generator(self, inputs:dict):
"""
Text to speech inference.
Args:
inputs (dict):
{
"text": "", # str.(required) text to be synthesized
"text_lang: "", # str.(required) language of the text to be synthesized
"ref_audio_path": "", # str.(required) reference audio path
"aux_ref_audio_paths": [], # list.(optional) auxiliary reference audio paths for multi-speaker tone fusion
"prompt_text": "", # str.(optional) prompt text for the reference audio
"prompt_lang": "", # str.(required) language of the prompt text for the reference audio
"top_k": 5, # int. top k sampling
"top_p": 1, # float. top p sampling
"temperature": 1, # float. temperature for sampling
"text_split_method": "cut0", # str. text split method, see text_segmentation_method.py for details.
"batch_size": 1, # int. batch size for inference
"batch_threshold": 0.75, # float. threshold for batch splitting.
"split_bucket: True, # bool. whether to split the batch into multiple buckets.
"return_fragment": False, # bool. step by step return the audio fragment.
"speed_factor":1.0, # float. control the speed of the synthesized audio.
"fragment_interval":0.3, # float. to control the interval of the audio fragment.
"seed": -1, # int. random seed for reproducibility.
"parallel_infer": True, # bool. whether to use parallel inference.
"repetition_penalty": 1.35 # float. repetition penalty for T2S model.
}
returns:
Tuple[int, np.ndarray]: sampling rate and audio data.
"""
########## variables initialization ###########
self.stop_flag:bool = False
text:str = inputs.get("text", "")
text_lang:str = inputs.get("text_lang", "")
ref_audio_path:str = inputs.get("ref_audio_path", "")
aux_ref_audio_paths:list = inputs.get("aux_ref_audio_paths", [])
prompt_text:str = inputs.get("prompt_text", "")
prompt_lang:str = inputs.get("prompt_lang", "")
top_k:int = inputs.get("top_k", 5)
top_p:float = inputs.get("top_p", 1)
temperature:float = inputs.get("temperature", 1)
text_split_method:str = inputs.get("text_split_method", "cut0")
batch_size = inputs.get("batch_size", 1)
batch_threshold = inputs.get("batch_threshold", 0.75)
speed_factor = inputs.get("speed_factor", 1.0)
split_bucket = inputs.get("split_bucket", True)
return_fragment = inputs.get("return_fragment", False)
fragment_interval = inputs.get("fragment_interval", 0.3)
seed = inputs.get("seed", -1)
seed = -1 if seed in ["", None] else seed
actual_seed = set_seed(seed)
parallel_infer = inputs.get("parallel_infer", True)
repetition_penalty = inputs.get("repetition_penalty", 1.35)
if parallel_infer:
print(i18n("并行推理模式已开启"))
self.t2s_model.model.infer_panel = self.t2s_model.model.infer_panel_batch_infer
else:
print(i18n("并行推理模式已关闭"))
self.t2s_model.model.infer_panel = self.t2s_model.model.infer_panel_naive_batched
if return_fragment:
print(i18n("分段返回模式已开启"))
if split_bucket:
split_bucket = False
print(i18n("分段返回模式不支持分桶处理,已自动关闭分桶处理"))
if split_bucket and speed_factor==1.0:
print(i18n("分桶处理模式已开启"))
elif speed_factor!=1.0:
print(i18n("语速调节不支持分桶处理,已自动关闭分桶处理"))
split_bucket = False
else:
print(i18n("分桶处理模式已关闭"))
if fragment_interval<0.01:
fragment_interval = 0.01
print(i18n("分段间隔过小已自动设置为0.01"))
no_prompt_text = False
if prompt_text in [None, ""]:
no_prompt_text = True
assert text_lang in self.configs.languages
if not no_prompt_text:
assert prompt_lang in self.configs.languages
if ref_audio_path in [None, ""] and \
((self.prompt_cache["prompt_semantic"] is None) or (self.prompt_cache["refer_spec"] in [None, []])):
raise ValueError("ref_audio_path cannot be empty, when the reference audio is not set using set_ref_audio()")
###### setting reference audio and prompt text preprocessing ########
t0 = ttime()
if (ref_audio_path is not None) and (ref_audio_path != self.prompt_cache["ref_audio_path"]):
if not os.path.exists(ref_audio_path):
raise ValueError(f"{ref_audio_path} not exists")
self.set_ref_audio(ref_audio_path)
aux_ref_audio_paths = aux_ref_audio_paths if aux_ref_audio_paths is not None else []
paths = set(aux_ref_audio_paths)&set(self.prompt_cache["aux_ref_audio_paths"])
if not (len(list(paths)) == len(aux_ref_audio_paths) == len(self.prompt_cache["aux_ref_audio_paths"])):
self.prompt_cache["aux_ref_audio_paths"] = aux_ref_audio_paths
self.prompt_cache["refer_spec"] = [self.prompt_cache["refer_spec"][0]]
for path in aux_ref_audio_paths:
if path in [None, ""]:
continue
if not os.path.exists(path):
print(i18n("音频文件不存在,跳过:{}").format(path))
continue
self.prompt_cache["refer_spec"].append(self._get_ref_spec(path))
if not no_prompt_text:
prompt_text = prompt_text.strip("\n")
if (prompt_text[-1] not in splits): prompt_text += "" if prompt_lang != "en" else "."
print(i18n("实际输入的参考文本:"), prompt_text)
if self.prompt_cache["prompt_text"] != prompt_text:
self.prompt_cache["prompt_text"] = prompt_text
self.prompt_cache["prompt_lang"] = prompt_lang
phones, bert_features, norm_text = \
self.text_preprocessor.segment_and_extract_feature_for_text(
prompt_text,
prompt_lang,
self.configs.version)
self.prompt_cache["phones"] = phones
self.prompt_cache["bert_features"] = bert_features
self.prompt_cache["norm_text"] = norm_text
###### text preprocessing ########
t1 = ttime()
data:list = None
if not return_fragment:
data = self.text_preprocessor.preprocess(text, text_lang, text_split_method, self.configs.version)
if len(data) == 0:
yield self.configs.sampling_rate, np.zeros(int(self.configs.sampling_rate),
dtype=np.int16)
return
batch_index_list:list = None
data, batch_index_list = self.to_batch(data,
prompt_data=self.prompt_cache if not no_prompt_text else None,
batch_size=batch_size,
threshold=batch_threshold,
split_bucket=split_bucket,
device=self.configs.device,
precision=self.precision
)
else:
print(i18n("############ 切分文本 ############"))
texts = self.text_preprocessor.pre_seg_text(text, text_lang, text_split_method)
data = []
for i in range(len(texts)):
if i%batch_size == 0:
data.append([])
data[-1].append(texts[i])
def make_batch(batch_texts):
batch_data = []
print(i18n("############ 提取文本Bert特征 ############"))
for text in tqdm(batch_texts):
phones, bert_features, norm_text = self.text_preprocessor.segment_and_extract_feature_for_text(text, text_lang, self.configs.version)
if phones is None:
continue
res={
"phones": phones,
"bert_features": bert_features,
"norm_text": norm_text,
}
batch_data.append(res)
if len(batch_data) == 0:
return None
batch, _ = self.to_batch(batch_data,
prompt_data=self.prompt_cache if not no_prompt_text else None,
batch_size=batch_size,
threshold=batch_threshold,
split_bucket=False,
device=self.configs.device,
precision=self.precision
)
return batch[0]
t2 = ttime()
try:
print("############ 推理 ############")
###### inference ######
t_34 = 0.0
t_45 = 0.0
audio = []
for item in data:
t3 = ttime()
if return_fragment:
item = make_batch(item)
if item is None:
continue
batch_phones:List[torch.LongTensor] = item["phones"]
# batch_phones:torch.LongTensor = item["phones"]
batch_phones_len:torch.LongTensor = item["phones_len"]
all_phoneme_ids:torch.LongTensor = item["all_phones"]
all_phoneme_lens:torch.LongTensor = item["all_phones_len"]
all_bert_features:torch.LongTensor = item["all_bert_features"]
norm_text:str = item["norm_text"]
max_len = item["max_len"]
print(i18n("前端处理后的文本(每句):"), norm_text)
if no_prompt_text :
prompt = None
else:
prompt = self.prompt_cache["prompt_semantic"].expand(len(all_phoneme_ids), -1).to(self.configs.device)
refer_audio_spec:torch.Tensor = [item.to(dtype=self.precision, device=self.configs.device) for item in self.prompt_cache["refer_spec"]]
generated_tokens_list = []
start = ttime()
from GPT_SoVITS.TTS_infer_pack.zero_crossing import find_zero_zone, find_matching_index
zc_index1 = 0
zc_index2 = 0
crossing_direction = 0
first_chunk = True
last_chunk = False
search_length = 32000*5
num_zeroes = 5
cumulation_amount=50
# Use infer_panel_generator to generate tokens in batches
for generated_tokens in self.t2s_model.model.infer_panel_generator(
all_phoneme_ids[0].unsqueeze(0),
all_phoneme_lens[0].unsqueeze(0),
prompt[0].unsqueeze(0) if prompt is not None else None,
all_bert_features[0].unsqueeze(0),
cumulation_amount=cumulation_amount,
top_k=top_k,
top_p=top_p,
temperature=temperature,
early_stop_num=self.configs.hz * self.configs.max_sec,
max_len=max_len,
repetition_penalty=repetition_penalty,
):
# Append the generated tokens
generated_tokens_list.append(generated_tokens)
total_tokens = sum([tokens.size(1) for tokens in generated_tokens_list])
tokens_to_process = torch.cat(generated_tokens_list, dim=1)[:, :total_tokens] # uses full context for decoding
# Check if tokens_to_process contains the EOS token (1024)
contains_eos = (tokens_to_process == 1024).any()
if contains_eos:
# Replace all instances of the EOS token (1024) with 0
tokens_to_process = tokens_to_process.masked_fill(tokens_to_process == 1024, 0)
print("Replaced EOS token (1024) with 0 in tokens_to_process")
last_chunk = True
first_chunk = False
# Prepare input for VITS model
_pred_semantic = tokens_to_process.unsqueeze(0)
phones = batch_phones[0].unsqueeze(0).to(self.configs.device)
# Generate audio for the tokens
audio_output = self.vits_model.decode(
_pred_semantic, phones, refer_audio_spec, speed=speed_factor
).detach()[0, 0, :]
audio_output = audio_output[:].cpu().numpy()
# Convert audio_fragment to float32 and normalize
audio_output = audio_output.astype(np.float32)
max_val = np.abs(audio_output).max()
if max_val > 1.0:
audio_output /= max_val
start_index = len(audio_output) - search_length
if start_index < 0:
search_length = len(audio_output)
print(f"search_length is too HIGH! Auto adjusted to {search_length} frames as the chunks are only {len(audio_output)} frames large")
start_index = 0
center_index = zc_index2 # Start from previous zero crossing index and search outwards
max_offset = int(search_length // 2) # branches out in both ways
if center_index < 0:
raise "Something wrong is going on here, center index issue, less than 0"
elif center_index >= len(audio_output):
raise "Something wrong is going on here, center index issue, greater than audio_output"
if first_chunk:
zc_index1, crossing_direction = find_zero_zone(
chunk=audio_output,
start_index=start_index,
search_length=search_length,
num_zeroes=num_zeroes
)
audio_fragment = audio_output[:zc_index1]
yield self.configs.sampling_rate, audio_fragment
first_chunk = False
zc_index2 = zc_index1
elif last_chunk:
zc_index1 = find_matching_index(
chunk=audio_output,
center_index=center_index,
max_offset=max_offset,
crossing_direction=crossing_direction
)
audio_fragment = audio_output[zc_index1:]
yield self.configs.sampling_rate, audio_fragment
else:
zc_index1 = find_matching_index(
chunk=audio_output,
center_index=center_index,
max_offset=max_offset,
crossing_direction=crossing_direction
)
zc_index2, crossing_direction = find_zero_zone(
chunk=audio_output,
start_index=start_index,
search_length=search_length,
num_zeroes=num_zeroes
)
audio_fragment = audio_output[zc_index1:zc_index2]
yield self.configs.sampling_rate, audio_fragment
end = ttime()
print(f"Time to speech: {end-start}")
except Exception as e:
traceback.print_exc()
# 必须返回一个空音频, 否则会导致显存不释放。
yield self.configs.sampling_rate, np.zeros(int(self.configs.sampling_rate),
dtype=np.int16)
# 重置模型, 否则会导致显存释放不完全。
del self.t2s_model
del self.vits_model
self.t2s_model = None
self.vits_model = None
self.init_t2s_weights(self.configs.t2s_weights_path)
self.init_vits_weights(self.configs.vits_weights_path)
raise e
finally:
self.empty_cache()
def empty_cache(self):
try:
gc.collect() # 触发gc的垃圾回收。避免内存一直增长。
if "cuda" in str(self.configs.device):
torch.cuda.empty_cache()
elif str(self.configs.device) == "mps":
torch.mps.empty_cache()
except:
pass
def empty_cache(self):
try:
gc.collect() # 触发gc的垃圾回收。避免内存一直增长。

View File

@ -0,0 +1,203 @@
import numpy as np
import wave
import struct
def read_wav_file(filename):
"""
Reads a WAV file and returns the sample rate and data as a numpy array.
"""
with wave.open(filename, 'rb') as wf:
sample_rate = wf.getframerate()
n_frames = wf.getnframes()
sample_width = wf.getsampwidth()
n_channels = wf.getnchannels()
audio_data = wf.readframes(n_frames)
# Determine the format string for struct unpacking
fmt = "<" + {1:'b', 2:'h', 4:'i'}[sample_width] * n_frames * n_channels
audio_samples = struct.unpack(fmt, audio_data)
audio_array = np.array(audio_samples, dtype=int)
# If stereo, reshape the array
if n_channels > 1:
audio_array = audio_array.reshape(-1, n_channels)
return sample_rate, audio_array, sample_width, n_channels
def write_wav_file(filename, sample_rate, data, sample_width, n_channels):
"""
Writes numpy array data to a WAV file.
"""
with wave.open(filename, 'wb') as wf:
wf.setnchannels(n_channels)
wf.setsampwidth(sample_width)
wf.setframerate(sample_rate)
# Flatten the array if it's multi-dimensional
if data.ndim > 1:
data = data.flatten()
# Pack the data into bytes
fmt = "<" + {1:'b', 2:'h', 4:'i'}[sample_width] * len(data)
byte_data = struct.pack(fmt, *data)
wf.writeframes(byte_data)
def find_zero_zone(chunk, start_index, search_length, num_zeroes=11):
zone = chunk[start_index:start_index + search_length]
print(f"Zero-crossing search zone: Start={start_index}, Length={len(zone)}")
zero_threshold = 1.0e-4
# Check for y consecutive zeros
for idx in range(len(zone), -1 + num_zeroes, -1):
index_to_start = idx-num_zeroes
abs_zone = np.abs(zone[index_to_start:idx])
if np.all(abs_zone < zero_threshold):
index_midpoint = index_to_start + int(num_zeroes // 2)
return (start_index + index_midpoint), None
print("Falling back to zero crossing due to no zero zone found. You may hear more prominent pops and clicks in the audio. Try increasing search length or cumulative tokens.")
return find_zero_crossing(chunk, start_index, search_length)
def find_zero_crossing(chunk, start_index, search_length):
# If the model is falling back on the this function, it might be a bad indicator that the search length is too low
zone = chunk[start_index:start_index + search_length]
sign_changes = np.where(np.diff(np.sign(zone)) != 0)[0]
if len(sign_changes) == 0:
raise ("No zero-crossings found in this zone. This should not be happening, debugging time.")
else:
zc_index = start_index + sign_changes[0] + 1
print(f"Zero-crossing found at index {zc_index}")
# Determine the crossing direction in chunk1
prev_value = chunk[zc_index - 1]
curr_value = chunk[zc_index]
crossing_direction = np.sign(curr_value) - np.sign(prev_value)
print(f"Crossing direction in chunk1: {np.sign(prev_value)} to {np.sign(curr_value)}")
return zc_index, crossing_direction
def find_matching_index(chunk, center_index, max_offset, crossing_direction):
"""
Finds a zero-crossing in data that matches the specified crossing direction,
starting from center_index and searching outward.
"""
if crossing_direction == None:
return center_index # if zero zone
# fall back for zero_crossing
data_length = len(chunk)
print(f"Center index in chunk2: {center_index}")
for offset in range(max_offset + 1):
# Check index bounds
idx_forward = center_index + offset
idx_backward = center_index - offset
found = False
# Check forward direction
if idx_forward < data_length - 1:
prev_sign = np.sign(chunk[idx_forward])
curr_sign = np.sign(chunk[idx_forward + 1])
direction = curr_sign - prev_sign
if direction == crossing_direction:
print(f"Matching zero-crossing found at index {idx_forward + 1} (forward)")
return idx_forward + 1
# Check backward direction
if idx_backward > 0:
prev_sign = np.sign(chunk[idx_backward - 1])
curr_sign = np.sign(chunk[idx_backward])
direction = curr_sign - prev_sign
if direction == crossing_direction:
print(f"Matching zero-crossing found at index {idx_backward} (backward)")
return idx_backward
print("No matching zero-crossings found in this zone.")
return None
# legacy, just for history. delete me sometime
def splice_chunks(chunk1, chunk2, search_length, y):
"""
Splices two audio chunks at zero-crossing points.
"""
# Define the zone to search in chunk1
start_index1 = len(chunk1) - search_length
if start_index1 < 0:
start_index1 = 0
search_length = len(chunk1)
print(f"Searching for zero-crossing in chunk1 from index {start_index1} to {len(chunk1)}")
# Find zero-crossing in chunk1
zc_index1, crossing_direction = find_zero_crossing(chunk1, start_index1, search_length, y)
if zc_index1 is None:
print("No zero-crossing found in chunk1 within the specified zone.")
return None
# Define the zone to search in chunk2 near the same index
# Since chunk2 overlaps with chunk1, we can assume that index positions correspond
# Adjusted search in chunk2
# You can adjust this value if needed
center_index = zc_index1 # Assuming alignment between chunk1 and chunk2
max_offset = search_length
# Ensure center_index is within bounds
if center_index < 0:
center_index = 0
elif center_index >= len(chunk2):
center_index = len(chunk2) - 1
print(f"Searching for matching zero-crossing in chunk2 around index {center_index} with max offset {max_offset}")
zc_index2 = find_matching_zero_crossing(chunk2, center_index, max_offset, crossing_direction)
if zc_index2 is None:
print("No matching zero-crossing found in chunk2.")
return None
print(f"Zero-crossing in chunk1 at index {zc_index1}, chunk2 at index {zc_index2}")
# Splice the chunks
new_chunk = np.concatenate((chunk1[:zc_index1], chunk2[zc_index2:]))
print(f"Spliced chunk length: {len(new_chunk)}")
return new_chunk
# legacy, just for history. delete me sometime
def process_audio_chunks(filenames, sample_rate, x, y, output_filename):
"""
Processes and splices a list of audio chunks.
"""
# Read the first chunk
sr, chunk_data, sample_width, n_channels = read_wav_file(filenames[0])
if sr != sample_rate:
print(f"Sample rate mismatch in {filenames[0]}")
return
print(f"Processing {filenames[0]}")
# Initialize the combined audio with the first chunk
combined_audio = chunk_data
# Process remaining chunks
for filename in filenames[1:]:
sr, next_chunk_data, _, _ = read_wav_file(filename)
if sr != sample_rate:
print(f"Sample rate mismatch in {filename}")
return
print(f"Processing {filename}")
# Splice the current combined audio with the next chunk
new_combined = splice_chunks(combined_audio, next_chunk_data, x, y)
if new_combined is None:
print(f"Failed to splice chunks between {filename} and previous chunk.")
return
combined_audio = new_combined
# Write the final combined audio to output file
write_wav_file(output_filename, sample_rate, combined_audio, sample_width, n_channels)
print(f"Final audio saved to {output_filename}")
# Main execution
if __name__ == "__main__":
# User-specified parameters
sample_rate = 32000 # Sample rate in Hz
x = 500 # Number of frames to search from the end of the chunk
y = 10 # Number of consecutive zeros to look for
output_filename = "combined_output.wav"
folder_with_chunks = "output_chunks"
import os
def absolute_file_paths(directory):
path = os.path.abspath(directory)
return [entry.path for entry in os.scandir(path) if entry.is_file()]
# List of input audio chunk filenames in sequential order
filenames = absolute_file_paths(folder_with_chunks)
# Process and splice the audio chunks
process_audio_chunks(filenames, sample_rate, x, y, output_filename)

View File

@ -358,7 +358,7 @@ async def tts_get_endpoint(
top_p:float = 1,
temperature:float = 1,
text_split_method:str = "cut0",
batch_size:int = 1,
batch_size:int = 4,
batch_threshold:float = 0.75,
split_bucket:bool = True,
speed_factor:float = 1.0,

272
infer_script.py Normal file
View File

@ -0,0 +1,272 @@
'''
This is just an example inference script to test batching with llama, mainly for my reference in the future.
'''
import os
import sys
import numpy as np
import soundfile as sf
import threading
import queue
import sounddevice as sd
import time
import speech_recognition as sr
# Ensure that GPT_SoVITS is in the Python path
now_dir = os.getcwd()
sys.path.append(now_dir)
sys.path.append(os.path.join(now_dir, 'GPT_SoVITS'))
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
from GPT_SoVITS.TTS_infer_pack.TTS import TTS, TTS_Config
from llama_cpp import Llama
import sys
# Initialize the Llama model
llm = Llama(
model_path="ggml-model-q8_0.gguf",
n_gpu_layers=-1, # Uncomment to use GPU acceleration
seed=1337, # Uncomment to set a specific seed
n_ctx=2048, # Uncomment to increase the context window
chat_format="llama-3",
verbose=False
)
from time import time
def generate_chat_completion_openai_v1_stream(messages):
start = time()
stream = llm.create_chat_completion_openai_v1(
messages=messages,
temperature=0.8, # Adjust temperature as needed
top_p=0.95, # Adjust top_p as needed
top_k=40, # Adjust top_k as needed
max_tokens=50, # Adjust the maximum number of tokens as needed
# stop=["\n"], # Adjust the stop sequence as needed
stream=True # Enable streaming
)
end = time()
total = end - start
print(total)
for chunk in stream:
if chunk.choices[0].delta.content is not None:
yield chunk.choices[0].delta.content
def audio_playback_thread(audio_queue, sample_rate):
"""
Audio playback thread that plays audio fragments from the queue.
"""
sd.default.samplerate = sample_rate
sd.default.channels = 1
stream = sd.OutputStream(dtype='float32')
stream.start()
try:
while True:
# Get the next audio fragment
audio_fragment = audio_queue.get()
try:
if audio_fragment is None:
# Sentinel value received, exit the loop
break
# Write the audio fragment to the stream
stream.write(audio_fragment)
finally:
# Mark the item as processed
audio_queue.task_done()
finally:
stream.stop()
stream.close()
def main():
config_path = 'configs/tts_infer.yaml'
# GPT_model_path = 'pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt'
GPT_model_path = 'custom_trained.ckpt'
# SoVITS_model_path = 'pretrained_models/gsv-v2final-pretrained/s2G2333k.pth'
SoVITS_model_path = 'custom_trained.pth'
ref_audio_path = 'ref_audio.wav'
ref_text = 'でもなんか対処法ではないよなこれ対処法ではないけどそもそもの話みたいなことを言ってんのか'
target_text = """hahahaha, well well, let me tell you about that! it was perhaps the most exquisite day of my life! Phew, I've never had one better! """
output_path = 'output'
ref_language = 'ja'
target_language = 'ja'
# Ensure output directory exists
os.makedirs(output_path, exist_ok=True)
# Initialize TTS configuration and pipeline
tts_config = TTS_Config(config_path)
tts_pipeline = TTS(tts_config)
# Load model weights
tts_pipeline.init_t2s_weights(GPT_model_path)
tts_pipeline.init_vits_weights(SoVITS_model_path)
# Prepare inputs for TTS
inputs = {
"text": target_text,
"text_lang": target_language.lower(),
"ref_audio_path": ref_audio_path,
"prompt_text": ref_text,
"prompt_lang": ref_language.lower(),
"top_k": 5,
"top_p": 1.0,
"temperature": 1.0,
"text_split_method": "cut0",
"batch_size": 1,
"batch_threshold": 0.75,
"split_bucket": True,
"speed_factor": 1.0,
"fragment_interval": 0.3,
"seed": 2855904637,
"return_fragment": True,
"parallel_infer": False,
"repetition_penalty": 1.35,
}
# Run TTS inference
system_message = '''You are a friendly AI named Vivy.
HOW YOU SHOULD RESPOND:
- The responses should include only verbal responses, for example *laughs* should be replaced with haha
'''
# Initialize conversation history with system message
conversation_history = [
{"role": "system", "content": f"{system_message}"}
]
# Create a queue for audio fragments
audio_queue = queue.Queue(maxsize=100) # Adjust maxsize based on your needs
# Start the audio playback thread
playback_thread = threading.Thread(
target=audio_playback_thread,
args=(audio_queue, tts_pipeline.configs.sampling_rate)
)
playback_thread.start()
# Setup speech recognition
r = sr.Recognizer()
mic = sr.Microphone()
try:
while True:
# Prompt for speech input instead of text input
while True:
print("\nPlease speak your message (say 'quit' to exit):")
with mic as source:
# Adjust for ambient noise to improve recognition accuracy
r.adjust_for_ambient_noise(source, duration=1.0)
print("Listening...")
audio_data = r.listen(source, timeout=None, phrase_time_limit=60)
try:
# Replace 'recognize_whisper' with your actual recognition method
# Ensure that the method is correctly implemented or available
user_input = r.recognize_whisper(audio_data=audio_data, model="base")
print("You said: " + user_input)
# Check if the input is not empty or just whitespace
if user_input.strip() == "":
print("No speech detected. Please try again.")
continue # Continue listening
break # Valid input received, exit inner loop
except sr.UnknownValueError:
print("Sorry, I could not understand the audio. Please try again.")
continue # Continue listening
except sr.RequestError as e:
print(f"Could not request results from speech recognition service; {e}")
continue # Continue listening
# Check if the user wants to quit
if user_input.lower() == "quit":
print("Exiting the application. Goodbye!")
sys.exit()
# Append user message to conversation history
conversation_history.append({"role": "user", "content": user_input})
# Initialize variables to track character count and buffering
buffer = ""
char_count = 0
waiting_for_punctuation = False
assistant_buffer = ""
# Generate and print the chat completion with streaming
for token in generate_chat_completion_openai_v1_stream(conversation_history):
print(token, end="", flush=True) # Print each character as it's generated
buffer += token
assistant_buffer += token
char_count += len(token)
if not waiting_for_punctuation:
if char_count >= 100:
waiting_for_punctuation = True # Start looking for punctuation
else:
if any(punct in token for punct in ['.', '!', '?']):
# Send the buffer to TTS
inputs["text"] = buffer
synthesis_result = tts_pipeline.run_generator(inputs)
# Consume the generator and put audio fragments into the queue
for sampling_rate, audio_fragment in synthesis_result:
audio_queue.put(audio_fragment)
#put sielnce into audio queue after tts sythesis generator has finished
silence_duration = 0.5 # in seconds
num_samples = int(sampling_rate * silence_duration)
silence = np.zeros(num_samples, dtype='float32')
audio_queue.put(silence)
# Reset counters and buffer
char_count = 0
buffer = ""
waiting_for_punctuation = False
# Append assistant message to conversation history
conversation_history.append({"role": "assistant", "content": assistant_buffer})
# Handle any remaining text after the generator is done
if buffer.strip():
inputs["text"] = buffer
synthesis_result = tts_pipeline.run_generator(inputs)
# Consume the generator and put audio fragments into the queue
for sampling_rate, audio_fragment in synthesis_result:
audio_queue.put(audio_fragment)
#put sielnce into audio queue after tts sythesis generator has finished
silence_duration = 0.5 # in seconds
num_samples = int(sampling_rate * silence_duration)
silence = np.zeros(num_samples, dtype='float32')
audio_queue.put(silence)
conversation_history.append({"role": "assistant", "content": buffer})
buffer = ""
char_count = 0
waiting_for_punctuation = False
finally:
# After all processing is done, send a sentinel to the audio queue and wait for threads to finish
audio_queue.put(None)
audio_queue.join()
playback_thread.join()
# text = input("GO:")
# inputs["text"] = text
# synthesis_result = tts_pipeline.run_generator(inputs)
# audio_data_list = list(synthesis_result)
# if audio_data_list:
# # Since return_fragment is False, we expect only one tuple in audio_data_list
# sampling_rate, audio_data = audio_data_list[0]
# output_wav_path = os.path.join(output_path, "output.wav")
# # Save the audio data to a WAV file
# sf.write(output_wav_path, audio_data, sampling_rate)
# print(f"Audio saved to {output_wav_path}")
# else:
# print("No audio data generated.")
if __name__ == '__main__':
main()