Fixed imports and added initial setup for streaming support.

This commit is contained in:
Sebastian Pava 2025-04-19 01:52:32 -04:00
parent c0b46314ca
commit 11e98462a2
61 changed files with 239 additions and 53 deletions

View File

@ -10,9 +10,9 @@ from typing import Dict
import torch
from pytorch_lightning import LightningModule
from AR.models.t2s_model import Text2SemanticDecoder
from AR.modules.lr_schedulers import WarmupCosineLRSchedule
from AR.modules.optim import ScaledAdam
from GPT_SoVITS.AR.models.t2s_model import Text2SemanticDecoder
from GPT_SoVITS.AR.modules.lr_schedulers import WarmupCosineLRSchedule
from GPT_SoVITS.AR.modules.optim import ScaledAdam
class Text2SemanticLightningModule(LightningModule):

View File

@ -9,7 +9,7 @@ from torch.nn import functional as F
from torchmetrics.classification import MulticlassAccuracy
from tqdm import tqdm
from AR.models.utils import (
from GPT_SoVITS.AR.models.utils import (
dpo_loss,
get_batch_logps,
make_pad_mask,
@ -18,8 +18,8 @@ from AR.models.utils import (
sample,
topk_sampling,
)
from AR.modules.embedding import SinePositionalEmbedding, TokenEmbedding
from AR.modules.transformer import LayerNorm, TransformerEncoder, TransformerEncoderLayer
from GPT_SoVITS.AR.modules.embedding import SinePositionalEmbedding, TokenEmbedding
from GPT_SoVITS.AR.modules.transformer import LayerNorm, TransformerEncoder, TransformerEncoderLayer
default_config = {
"embedding_dim": 512,

View File

@ -9,7 +9,7 @@ from torch.nn.init import constant_, xavier_normal_, xavier_uniform_
from torch.nn.modules.linear import NonDynamicallyQuantizableLinear
from torch.nn.parameter import Parameter
from AR.modules.patched_mha_with_cache import multi_head_attention_forward_patched
from GPT_SoVITS.AR.modules.patched_mha_with_cache import multi_head_attention_forward_patched
F.multi_head_attention_forward = multi_head_attention_forward_patched

View File

@ -10,8 +10,8 @@ from typing import Tuple
from typing import Union
import torch
from AR.modules.activation import MultiheadAttention
from AR.modules.scaling import BalancedDoubleSwish
from GPT_SoVITS.AR.modules.activation import MultiheadAttention
from GPT_SoVITS.AR.modules.scaling import BalancedDoubleSwish
from torch import nn
from torch import Tensor
from torch.nn import functional as F

View File

@ -21,20 +21,20 @@ import numpy as np
import torch
import torch.nn.functional as F
import yaml
from AR.models.t2s_lightning_module import Text2SemanticLightningModule
from BigVGAN.bigvgan import BigVGAN
from feature_extractor.cnhubert import CNHubert
from module.mel_processing import mel_spectrogram_torch, spectrogram_torch
from module.models import SynthesizerTrn, SynthesizerTrnV3
from GPT_SoVITS.AR.models.t2s_lightning_module import Text2SemanticLightningModule
from GPT_SoVITS.BigVGAN.bigvgan import BigVGAN
from GPT_SoVITS.feature_extractor.cnhubert import CNHubert
from GPT_SoVITS.module.mel_processing import mel_spectrogram_torch, spectrogram_torch
from GPT_SoVITS.module.models import SynthesizerTrn, SynthesizerTrnV3
from peft import LoraConfig, get_peft_model
from process_ckpt import get_sovits_version_from_path_fast, load_sovits_new
from GPT_SoVITS.process_ckpt import get_sovits_version_from_path_fast, load_sovits_new
from transformers import AutoModelForMaskedLM, AutoTokenizer
from tools.audio_sr import AP_BWE
from tools.i18n.i18n import I18nAuto, scan_language_list
from tools.my_utils import load_audio
from TTS_infer_pack.text_segmentation_method import splits
from TTS_infer_pack.TextPreprocessor import TextPreprocessor
from GPT_SoVITS.TTS_infer_pack.text_segmentation_method import splits
from GPT_SoVITS.TTS_infer_pack.TextPreprocessor import TextPreprocessor
language = os.environ.get("language", "Auto")
language = sys.argv[-1] if sys.argv[-1] in scan_language_list() else language

View File

@ -9,13 +9,13 @@ sys.path.append(now_dir)
import re
import torch
from text.LangSegmenter import LangSegmenter
from text import chinese
from GPT_SoVITS.text.LangSegmenter import LangSegmenter
from GPT_SoVITS.text import chinese
from typing import Dict, List, Tuple
from text.cleaner import clean_text
from text import cleaned_text_to_sequence
from GPT_SoVITS.text.cleaner import clean_text
from GPT_SoVITS.text import cleaned_text_to_sequence
from transformers import AutoModelForMaskedLM, AutoTokenizer
from TTS_infer_pack.text_segmentation_method import split_big_text, splits, get_method as get_seg_method
from GPT_SoVITS.TTS_infer_pack.text_segmentation_method import split_big_text, splits, get_method as get_seg_method
from tools.i18n.i18n import I18nAuto, scan_language_list

View File

@ -0,0 +1,84 @@
import numpy as np
"""
Essentially returns the index of the middle of the zero zone + the starting index.
So if the starting index was 0 and we found the zero zone to be from 12789:12800,
then we would be returning 0 + 12795 or 12795 (since the window was of size 11 and midpoint is 6)
This method works by using a sliding window mechanic on each chunk, where we
slide the window from the end going to the start. If all the values in the window
meet the threshold, then we assign this as the zero zone.
TLDR: Returns the zero zone where a region in the audio has enough silence.
"""
def find_zero_zone(chunk, start_index, search_length, search_window_size=11):
zone = chunk[start_index:start_index + search_length]
# print(f"Zero-crossing search zone: Start={start_index}, Length={len(zone)}")
zero_threshold = 1.0e-4
# Check for y consecutive zeros
for idx in range(len(zone), -1 + search_window_size, -1):
index_to_start = idx-search_window_size
abs_zone = np.abs(zone[index_to_start:idx])
if np.all(abs_zone < zero_threshold):
# print(f"Found Abs Zone: {abs_zone}")
# print(f"Extended Abs Zone: {chunk[idx-21:idx+10]}")
index_midpoint = index_to_start + int(search_window_size // 2)
# print(f"Returning {start_index} + {index_midpoint}")
return (start_index + index_midpoint), None
# print("Falling back to zero crossing due to no zero zone found. You may hear more prominent pops and clicks in the audio. Try increasing search length or cumulative tokens.")
return find_zero_crossing(chunk, start_index, search_length)
def find_zero_crossing(chunk, start_index, search_length):
# If the model is falling back on the this function, it might be a bad indicator that the search length is too low
zone = chunk[start_index:start_index + search_length]
sign_changes = np.where(np.diff(np.sign(zone)) != 0)[0]
if len(sign_changes) == 0:
raise ("No zero-crossings found in this zone. This should not be happening, debugging time.")
else:
zc_index = start_index + sign_changes[0] + 1
# print(f"Zero-crossing found at index {zc_index}")
# Determine the crossing direction in chunk1
prev_value = chunk[zc_index - 1]
curr_value = chunk[zc_index]
crossing_direction = np.sign(curr_value) - np.sign(prev_value)
# print(f"Crossing direction in chunk1: {np.sign(prev_value)} to {np.sign(curr_value)}")
return zc_index, crossing_direction
def find_matching_index(chunk, center_index, max_offset, crossing_direction):
"""
Finds a zero-crossing in data that matches the specified crossing direction,
starting from center_index and searching outward.
"""
if crossing_direction == None:
return center_index # if zero zone
# fall back for zero_crossing
data_length = len(chunk)
# print(f"Center index in chunk2: {center_index}")
for offset in range(max_offset + 1):
# Check index bounds
idx_forward = center_index + offset
idx_backward = center_index - offset
# Check forward direction
if idx_forward < data_length - 1:
prev_sign = np.sign(chunk[idx_forward])
curr_sign = np.sign(chunk[idx_forward + 1])
direction = curr_sign - prev_sign
if direction == crossing_direction:
# print(f"Matching zero-crossing found at index {idx_forward + 1} (forward)")
return idx_forward + 1
# Check backward direction
if idx_backward > 0:
prev_sign = np.sign(chunk[idx_backward - 1])
curr_sign = np.sign(chunk[idx_backward])
direction = curr_sign - prev_sign
if direction == crossing_direction:
# print(f"Matching zero-crossing found at index {idx_backward} (backward)")
return idx_backward
# print("No matching zero-crossings found in this zone.")
return None

View File

@ -25,7 +25,7 @@ from GPT_SoVITS.f5_tts.model.modules import (
get_pos_embed_indices,
)
from module.commons import sequence_mask
from GPT_SoVITS.module.commons import sequence_mask
class TextEmbedding(nn.Module):

View File

@ -13,7 +13,7 @@ from transformers import (
HubertModel,
)
import utils
import GPT_SoVITS.utils
import torch.nn as nn
cnhubert_base_path = None

View File

@ -3,8 +3,8 @@ import torch
from torch import nn
from torch.nn import functional as F
from module import commons
from module.modules import LayerNorm
from GPT_SoVITS.module import commons
from GPT_SoVITS.module.modules import LayerNorm
class Encoder(nn.Module):

View File

@ -7,19 +7,19 @@ import torch
from torch import nn
from torch.nn import functional as F
from module import commons
from module import modules
from module import attentions
from f5_tts.model import DiT
from GPT_SoVITS.module import commons
from GPT_SoVITS.module import modules
from GPT_SoVITS.module import attentions
from GPT_SoVITS.f5_tts.model import DiT
from torch.nn import Conv1d, ConvTranspose1d, Conv2d
from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
from module.commons import init_weights, get_padding
from module.mrte_model import MRTE
from module.quantize import ResidualVectorQuantizer
from GPT_SoVITS.module.commons import init_weights, get_padding
from GPT_SoVITS.module.mrte_model import MRTE
from GPT_SoVITS.module.quantize import ResidualVectorQuantizer
# from text import symbols
from text import symbols as symbols_v1
from text import symbols2 as symbols_v2
from GPT_SoVITS.text import symbols as symbols_v1
from GPT_SoVITS.text import symbols2 as symbols_v2
from torch.cuda.amp import autocast
import contextlib
import random

View File

@ -7,9 +7,9 @@ from torch.nn import functional as F
from torch.nn import Conv1d
from torch.nn.utils import weight_norm, remove_weight_norm
from module import commons
from module.commons import init_weights, get_padding
from module.transforms import piecewise_rational_quadratic_transform
from GPT_SoVITS.module import commons
from GPT_SoVITS.module.commons import init_weights, get_padding
from GPT_SoVITS.module.transforms import piecewise_rational_quadratic_transform
import torch.distributions as D

View File

@ -3,7 +3,7 @@
import torch
from torch import nn
from torch.nn.utils import remove_weight_norm, weight_norm
from module.attentions import MultiHeadAttention
from GPT_SoVITS.module.attentions import MultiHeadAttention
class MRTE(nn.Module):

View File

@ -12,7 +12,7 @@ import typing as tp
import torch
from torch import nn
from module.core_vq import ResidualVectorQuantization
from GPT_SoVITS.module.core_vq import ResidualVectorQuantization
@dataclass

View File

@ -4,8 +4,8 @@ import os
# else:
# from text.symbols2 import symbols
from text import symbols as symbols_v1
from text import symbols2 as symbols_v2
from GPT_SoVITS.text import symbols as symbols_v1
from GPT_SoVITS.text import symbols2 as symbols_v2
_symbol_to_id_v1 = {s: i for i, s in enumerate(symbols_v1.symbols)}
_symbol_to_id_v2 = {s: i for i, s in enumerate(symbols_v2.symbols)}

View File

@ -4,9 +4,9 @@ import re
import cn2an
from pypinyin import lazy_pinyin, Style
from text.symbols import punctuation
from text.tone_sandhi import ToneSandhi
from text.zh_normalization.text_normlization import TextNormalizer
from GPT_SoVITS.text.symbols import punctuation
from GPT_SoVITS.text.tone_sandhi import ToneSandhi
from GPT_SoVITS.text.zh_normalization.text_normlization import TextNormalizer
normalizer = lambda x: cn2an.transform(x, "an2cn")

View File

@ -1,4 +1,4 @@
from text import cleaned_text_to_sequence
from GPT_SoVITS.text import cleaned_text_to_sequence
import os
# if os.environ.get("version","v1")=="v1":
# from text import chinese
@ -7,8 +7,8 @@ import os
# from text import chinese2 as chinese
# from text.symbols2 import symbols
from text import symbols as symbols_v1
from text import symbols2 as symbols_v2
from GPT_SoVITS.text import symbols as symbols_v1
from GPT_SoVITS.text import symbols2 as symbols_v2
special = [
# ("%", "zh", "SP"),
@ -34,7 +34,7 @@ def clean_text(text, language, version=None):
for special_s, special_l, target_symbol in special:
if special_s in text and language == special_l:
return clean_special(text, language, special_s, target_symbol, version)
language_module = __import__("text." + language_module_map[language], fromlist=[language_module_map[language]])
language_module = __import__("GPT_SoVITS.text." + language_module_map[language], fromlist=[language_module_map[language]])
if hasattr(language_module, "text_normalize"):
norm_text = language_module.text_normalize(text)
else:

View File

@ -4,12 +4,12 @@ import re
import wordsegment
from g2p_en import G2p
from text.symbols import punctuation
from GPT_SoVITS.text.symbols import punctuation
from text.symbols2 import symbols
from GPT_SoVITS.text.symbols2 import symbols
from builtins import str as unicode
from text.en_normalization.expend import normalize
from GPT_SoVITS.text.en_normalization.expend import normalize
from nltk.tokenize import TweetTokenizer
word_tokenize = TweetTokenizer().tokenize

View File

@ -11,4 +11,4 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from text.zh_normalization.text_normlization import *
from GPT_SoVITS.text.zh_normalization.text_normlization import *

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

102
inference.py Normal file
View File

@ -0,0 +1,102 @@
import torch
import sounddevice as sd
import time
from queue import Queue
from threading import Thread
import os
class TTS:
def __init__(self):
# Replace with your checkpoints and reference audio here
# Note: Using a venv may require updating the default paths provided here
self.bert_checkpoint = "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large"
self.cnhuhbert_checkpoint = "GPT_SoVITS/pretrained_models/chinese-hubert-base"
# self.t2s_checkpoint = "GPT_SoVITS/pretrained_models/ayaka/Ayaka-e15.ckpt"
# self.vits_checkpoint = "GPT_SoVITS/pretrained_models/ayaka/Ayaka_e3_s1848_l32.pth"
self.t2s_checkpoint = "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt"
self.vits_checkpoint = "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth"
self.ref_audio = "audio/ayaka/ref_audio/10_audio.wav"
from GPT_SoVITS.TTS_infer_pack.TTS import TTS, TTS_Config
self.config = {
"custom": {
"bert_base_path": self.bert_checkpoint,
"cnhuhbert_base_path": self.cnhuhbert_checkpoint,
"device": "cuda" if torch.cuda.is_available() else "cpu",
"is_half": True,
"t2s_weights_path": self.t2s_checkpoint,
"vits_weights_path": self.vits_checkpoint,
"version": "v3"
}
}
self.tts = TTS(TTS_Config(self.config))
self.audio_queue = Queue()
self.generating_audio = False
def audio_stream(self, start_time):
with sd.OutputStream(samplerate=32000, channels=1, dtype="int16") as stream:
while True:
sr, audio_data = self.audio_queue.get()
if audio_data is None:
print(f"Stream Thread Done ({time.time() - start_time:.2f}s)")
break
print((sr, audio_data))
stream.write(audio_data)
self.generating_audio = False
def synthesize(self, text, start_time, generating_text=False):
if not self.generating_audio:
Thread(target=self.audio_stream, args=(start_time,)).start()
self.generating_audio = True
path = "audio/ayaka/aux_ref_audio"
aux_ref_audios = [f"{path}/{file_name}" for file_name in os.listdir(path)]
args = {
"text": text,
"text_lang": "en",
"ref_audio_path": self.ref_audio,
"aux_ref_audio_paths": aux_ref_audios,
"prompt_text": "Don't worry. Now that I've experienced the event once already, I won't be easily frightened. I'll see you later. Have a lovely chat with your friend.",
"prompt_lang": "en",
"temperature": 0.8,
"top_k": 50,
"top_p": 0.9,
"parallel_infer": True,
"sample_steps": 32,
"super_sampling": True,
"speed_factor": 1,
"fragment_interval": 0.2
# "stream_output": True,
# "max_chunk_size": 20,
}
if text:
print(f"Synthesis Start: {time.time() - start_time}")
generator = self.tts.run(args)
while True:
try:
audio_chunk = next(generator)
self.audio_queue.put(audio_chunk)
except StopIteration:
break
if not generating_text:
self.audio_queue.put((None, None))
print(f"Synthesis End ({time.time() - start_time:.2f}s)")
# Usage
tts = TTS()
"""
Time is only for debugging purposes. If not needed, feel free to remove.
Since this TTS model was built to be paired with LLM text streaming, we use a generating_text bool
this bool signifies if we are receiving the last chunk of streamed text (hence if we are generating anymore).
"""
tts.synthesize("One day, a fierce storm rolled in, bringing heavy rain and strong winds that threatened to destroy the wheat crops.", time.time(), False)
while tts.generating_audio:
time.sleep(0.1)
tts.synthesize("One day, a fierce storm rolled in, bringing heavy rain and strong winds that threatened to destroy the wheat crops.", time.time(), False)