Fixed imports and added initial setup for streaming support.

This commit is contained in:
Sebastian Pava 2025-04-19 01:52:32 -04:00
parent c0b46314ca
commit 11e98462a2
61 changed files with 239 additions and 53 deletions

View File

@ -10,9 +10,9 @@ from typing import Dict
import torch import torch
from pytorch_lightning import LightningModule from pytorch_lightning import LightningModule
from AR.models.t2s_model import Text2SemanticDecoder from GPT_SoVITS.AR.models.t2s_model import Text2SemanticDecoder
from AR.modules.lr_schedulers import WarmupCosineLRSchedule from GPT_SoVITS.AR.modules.lr_schedulers import WarmupCosineLRSchedule
from AR.modules.optim import ScaledAdam from GPT_SoVITS.AR.modules.optim import ScaledAdam
class Text2SemanticLightningModule(LightningModule): class Text2SemanticLightningModule(LightningModule):

View File

@ -9,7 +9,7 @@ from torch.nn import functional as F
from torchmetrics.classification import MulticlassAccuracy from torchmetrics.classification import MulticlassAccuracy
from tqdm import tqdm from tqdm import tqdm
from AR.models.utils import ( from GPT_SoVITS.AR.models.utils import (
dpo_loss, dpo_loss,
get_batch_logps, get_batch_logps,
make_pad_mask, make_pad_mask,
@ -18,8 +18,8 @@ from AR.models.utils import (
sample, sample,
topk_sampling, topk_sampling,
) )
from AR.modules.embedding import SinePositionalEmbedding, TokenEmbedding from GPT_SoVITS.AR.modules.embedding import SinePositionalEmbedding, TokenEmbedding
from AR.modules.transformer import LayerNorm, TransformerEncoder, TransformerEncoderLayer from GPT_SoVITS.AR.modules.transformer import LayerNorm, TransformerEncoder, TransformerEncoderLayer
default_config = { default_config = {
"embedding_dim": 512, "embedding_dim": 512,

View File

@ -9,7 +9,7 @@ from torch.nn.init import constant_, xavier_normal_, xavier_uniform_
from torch.nn.modules.linear import NonDynamicallyQuantizableLinear from torch.nn.modules.linear import NonDynamicallyQuantizableLinear
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from AR.modules.patched_mha_with_cache import multi_head_attention_forward_patched from GPT_SoVITS.AR.modules.patched_mha_with_cache import multi_head_attention_forward_patched
F.multi_head_attention_forward = multi_head_attention_forward_patched F.multi_head_attention_forward = multi_head_attention_forward_patched

View File

@ -10,8 +10,8 @@ from typing import Tuple
from typing import Union from typing import Union
import torch import torch
from AR.modules.activation import MultiheadAttention from GPT_SoVITS.AR.modules.activation import MultiheadAttention
from AR.modules.scaling import BalancedDoubleSwish from GPT_SoVITS.AR.modules.scaling import BalancedDoubleSwish
from torch import nn from torch import nn
from torch import Tensor from torch import Tensor
from torch.nn import functional as F from torch.nn import functional as F

View File

@ -21,20 +21,20 @@ import numpy as np
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
import yaml import yaml
from AR.models.t2s_lightning_module import Text2SemanticLightningModule from GPT_SoVITS.AR.models.t2s_lightning_module import Text2SemanticLightningModule
from BigVGAN.bigvgan import BigVGAN from GPT_SoVITS.BigVGAN.bigvgan import BigVGAN
from feature_extractor.cnhubert import CNHubert from GPT_SoVITS.feature_extractor.cnhubert import CNHubert
from module.mel_processing import mel_spectrogram_torch, spectrogram_torch from GPT_SoVITS.module.mel_processing import mel_spectrogram_torch, spectrogram_torch
from module.models import SynthesizerTrn, SynthesizerTrnV3 from GPT_SoVITS.module.models import SynthesizerTrn, SynthesizerTrnV3
from peft import LoraConfig, get_peft_model from peft import LoraConfig, get_peft_model
from process_ckpt import get_sovits_version_from_path_fast, load_sovits_new from GPT_SoVITS.process_ckpt import get_sovits_version_from_path_fast, load_sovits_new
from transformers import AutoModelForMaskedLM, AutoTokenizer from transformers import AutoModelForMaskedLM, AutoTokenizer
from tools.audio_sr import AP_BWE from tools.audio_sr import AP_BWE
from tools.i18n.i18n import I18nAuto, scan_language_list from tools.i18n.i18n import I18nAuto, scan_language_list
from tools.my_utils import load_audio from tools.my_utils import load_audio
from TTS_infer_pack.text_segmentation_method import splits from GPT_SoVITS.TTS_infer_pack.text_segmentation_method import splits
from TTS_infer_pack.TextPreprocessor import TextPreprocessor from GPT_SoVITS.TTS_infer_pack.TextPreprocessor import TextPreprocessor
language = os.environ.get("language", "Auto") language = os.environ.get("language", "Auto")
language = sys.argv[-1] if sys.argv[-1] in scan_language_list() else language language = sys.argv[-1] if sys.argv[-1] in scan_language_list() else language

View File

@ -9,13 +9,13 @@ sys.path.append(now_dir)
import re import re
import torch import torch
from text.LangSegmenter import LangSegmenter from GPT_SoVITS.text.LangSegmenter import LangSegmenter
from text import chinese from GPT_SoVITS.text import chinese
from typing import Dict, List, Tuple from typing import Dict, List, Tuple
from text.cleaner import clean_text from GPT_SoVITS.text.cleaner import clean_text
from text import cleaned_text_to_sequence from GPT_SoVITS.text import cleaned_text_to_sequence
from transformers import AutoModelForMaskedLM, AutoTokenizer from transformers import AutoModelForMaskedLM, AutoTokenizer
from TTS_infer_pack.text_segmentation_method import split_big_text, splits, get_method as get_seg_method from GPT_SoVITS.TTS_infer_pack.text_segmentation_method import split_big_text, splits, get_method as get_seg_method
from tools.i18n.i18n import I18nAuto, scan_language_list from tools.i18n.i18n import I18nAuto, scan_language_list

View File

@ -0,0 +1,84 @@
import numpy as np
"""
Essentially returns the index of the middle of the zero zone + the starting index.
So if the starting index was 0 and we found the zero zone to be from 12789:12800,
then we would be returning 0 + 12795 or 12795 (since the window was of size 11 and midpoint is 6)
This method works by using a sliding window mechanic on each chunk, where we
slide the window from the end going to the start. If all the values in the window
meet the threshold, then we assign this as the zero zone.
TLDR: Returns the zero zone where a region in the audio has enough silence.
"""
def find_zero_zone(chunk, start_index, search_length, search_window_size=11):
zone = chunk[start_index:start_index + search_length]
# print(f"Zero-crossing search zone: Start={start_index}, Length={len(zone)}")
zero_threshold = 1.0e-4
# Check for y consecutive zeros
for idx in range(len(zone), -1 + search_window_size, -1):
index_to_start = idx-search_window_size
abs_zone = np.abs(zone[index_to_start:idx])
if np.all(abs_zone < zero_threshold):
# print(f"Found Abs Zone: {abs_zone}")
# print(f"Extended Abs Zone: {chunk[idx-21:idx+10]}")
index_midpoint = index_to_start + int(search_window_size // 2)
# print(f"Returning {start_index} + {index_midpoint}")
return (start_index + index_midpoint), None
# print("Falling back to zero crossing due to no zero zone found. You may hear more prominent pops and clicks in the audio. Try increasing search length or cumulative tokens.")
return find_zero_crossing(chunk, start_index, search_length)
def find_zero_crossing(chunk, start_index, search_length):
# If the model is falling back on the this function, it might be a bad indicator that the search length is too low
zone = chunk[start_index:start_index + search_length]
sign_changes = np.where(np.diff(np.sign(zone)) != 0)[0]
if len(sign_changes) == 0:
raise ("No zero-crossings found in this zone. This should not be happening, debugging time.")
else:
zc_index = start_index + sign_changes[0] + 1
# print(f"Zero-crossing found at index {zc_index}")
# Determine the crossing direction in chunk1
prev_value = chunk[zc_index - 1]
curr_value = chunk[zc_index]
crossing_direction = np.sign(curr_value) - np.sign(prev_value)
# print(f"Crossing direction in chunk1: {np.sign(prev_value)} to {np.sign(curr_value)}")
return zc_index, crossing_direction
def find_matching_index(chunk, center_index, max_offset, crossing_direction):
"""
Finds a zero-crossing in data that matches the specified crossing direction,
starting from center_index and searching outward.
"""
if crossing_direction == None:
return center_index # if zero zone
# fall back for zero_crossing
data_length = len(chunk)
# print(f"Center index in chunk2: {center_index}")
for offset in range(max_offset + 1):
# Check index bounds
idx_forward = center_index + offset
idx_backward = center_index - offset
# Check forward direction
if idx_forward < data_length - 1:
prev_sign = np.sign(chunk[idx_forward])
curr_sign = np.sign(chunk[idx_forward + 1])
direction = curr_sign - prev_sign
if direction == crossing_direction:
# print(f"Matching zero-crossing found at index {idx_forward + 1} (forward)")
return idx_forward + 1
# Check backward direction
if idx_backward > 0:
prev_sign = np.sign(chunk[idx_backward - 1])
curr_sign = np.sign(chunk[idx_backward])
direction = curr_sign - prev_sign
if direction == crossing_direction:
# print(f"Matching zero-crossing found at index {idx_backward} (backward)")
return idx_backward
# print("No matching zero-crossings found in this zone.")
return None

View File

@ -25,7 +25,7 @@ from GPT_SoVITS.f5_tts.model.modules import (
get_pos_embed_indices, get_pos_embed_indices,
) )
from module.commons import sequence_mask from GPT_SoVITS.module.commons import sequence_mask
class TextEmbedding(nn.Module): class TextEmbedding(nn.Module):

View File

@ -13,7 +13,7 @@ from transformers import (
HubertModel, HubertModel,
) )
import utils import GPT_SoVITS.utils
import torch.nn as nn import torch.nn as nn
cnhubert_base_path = None cnhubert_base_path = None

View File

@ -3,8 +3,8 @@ import torch
from torch import nn from torch import nn
from torch.nn import functional as F from torch.nn import functional as F
from module import commons from GPT_SoVITS.module import commons
from module.modules import LayerNorm from GPT_SoVITS.module.modules import LayerNorm
class Encoder(nn.Module): class Encoder(nn.Module):

View File

@ -7,19 +7,19 @@ import torch
from torch import nn from torch import nn
from torch.nn import functional as F from torch.nn import functional as F
from module import commons from GPT_SoVITS.module import commons
from module import modules from GPT_SoVITS.module import modules
from module import attentions from GPT_SoVITS.module import attentions
from f5_tts.model import DiT from GPT_SoVITS.f5_tts.model import DiT
from torch.nn import Conv1d, ConvTranspose1d, Conv2d from torch.nn import Conv1d, ConvTranspose1d, Conv2d
from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
from module.commons import init_weights, get_padding from GPT_SoVITS.module.commons import init_weights, get_padding
from module.mrte_model import MRTE from GPT_SoVITS.module.mrte_model import MRTE
from module.quantize import ResidualVectorQuantizer from GPT_SoVITS.module.quantize import ResidualVectorQuantizer
# from text import symbols # from text import symbols
from text import symbols as symbols_v1 from GPT_SoVITS.text import symbols as symbols_v1
from text import symbols2 as symbols_v2 from GPT_SoVITS.text import symbols2 as symbols_v2
from torch.cuda.amp import autocast from torch.cuda.amp import autocast
import contextlib import contextlib
import random import random

View File

@ -7,9 +7,9 @@ from torch.nn import functional as F
from torch.nn import Conv1d from torch.nn import Conv1d
from torch.nn.utils import weight_norm, remove_weight_norm from torch.nn.utils import weight_norm, remove_weight_norm
from module import commons from GPT_SoVITS.module import commons
from module.commons import init_weights, get_padding from GPT_SoVITS.module.commons import init_weights, get_padding
from module.transforms import piecewise_rational_quadratic_transform from GPT_SoVITS.module.transforms import piecewise_rational_quadratic_transform
import torch.distributions as D import torch.distributions as D

View File

@ -3,7 +3,7 @@
import torch import torch
from torch import nn from torch import nn
from torch.nn.utils import remove_weight_norm, weight_norm from torch.nn.utils import remove_weight_norm, weight_norm
from module.attentions import MultiHeadAttention from GPT_SoVITS.module.attentions import MultiHeadAttention
class MRTE(nn.Module): class MRTE(nn.Module):

View File

@ -12,7 +12,7 @@ import typing as tp
import torch import torch
from torch import nn from torch import nn
from module.core_vq import ResidualVectorQuantization from GPT_SoVITS.module.core_vq import ResidualVectorQuantization
@dataclass @dataclass

View File

@ -4,8 +4,8 @@ import os
# else: # else:
# from text.symbols2 import symbols # from text.symbols2 import symbols
from text import symbols as symbols_v1 from GPT_SoVITS.text import symbols as symbols_v1
from text import symbols2 as symbols_v2 from GPT_SoVITS.text import symbols2 as symbols_v2
_symbol_to_id_v1 = {s: i for i, s in enumerate(symbols_v1.symbols)} _symbol_to_id_v1 = {s: i for i, s in enumerate(symbols_v1.symbols)}
_symbol_to_id_v2 = {s: i for i, s in enumerate(symbols_v2.symbols)} _symbol_to_id_v2 = {s: i for i, s in enumerate(symbols_v2.symbols)}

View File

@ -4,9 +4,9 @@ import re
import cn2an import cn2an
from pypinyin import lazy_pinyin, Style from pypinyin import lazy_pinyin, Style
from text.symbols import punctuation from GPT_SoVITS.text.symbols import punctuation
from text.tone_sandhi import ToneSandhi from GPT_SoVITS.text.tone_sandhi import ToneSandhi
from text.zh_normalization.text_normlization import TextNormalizer from GPT_SoVITS.text.zh_normalization.text_normlization import TextNormalizer
normalizer = lambda x: cn2an.transform(x, "an2cn") normalizer = lambda x: cn2an.transform(x, "an2cn")

View File

@ -1,4 +1,4 @@
from text import cleaned_text_to_sequence from GPT_SoVITS.text import cleaned_text_to_sequence
import os import os
# if os.environ.get("version","v1")=="v1": # if os.environ.get("version","v1")=="v1":
# from text import chinese # from text import chinese
@ -7,8 +7,8 @@ import os
# from text import chinese2 as chinese # from text import chinese2 as chinese
# from text.symbols2 import symbols # from text.symbols2 import symbols
from text import symbols as symbols_v1 from GPT_SoVITS.text import symbols as symbols_v1
from text import symbols2 as symbols_v2 from GPT_SoVITS.text import symbols2 as symbols_v2
special = [ special = [
# ("%", "zh", "SP"), # ("%", "zh", "SP"),
@ -34,7 +34,7 @@ def clean_text(text, language, version=None):
for special_s, special_l, target_symbol in special: for special_s, special_l, target_symbol in special:
if special_s in text and language == special_l: if special_s in text and language == special_l:
return clean_special(text, language, special_s, target_symbol, version) return clean_special(text, language, special_s, target_symbol, version)
language_module = __import__("text." + language_module_map[language], fromlist=[language_module_map[language]]) language_module = __import__("GPT_SoVITS.text." + language_module_map[language], fromlist=[language_module_map[language]])
if hasattr(language_module, "text_normalize"): if hasattr(language_module, "text_normalize"):
norm_text = language_module.text_normalize(text) norm_text = language_module.text_normalize(text)
else: else:

View File

@ -4,12 +4,12 @@ import re
import wordsegment import wordsegment
from g2p_en import G2p from g2p_en import G2p
from text.symbols import punctuation from GPT_SoVITS.text.symbols import punctuation
from text.symbols2 import symbols from GPT_SoVITS.text.symbols2 import symbols
from builtins import str as unicode from builtins import str as unicode
from text.en_normalization.expend import normalize from GPT_SoVITS.text.en_normalization.expend import normalize
from nltk.tokenize import TweetTokenizer from nltk.tokenize import TweetTokenizer
word_tokenize = TweetTokenizer().tokenize word_tokenize = TweetTokenizer().tokenize

View File

@ -11,4 +11,4 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from text.zh_normalization.text_normlization import * from GPT_SoVITS.text.zh_normalization.text_normlization import *

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

102
inference.py Normal file
View File

@ -0,0 +1,102 @@
import torch
import sounddevice as sd
import time
from queue import Queue
from threading import Thread
import os
class TTS:
def __init__(self):
# Replace with your checkpoints and reference audio here
# Note: Using a venv may require updating the default paths provided here
self.bert_checkpoint = "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large"
self.cnhuhbert_checkpoint = "GPT_SoVITS/pretrained_models/chinese-hubert-base"
# self.t2s_checkpoint = "GPT_SoVITS/pretrained_models/ayaka/Ayaka-e15.ckpt"
# self.vits_checkpoint = "GPT_SoVITS/pretrained_models/ayaka/Ayaka_e3_s1848_l32.pth"
self.t2s_checkpoint = "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt"
self.vits_checkpoint = "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth"
self.ref_audio = "audio/ayaka/ref_audio/10_audio.wav"
from GPT_SoVITS.TTS_infer_pack.TTS import TTS, TTS_Config
self.config = {
"custom": {
"bert_base_path": self.bert_checkpoint,
"cnhuhbert_base_path": self.cnhuhbert_checkpoint,
"device": "cuda" if torch.cuda.is_available() else "cpu",
"is_half": True,
"t2s_weights_path": self.t2s_checkpoint,
"vits_weights_path": self.vits_checkpoint,
"version": "v3"
}
}
self.tts = TTS(TTS_Config(self.config))
self.audio_queue = Queue()
self.generating_audio = False
def audio_stream(self, start_time):
with sd.OutputStream(samplerate=32000, channels=1, dtype="int16") as stream:
while True:
sr, audio_data = self.audio_queue.get()
if audio_data is None:
print(f"Stream Thread Done ({time.time() - start_time:.2f}s)")
break
print((sr, audio_data))
stream.write(audio_data)
self.generating_audio = False
def synthesize(self, text, start_time, generating_text=False):
if not self.generating_audio:
Thread(target=self.audio_stream, args=(start_time,)).start()
self.generating_audio = True
path = "audio/ayaka/aux_ref_audio"
aux_ref_audios = [f"{path}/{file_name}" for file_name in os.listdir(path)]
args = {
"text": text,
"text_lang": "en",
"ref_audio_path": self.ref_audio,
"aux_ref_audio_paths": aux_ref_audios,
"prompt_text": "Don't worry. Now that I've experienced the event once already, I won't be easily frightened. I'll see you later. Have a lovely chat with your friend.",
"prompt_lang": "en",
"temperature": 0.8,
"top_k": 50,
"top_p": 0.9,
"parallel_infer": True,
"sample_steps": 32,
"super_sampling": True,
"speed_factor": 1,
"fragment_interval": 0.2
# "stream_output": True,
# "max_chunk_size": 20,
}
if text:
print(f"Synthesis Start: {time.time() - start_time}")
generator = self.tts.run(args)
while True:
try:
audio_chunk = next(generator)
self.audio_queue.put(audio_chunk)
except StopIteration:
break
if not generating_text:
self.audio_queue.put((None, None))
print(f"Synthesis End ({time.time() - start_time:.2f}s)")
# Usage
tts = TTS()
"""
Time is only for debugging purposes. If not needed, feel free to remove.
Since this TTS model was built to be paired with LLM text streaming, we use a generating_text bool
this bool signifies if we are receiving the last chunk of streamed text (hence if we are generating anymore).
"""
tts.synthesize("One day, a fierce storm rolled in, bringing heavy rain and strong winds that threatened to destroy the wheat crops.", time.time(), False)
while tts.generating_audio:
time.sleep(0.1)
tts.synthesize("One day, a fierce storm rolled in, bringing heavy rain and strong winds that threatened to destroy the wheat crops.", time.time(), False)