Merge branch 'main' into main

This commit is contained in:
C3EZ 2025-03-09 10:36:02 +11:00 committed by GitHub
commit f72237d668
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 354 additions and 170 deletions

View File

@ -5,7 +5,7 @@ from typing import List, Optional
import torch
from tqdm import tqdm
from AR.models.utils import make_pad_mask
from AR.models.utils import make_pad_mask, make_pad_mask_left
from AR.models.utils import (
topk_sampling,
sample,
@ -162,7 +162,7 @@ class T2SBlock:
)
return x, k_cache, v_cache
def decode_next_token(self, x:torch.Tensor, k_cache:torch.Tensor, v_cache:torch.Tensor, attn_mask:Optional[torch.Tensor]=None, torch_sdpa:bool=True):
def decode_next_token(self, x:torch.Tensor, k_cache:torch.Tensor, v_cache:torch.Tensor, attn_mask:torch.Tensor=None, torch_sdpa:bool=True):
q, k, v = F.linear(x, self.qkv_w, self.qkv_b).chunk(3, dim=-1)
k_cache = torch.cat([k_cache, k], dim=1)
@ -178,7 +178,7 @@ class T2SBlock:
if torch_sdpa:
attn = F.scaled_dot_product_attention(q, k, v)
attn = F.scaled_dot_product_attention(q, k, v, (~attn_mask) if attn_mask is not None else None)
else:
attn = scaled_dot_product_attention(q, k, v, attn_mask)
@ -223,7 +223,7 @@ class T2STransformer:
self, x:torch.Tensor,
k_cache: List[torch.Tensor],
v_cache: List[torch.Tensor],
attn_mask : Optional[torch.Tensor]=None,
attn_mask : torch.Tensor=None,
torch_sdpa:bool=True
):
for i in range(self.num_blocks):
@ -573,71 +573,88 @@ class Text2SemanticDecoder(nn.Module):
x_item = self.ar_text_embedding(x_item.unsqueeze(0))
x_item = x_item + self.bert_proj(bert_item.transpose(0, 1).unsqueeze(0))
x_item = self.ar_text_position(x_item).squeeze(0)
x_item = F.pad(x_item,(0,0,0,max_len-x_item.shape[0]),value=0) if x_item.shape[0]<max_len else x_item
# x_item = F.pad(x_item,(0,0,0,max_len-x_item.shape[0]),value=0) if x_item.shape[0]<max_len else x_item ### padding right
x_item = F.pad(x_item,(0,0,max_len-x_item.shape[0],0),value=0) if x_item.shape[0]<max_len else x_item ### padding left
x_list.append(x_item)
x = torch.stack(x_list, dim=0)
x:torch.Tensor = torch.stack(x_list, dim=0)
# AR Decoder
y = prompts
x_len = x.shape[1]
x_attn_mask = torch.zeros((x_len, x_len), dtype=torch.bool)
stop = False
k_cache = None
v_cache = None
################### first step ##########################
if y is not None:
y_emb = self.ar_audio_embedding(y)
y_len = y_emb.shape[1]
prefix_len = y.shape[1]
y_lens = torch.LongTensor([y_emb.shape[1]]*y_emb.shape[0]).to(x.device)
y_pos = self.ar_audio_position(y_emb)
xy_pos = torch.concat([x, y_pos], dim=1)
ref_free = False
else:
y_emb = None
y_len = 0
prefix_len = 0
y_lens = torch.LongTensor([y_len]*x.shape[0]).to(x.device)
y_pos = None
xy_pos = x
y = torch.zeros(x.shape[0], 0, dtype=torch.int, device=x.device)
ref_free = True
assert y is not None, "Error: Prompt free is not supported batch_infer!"
ref_free = False
y_emb = self.ar_audio_embedding(y)
y_len = y_emb.shape[1]
prefix_len = y.shape[1]
y_lens = torch.LongTensor([y_emb.shape[1]]*y_emb.shape[0]).to(x.device)
y_pos = self.ar_audio_position(y_emb)
xy_pos = torch.concat([x, y_pos], dim=1)
##### create mask #####
bsz = x.shape[0]
src_len = x_len + y_len
y_paddind_mask = make_pad_mask(y_lens, y_len)
x_paddind_mask = make_pad_mask(x_lens, max_len)
y_paddind_mask = make_pad_mask_left(y_lens, y_len)
x_paddind_mask = make_pad_mask_left(x_lens, max_len)
# (bsz, x_len + y_len)
xy_padding_mask = torch.concat([x_paddind_mask, y_paddind_mask], dim=1)
padding_mask = torch.concat([x_paddind_mask, y_paddind_mask], dim=1)
x_mask = F.pad(
torch.zeros(x_len, x_len, dtype=torch.bool, device=x.device),
(0, y_len),
value=True,
)
x_mask = F.pad(
x_attn_mask,
(0, y_len), ###xx的纯0扩展到xx纯0+xy纯1(x,x+y)
value=True,
)
y_mask = F.pad( ###yy的右上1扩展到左边xy的0,(y,x+y)
torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1),
torch.triu(torch.ones(y_len, y_len, dtype=torch.bool, device=x.device), diagonal=1),
(x_len, 0),
value=False,
)
xy_mask = torch.concat([x_mask, y_mask], dim=0).view(1 , src_len, src_len).repeat(bsz, 1, 1).to(x.device)
_xy_padding_mask = xy_padding_mask.view(bsz, 1, src_len).repeat(1, src_len, 1)
for i in range(bsz):
l = x_lens[i]
_xy_padding_mask[i,l:max_len,:]=True
xy_attn_mask = xy_mask.logical_or(_xy_padding_mask)
xy_attn_mask = xy_attn_mask.unsqueeze(1).expand(-1, self.num_head, -1, -1)
xy_attn_mask = xy_attn_mask.bool()
xy_padding_mask = xy_padding_mask.view(bsz, src_len, 1)
causal_mask = torch.concat([x_mask, y_mask], dim=0).view(1 , src_len, src_len).repeat(bsz, 1, 1).to(x.device)
# padding_mask = padding_mask.unsqueeze(1) * padding_mask.unsqueeze(2) ### [b, x+y, x+y]
### 上面是错误的会导致padding的token被"看见"
# 正确的padding_mask应该是
# | pad_len | x_len | y_len |
# [[PAD, PAD, PAD, 1, 2, 3, 4, 5, 6],
# [PAD, PAD, PAD, 1, 2, 3, 4, 5, 6],
# [PAD, PAD, PAD, 1, 2, 3, 4, 5, 6], 前3行按理说也应该被mask掉但是为了防止计算attention时不出现nan还是保留了不影响结果
# [PAD, PAD, PAD, 1, 2, 3, 4, 5, 6],
# [PAD, PAD, PAD, 1, 2, 3, 4, 5, 6],
# [PAD, PAD, PAD, 1, 2, 3, 4, 5, 6],
# [PAD, PAD, PAD, 1, 2, 3, 4, 5, 6],
# [PAD, PAD, PAD, 1, 2, 3, 4, 5, 6],
# [PAD, PAD, PAD, 1, 2, 3, 4, 5, 6]]
padding_mask = padding_mask.view(bsz, 1, src_len).repeat(1, src_len, 1)
attn_mask:torch.Tensor = causal_mask.logical_or(padding_mask)
attn_mask = attn_mask.unsqueeze(1).expand(-1, self.num_head, -1, -1).bool()
# 正确的attn_mask应该是这样的
# | pad_len | x_len | y_len |
# [[PAD, PAD, PAD, 1, 2, 3, EOS, EOS, EOS],
# [PAD, PAD, PAD, 1, 2, 3, EOS, EOS, EOS],
# [PAD, PAD, PAD, 1, 2, 3, EOS, EOS, EOS], 前3行按理说也应该被mask掉但是为了防止计算attention时不出现nan还是保留了不影响结果
# [PAD, PAD, PAD, 1, 2, 3, EOS, EOS, EOS],
# [PAD, PAD, PAD, 1, 2, 3, EOS, EOS, EOS],
# [PAD, PAD, PAD, 1, 2, 3, EOS, EOS, EOS],
# [PAD, PAD, PAD, 1, 2, 3, 4, EOS, EOS],
# [PAD, PAD, PAD, 1, 2, 3, 4, 5, EOS],
# [PAD, PAD, PAD, 1, 2, 3, 4, 5, 6]]
###### decode #####
y_list = [None]*y.shape[0]
@ -645,18 +662,18 @@ class Text2SemanticDecoder(nn.Module):
idx_list = [None]*y.shape[0]
for idx in tqdm(range(1500)):
if idx == 0:
xy_dec, k_cache, v_cache = self.t2s_transformer.process_prompt(xy_pos, xy_attn_mask, xy_padding_mask, False)
xy_dec, k_cache, v_cache = self.t2s_transformer.process_prompt(xy_pos, attn_mask, None)
else:
xy_dec, k_cache, v_cache = self.t2s_transformer.decode_next_token(xy_pos, k_cache, v_cache, xy_attn_mask, False)
xy_dec, k_cache, v_cache = self.t2s_transformer.decode_next_token(xy_pos, k_cache, v_cache, attn_mask)
logits = self.ar_predict_layer(
xy_dec[:, -1]
)
if idx == 0:
xy_attn_mask = F.pad(xy_attn_mask[:,:,-1].unsqueeze(-2),(0,1),value=False)
attn_mask = F.pad(attn_mask[:,:,-1].unsqueeze(-2),(0,1),value=False)
logits = logits[:, :-1]
else:
xy_attn_mask = F.pad(xy_attn_mask,(0,1),value=False)
attn_mask = F.pad(attn_mask,(0,1),value=False)
samples = sample(
logits, y, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature
@ -686,7 +703,7 @@ class Text2SemanticDecoder(nn.Module):
if reserved_idx_of_batch_for_y is not None:
# index = torch.LongTensor(batch_idx_map).to(y.device)
y = torch.index_select(y, dim=0, index=reserved_idx_of_batch_for_y)
xy_attn_mask = torch.index_select(xy_attn_mask, dim=0, index=reserved_idx_of_batch_for_y)
attn_mask = torch.index_select(attn_mask, dim=0, index=reserved_idx_of_batch_for_y)
if k_cache is not None :
for i in range(len(k_cache)):
k_cache[i] = torch.index_select(k_cache[i], dim=0, index=reserved_idx_of_batch_for_y)

View File

@ -39,6 +39,39 @@ def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
return expaned_lengths >= lengths.unsqueeze(-1)
def make_pad_mask_left(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
"""
Args:
lengths:
A 1-D tensor containing sentence lengths.
max_len:
The length of masks.
Returns:
Return a 2-D bool tensor, where masked positions
are filled with `True` and non-masked positions are
filled with `False`.
#>>> lengths = torch.tensor([1, 3, 2, 5])
#>>> make_pad_mask(lengths)
tensor(
[
[True, True, False],
[True, False, False],
[True, True, False],
...
]
)
"""
assert lengths.ndim == 1, lengths.ndim
max_len = max(max_len, lengths.max())
n = lengths.size(0)
seq_range = torch.arange(0, max_len, device=lengths.device)
expaned_lengths = seq_range.unsqueeze(0).repeat(n, 1)
expaned_lengths -= (max_len-lengths).unsqueeze(-1)
return expaned_lengths<0
# https://github.com/microsoft/unilm/blob/master/xtune/src/transformers/modeling_utils.py
def top_k_top_p_filtering(
logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens_to_keep=1

View File

@ -145,7 +145,15 @@ class TTS_Config:
self.device = self.configs.get("device", torch.device("cpu"))
if "cuda" in str(self.device) and not torch.cuda.is_available():
print(f"Warning: CUDA is not available, set device to CPU.")
self.device = torch.device("cpu")
self.is_half = self.configs.get("is_half", False)
# if str(self.device) == "cpu" and self.is_half:
# print(f"Warning: Half precision is not supported on CPU, set is_half to False.")
# self.is_half = False
self.version = version
self.t2s_weights_path = self.configs.get("t2s_weights_path", None)
self.vits_weights_path = self.configs.get("vits_weights_path", None)

View File

@ -691,7 +691,7 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
wav_gen = bigvgan_model(cmf_res)
audio=wav_gen[0][0]#.cpu().detach().numpy()
max_audio=torch.abs(audio).max()#简单防止16bit爆音
if max_audio>1:audio/=max_audio
if max_audio>1:audio=audio/max_audio
audio_opt.append(audio)
audio_opt.append(zero_wav_torch)#zero_wav
t4 = ttime()

View File

@ -1162,6 +1162,7 @@ class SynthesizerTrnV3(nn.Module):
use_sdp=True,
semantic_frame_rate=None,
freeze_quantizer=None,
version="v3",
**kwargs):
super().__init__()
@ -1182,6 +1183,7 @@ class SynthesizerTrnV3(nn.Module):
self.segment_size = segment_size
self.n_speakers = n_speakers
self.gin_channels = gin_channels
self.version = version
self.model_dim=512
self.use_sdp = use_sdp

View File

@ -8,66 +8,7 @@ jieba.setLogLevel(logging.CRITICAL)
# 更改fast_langdetect大模型位置
from pathlib import Path
import fast_langdetect
fast_langdetect.ft_detect.infer.CACHE_DIRECTORY = Path(__file__).parent.parent.parent / "pretrained_models" / "fast_langdetect"
# 防止win下无法读取模型
import os
from typing import Optional
def load_fasttext_model(
model_path: Path,
download_url: Optional[str] = None,
proxy: Optional[str] = None,
):
"""
Load a FastText model, downloading it if necessary.
:param model_path: Path to the FastText model file
:param download_url: URL to download the model from
:param proxy: Proxy URL for downloading the model
:return: FastText model
:raises DetectError: If model loading fails
"""
if all([
fast_langdetect.ft_detect.infer.VERIFY_FASTTEXT_LARGE_MODEL,
model_path.exists(),
model_path.name == fast_langdetect.ft_detect.infer.FASTTEXT_LARGE_MODEL_NAME,
]):
if not fast_langdetect.ft_detect.infer.verify_md5(model_path, fast_langdetect.ft_detect.infer.VERIFY_FASTTEXT_LARGE_MODEL):
fast_langdetect.ft_detect.infer.logger.warning(
f"fast-langdetect: MD5 hash verification failed for {model_path}, "
f"please check the integrity of the downloaded file from {fast_langdetect.ft_detect.infer.FASTTEXT_LARGE_MODEL_URL}. "
"\n This may seriously reduce the prediction accuracy. "
"If you want to ignore this, please set `fast_langdetect.ft_detect.infer.VERIFY_FASTTEXT_LARGE_MODEL = None` "
)
if not model_path.exists():
if download_url:
fast_langdetect.ft_detect.infer.download_model(download_url, model_path, proxy)
if not model_path.exists():
raise fast_langdetect.ft_detect.infer.DetectError(f"FastText model file not found at {model_path}")
try:
# Load FastText model
if (re.match(r'^[A-Za-z0-9_/\\:.]*$', str(model_path))):
model = fast_langdetect.ft_detect.infer.fasttext.load_model(str(model_path))
else:
python_dir = os.getcwd()
if (str(model_path)[:len(python_dir)].upper() == python_dir.upper()):
model = fast_langdetect.ft_detect.infer.fasttext.load_model(os.path.relpath(model_path, python_dir))
else:
import tempfile
import shutil
with tempfile.NamedTemporaryFile(delete=False) as tmpfile:
shutil.copyfile(model_path, tmpfile.name)
model = fast_langdetect.ft_detect.infer.fasttext.load_model(tmpfile.name)
os.unlink(tmpfile.name)
return model
except Exception as e:
fast_langdetect.ft_detect.infer.logger.warning(f"fast-langdetect:Failed to load FastText model from {model_path}: {e}")
raise fast_langdetect.ft_detect.infer.DetectError(f"Failed to load FastText model: {e}")
if os.name == 'nt':
fast_langdetect.ft_detect.infer.load_fasttext_model = load_fasttext_model
fast_langdetect.infer._default_detector = fast_langdetect.infer.LangDetector(fast_langdetect.infer.LangDetectConfig(cache_dir=Path(__file__).parent.parent.parent / "pretrained_models" / "fast_langdetect"))
from split_lang import LangSplitter

View File

@ -17,6 +17,8 @@ pinyin_to_symbol_map = {
for line in open(os.path.join(current_file_path, "opencpop-strict.txt")).readlines()
}
import jieba_fast, logging
jieba_fast.setLogLevel(logging.CRITICAL)
import jieba_fast.posseg as psg

View File

@ -18,13 +18,15 @@ pinyin_to_symbol_map = {
for line in open(os.path.join(current_file_path, "opencpop-strict.txt")).readlines()
}
import jieba_fast, logging
jieba_fast.setLogLevel(logging.CRITICAL)
import jieba_fast.posseg as psg
# is_g2pw_str = os.environ.get("is_g2pw", "True")##默认开启
# is_g2pw = False#True if is_g2pw_str.lower() == 'true' else False
is_g2pw = True#True if is_g2pw_str.lower() == 'true' else False
if is_g2pw:
print("当前使用g2pw进行拼音推理")
# print("当前使用g2pw进行拼音推理")
from text.g2pw import G2PWPinyin, correct_pronunciation
parent_directory = os.path.dirname(current_file_path)
g2pw = G2PWPinyin(model_dir="GPT_SoVITS/text/G2PWModel",model_source=os.environ.get("bert_path","GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large"),v_to_u=False, neutral_tone_with_five=True)

View File

@ -10,7 +10,7 @@ try:
if os.name == 'nt':
python_dir = os.getcwd()
OPEN_JTALK_DICT_DIR = pyopenjtalk.OPEN_JTALK_DICT_DIR.decode("utf-8")
if not (re.match(r'^[A-Za-z0-9_/\\:.]*$', OPEN_JTALK_DICT_DIR)):
if not (re.match(r'^[A-Za-z0-9_/\\:.\-]*$', OPEN_JTALK_DICT_DIR)):
if (OPEN_JTALK_DICT_DIR[:len(python_dir)].upper() == python_dir.upper()):
OPEN_JTALK_DICT_DIR = os.path.join(os.path.relpath(OPEN_JTALK_DICT_DIR,python_dir))
else:
@ -25,7 +25,7 @@ try:
OPEN_JTALK_DICT_DIR = os.path.join("TEMP", "ja", "open_jtalk_dic")
pyopenjtalk.OPEN_JTALK_DICT_DIR = OPEN_JTALK_DICT_DIR.encode("utf-8")
if not (re.match(r'^[A-Za-z0-9_/\\:.]*$', current_file_path)):
if not (re.match(r'^[A-Za-z0-9_/\\:.\-]*$', current_file_path)):
if (current_file_path[:len(python_dir)].upper() == python_dir.upper()):
current_file_path = os.path.join(os.path.relpath(current_file_path,python_dir))
else:

View File

@ -19,13 +19,13 @@ if os.name == 'nt':
print(f'you have to install eunjeon. install it...')
else:
installpath = spam_spec.submodule_search_locations[0]
if not (re.match(r'^[A-Za-z0-9_/\\:.]*$', installpath)):
if not (re.match(r'^[A-Za-z0-9_/\\:.\-]*$', installpath)):
import sys
from eunjeon import Mecab as _Mecab
class Mecab(_Mecab):
def get_dicpath(installpath):
if not (re.match(r'^[A-Za-z0-9_/\\:.]*$', installpath)):
if not (re.match(r'^[A-Za-z0-9_/\\:.\-]*$', installpath)):
import shutil
python_dir = os.getcwd()
if (installpath[:len(python_dir)].upper() == python_dir.upper()):

273
api.py
View File

@ -150,9 +150,9 @@ sys.path.append(now_dir)
sys.path.append("%s/GPT_SoVITS" % (now_dir))
import signal
import LangSegment
from text.LangSegmenter import LangSegmenter
from time import time as ttime
import torch
import torch, torchaudio
import librosa
import soundfile as sf
from fastapi import FastAPI, Request, Query, HTTPException
@ -162,7 +162,8 @@ from transformers import AutoModelForMaskedLM, AutoTokenizer
import numpy as np
from feature_extractor import cnhubert
from io import BytesIO
from module.models import SynthesizerTrn
from module.models import SynthesizerTrn, SynthesizerTrnV3
from peft import LoraConfig, PeftModel, get_peft_model
from AR.models.t2s_lightning_module import Text2SemanticLightningModule
from text import cleaned_text_to_sequence
from text.cleaner import clean_text
@ -197,6 +198,61 @@ def is_full(*items): # 任意一项为空返回False
return True
def init_bigvgan():
global bigvgan_model
from BigVGAN import bigvgan
bigvgan_model = bigvgan.BigVGAN.from_pretrained("%s/GPT_SoVITS/pretrained_models/models--nvidia--bigvgan_v2_24khz_100band_256x" % (now_dir,), use_cuda_kernel=False) # if True, RuntimeError: Ninja is required to load C++ extensions
# remove weight norm in the model and set to eval mode
bigvgan_model.remove_weight_norm()
bigvgan_model = bigvgan_model.eval()
if is_half == True:
bigvgan_model = bigvgan_model.half().to(device)
else:
bigvgan_model = bigvgan_model.to(device)
resample_transform_dict={}
def resample(audio_tensor, sr0):
global resample_transform_dict
if sr0 not in resample_transform_dict:
resample_transform_dict[sr0] = torchaudio.transforms.Resample(
sr0, 24000
).to(device)
return resample_transform_dict[sr0](audio_tensor)
from module.mel_processing import spectrogram_torch,mel_spectrogram_torch
spec_min = -12
spec_max = 2
def norm_spec(x):
return (x - spec_min) / (spec_max - spec_min) * 2 - 1
def denorm_spec(x):
return (x + 1) / 2 * (spec_max - spec_min) + spec_min
mel_fn=lambda x: mel_spectrogram_torch(x, **{
"n_fft": 1024,
"win_size": 1024,
"hop_size": 256,
"num_mels": 100,
"sampling_rate": 24000,
"fmin": 0,
"fmax": None,
"center": False
})
sr_model=None
def audio_sr(audio,sr):
global sr_model
if sr_model==None:
from tools.audio_sr import AP_BWE
try:
sr_model=AP_BWE(device,DictToAttrRecursive)
except FileNotFoundError:
logger.info("你没有下载超分模型的参数,因此不进行超分。如想超分请先参照教程把文件下载")
return audio.cpu().detach().numpy(),sr
return sr_model(audio,sr)
class Speaker:
def __init__(self, name, gpt, sovits, phones = None, bert = None, prompt = None):
self.name = name
@ -214,31 +270,72 @@ class Sovits:
self.vq_model = vq_model
self.hps = hps
from process_ckpt import get_sovits_version_from_path_fast,load_sovits_new
def get_sovits_weights(sovits_path):
dict_s2 = torch.load(sovits_path, map_location="cpu")
path_sovits_v3="GPT_SoVITS/pretrained_models/s2Gv3.pth"
is_exist_s2gv3=os.path.exists(path_sovits_v3)
version, model_version, if_lora_v3=get_sovits_version_from_path_fast(sovits_path)
if if_lora_v3==True and is_exist_s2gv3==False:
logger.info("SoVITS V3 底模缺失,无法加载相应 LoRA 权重")
dict_s2 = load_sovits_new(sovits_path)
hps = dict_s2["config"]
hps = DictToAttrRecursive(hps)
hps.model.semantic_frame_rate = "25hz"
if dict_s2['weight']['enc_p.text_embedding.weight'].shape[0] == 322:
if 'enc_p.text_embedding.weight' not in dict_s2['weight']:
hps.model.version = "v2"#v3model,v2sybomls
elif dict_s2['weight']['enc_p.text_embedding.weight'].shape[0] == 322:
hps.model.version = "v1"
else:
hps.model.version = "v2"
logger.info(f"模型版本: {hps.model.version}")
if model_version == "v3":
hps.model.version = "v3"
model_params_dict = vars(hps.model)
vq_model = SynthesizerTrn(
hps.data.filter_length // 2 + 1,
hps.train.segment_size // hps.data.hop_length,
n_speakers=hps.data.n_speakers,
**model_params_dict
)
if model_version!="v3":
vq_model = SynthesizerTrn(
hps.data.filter_length // 2 + 1,
hps.train.segment_size // hps.data.hop_length,
n_speakers=hps.data.n_speakers,
**model_params_dict
)
else:
vq_model = SynthesizerTrnV3(
hps.data.filter_length // 2 + 1,
hps.train.segment_size // hps.data.hop_length,
n_speakers=hps.data.n_speakers,
**model_params_dict
)
init_bigvgan()
model_version=hps.model.version
logger.info(f"模型版本: {model_version}")
if ("pretrained" not in sovits_path):
del vq_model.enc_q
try:
del vq_model.enc_q
except:pass
if is_half == True:
vq_model = vq_model.half().to(device)
else:
vq_model = vq_model.to(device)
vq_model.eval()
vq_model.load_state_dict(dict_s2["weight"], strict=False)
if if_lora_v3 == False:
vq_model.load_state_dict(dict_s2["weight"], strict=False)
else:
vq_model.load_state_dict(load_sovits_new(path_sovits_v3)["weight"], strict=False)
lora_rank=dict_s2["lora_rank"]
lora_config = LoraConfig(
target_modules=["to_k", "to_q", "to_v", "to_out.0"],
r=lora_rank,
lora_alpha=lora_rank,
init_lora_weights=True,
)
vq_model.cfm = get_peft_model(vq_model.cfm, lora_config)
vq_model.load_state_dict(dict_s2["weight"], strict=False)
vq_model.cfm = vq_model.cfm.merge_and_unload()
# torch.save(vq_model.state_dict(),"merge_win.pth")
vq_model.eval()
sovits = Sovits(vq_model, hps)
return sovits
@ -260,8 +357,8 @@ def get_gpt_weights(gpt_path):
t2s_model = t2s_model.half()
t2s_model = t2s_model.to(device)
t2s_model.eval()
total = sum([param.nelement() for param in t2s_model.parameters()])
logger.info("Number of parameter: %.2fM" % (total / 1e6))
# total = sum([param.nelement() for param in t2s_model.parameters()])
# logger.info("Number of parameter: %.2fM" % (total / 1e6))
gpt = Gpt(max_sec, t2s_model)
return gpt
@ -295,6 +392,7 @@ def get_bert_feature(text, word2ph):
def clean_text_inf(text, language, version):
language = language.replace("all_","")
phones, word2ph, norm_text = clean_text(text, language, version)
phones = cleaned_text_to_sequence(phones, version)
return phones, word2ph, norm_text
@ -315,16 +413,10 @@ def get_bert_inf(phones, word2ph, norm_text, language):
from text import chinese
def get_phones_and_bert(text,language,version,final=False):
if language in {"en", "all_zh", "all_ja", "all_ko", "all_yue"}:
language = language.replace("all_","")
if language == "en":
LangSegment.setfilters(["en"])
formattext = " ".join(tmp["text"] for tmp in LangSegment.getTexts(text))
else:
# 因无法区别中日韩文汉字,以用户输入为准
formattext = text
formattext = text
while " " in formattext:
formattext = formattext.replace(" ", " ")
if language == "zh":
if language == "all_zh":
if re.search(r'[A-Za-z]', formattext):
formattext = re.sub(r'[a-z]', lambda x: x.group(0).upper(), formattext)
formattext = chinese.mix_text_normalize(formattext)
@ -332,7 +424,7 @@ def get_phones_and_bert(text,language,version,final=False):
else:
phones, word2ph, norm_text = clean_text_inf(formattext, language, version)
bert = get_bert_feature(norm_text, word2ph).to(device)
elif language == "yue" and re.search(r'[A-Za-z]', formattext):
elif language == "all_yue" and re.search(r'[A-Za-z]', formattext):
formattext = re.sub(r'[a-z]', lambda x: x.group(0).upper(), formattext)
formattext = chinese.mix_text_normalize(formattext)
return get_phones_and_bert(formattext,"yue",version)
@ -345,19 +437,18 @@ def get_phones_and_bert(text,language,version,final=False):
elif language in {"zh", "ja", "ko", "yue", "auto", "auto_yue"}:
textlist=[]
langlist=[]
LangSegment.setfilters(["zh","ja","en","ko"])
if language == "auto":
for tmp in LangSegment.getTexts(text):
for tmp in LangSegmenter.getTexts(text):
langlist.append(tmp["lang"])
textlist.append(tmp["text"])
elif language == "auto_yue":
for tmp in LangSegment.getTexts(text):
for tmp in LangSegmenter.getTexts(text):
if tmp["lang"] == "zh":
tmp["lang"] = "yue"
langlist.append(tmp["lang"])
textlist.append(tmp["text"])
else:
for tmp in LangSegment.getTexts(text):
for tmp in LangSegmenter.getTexts(text):
if tmp["lang"] == "en":
langlist.append(tmp["lang"])
else:
@ -556,10 +647,11 @@ def only_punc(text):
splits = {"", "", "", "", ",", ".", "?", "!", "~", ":", "", "", "", }
def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, top_k= 15, top_p = 0.6, temperature = 0.6, speed = 1, inp_refs = None, spk = "default"):
def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, top_k= 15, top_p = 0.6, temperature = 0.6, speed = 1, inp_refs = None, sample_steps = 32, if_sr = False, spk = "default"):
infer_sovits = speaker_list[spk].sovits
vq_model = infer_sovits.vq_model
hps = infer_sovits.hps
version = vq_model.version
infer_gpt = speaker_list[spk].gpt
t2s_model = infer_gpt.t2s_model
@ -587,20 +679,22 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
prompt_semantic = codes[0, 0]
prompt = prompt_semantic.unsqueeze(0).to(device)
refers=[]
if(inp_refs):
for path in inp_refs:
try:
refer = get_spepc(hps, path).to(dtype).to(device)
refers.append(refer)
except Exception as e:
logger.error(e)
if(len(refers)==0):
refers = [get_spepc(hps, ref_wav_path).to(dtype).to(device)]
if version != "v3":
refers=[]
if(inp_refs):
for path in inp_refs:
try:
refer = get_spepc(hps, path).to(dtype).to(device)
refers.append(refer)
except Exception as e:
logger.error(e)
if(len(refers)==0):
refers = [get_spepc(hps, ref_wav_path).to(dtype).to(device)]
else:
refer = get_spepc(hps, ref_wav_path).to(device).to(dtype)
t1 = ttime()
version = vq_model.version
os.environ['version'] = version
# os.environ['version'] = version
prompt_language = dict_language[prompt_language.lower()]
text_language = dict_language[text_language.lower()]
phones1, bert1, norm_text1 = get_phones_and_bert(prompt_text, prompt_language, version)
@ -634,20 +728,82 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
early_stop_num=hz * max_sec)
pred_semantic = pred_semantic[:, -idx:].unsqueeze(0)
t3 = ttime()
audio = \
vq_model.decode(pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0),
refers,speed=speed).detach().cpu().numpy()[
0, 0] ###试试重建不带上prompt部分
if version != "v3":
audio = \
vq_model.decode(pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0),
refers,speed=speed).detach().cpu().numpy()[
0, 0] ###试试重建不带上prompt部分
else:
phoneme_ids0=torch.LongTensor(phones1).to(device).unsqueeze(0)
phoneme_ids1=torch.LongTensor(phones2).to(device).unsqueeze(0)
# print(11111111, phoneme_ids0, phoneme_ids1)
fea_ref,ge = vq_model.decode_encp(prompt.unsqueeze(0), phoneme_ids0, refer)
ref_audio, sr = torchaudio.load(ref_wav_path)
ref_audio=ref_audio.to(device).float()
if (ref_audio.shape[0] == 2):
ref_audio = ref_audio.mean(0).unsqueeze(0)
if sr!=24000:
ref_audio=resample(ref_audio,sr)
# print("ref_audio",ref_audio.abs().mean())
mel2 = mel_fn(ref_audio)
mel2 = norm_spec(mel2)
T_min = min(mel2.shape[2], fea_ref.shape[2])
mel2 = mel2[:, :, :T_min]
fea_ref = fea_ref[:, :, :T_min]
if (T_min > 468):
mel2 = mel2[:, :, -468:]
fea_ref = fea_ref[:, :, -468:]
T_min = 468
chunk_len = 934 - T_min
# print("fea_ref",fea_ref,fea_ref.shape)
# print("mel2",mel2)
mel2=mel2.to(dtype)
fea_todo, ge = vq_model.decode_encp(pred_semantic, phoneme_ids1, refer, ge,speed)
# print("fea_todo",fea_todo)
# print("ge",ge.abs().mean())
cfm_resss = []
idx = 0
while (1):
fea_todo_chunk = fea_todo[:, :, idx:idx + chunk_len]
if (fea_todo_chunk.shape[-1] == 0): break
idx += chunk_len
fea = torch.cat([fea_ref, fea_todo_chunk], 2).transpose(2, 1)
# set_seed(123)
cfm_res = vq_model.cfm.inference(fea, torch.LongTensor([fea.size(1)]).to(fea.device), mel2, sample_steps, inference_cfg_rate=0)
cfm_res = cfm_res[:, :, mel2.shape[2]:]
mel2 = cfm_res[:, :, -T_min:]
# print("fea", fea)
# print("mel2in", mel2)
fea_ref = fea_todo_chunk[:, :, -T_min:]
cfm_resss.append(cfm_res)
cmf_res = torch.cat(cfm_resss, 2)
cmf_res = denorm_spec(cmf_res)
if bigvgan_model==None:init_bigvgan()
with torch.inference_mode():
wav_gen = bigvgan_model(cmf_res)
audio=wav_gen[0][0].cpu().detach().numpy()
max_audio=np.abs(audio).max()
if max_audio>1:
audio/=max_audio
audio_opt.append(audio)
audio_opt.append(zero_wav)
audio_opt = np.concatenate(audio_opt, 0)
t4 = ttime()
sr = hps.data.sampling_rate if version != "v3" else 24000
if if_sr and sr == 24000:
audio_opt = torch.from_numpy(audio_opt).float().to(device)
audio_opt,sr=audio_sr(audio_opt.unsqueeze(0),sr)
max_audio=np.abs(audio_opt).max()
if max_audio > 1: audio_opt /= max_audio
sr = 48000
if is_int32:
audio_bytes = pack_audio(audio_bytes,(np.concatenate(audio_opt, 0) * 2147483647).astype(np.int32),hps.data.sampling_rate)
audio_bytes = pack_audio(audio_bytes,(audio_opt * 2147483647).astype(np.int32),sr)
else:
audio_bytes = pack_audio(audio_bytes,(np.concatenate(audio_opt, 0) * 32768).astype(np.int16),hps.data.sampling_rate)
audio_bytes = pack_audio(audio_bytes,(audio_opt * 32768).astype(np.int16),sr)
# logger.info("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t3 - t2, t4 - t3))
if stream_mode == "normal":
audio_bytes, audio_chunk = read_clean_buffer(audio_bytes)
@ -655,7 +811,9 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
if not stream_mode == "normal":
if media_type == "wav":
audio_bytes = pack_wav(audio_bytes,hps.data.sampling_rate)
sr = 48000 if if_sr else 24000
sr = hps.data.sampling_rate if version != "v3" else sr
audio_bytes = pack_wav(audio_bytes,sr)
yield audio_bytes.getvalue()
@ -688,7 +846,7 @@ def handle_change(path, text, language):
return JSONResponse({"code": 0, "message": "Success"}, status_code=200)
def handle(refer_wav_path, prompt_text, prompt_language, text, text_language, cut_punc, top_k, top_p, temperature, speed, inp_refs):
def handle(refer_wav_path, prompt_text, prompt_language, text, text_language, cut_punc, top_k, top_p, temperature, speed, inp_refs, sample_steps, if_sr):
if (
refer_wav_path == "" or refer_wav_path is None
or prompt_text == "" or prompt_text is None
@ -702,12 +860,15 @@ def handle(refer_wav_path, prompt_text, prompt_language, text, text_language, cu
if not default_refer.is_ready():
return JSONResponse({"code": 400, "message": "未指定参考音频且接口无预设"}, status_code=400)
if not sample_steps in [4,8,16,32]:
sample_steps = 32
if cut_punc == None:
text = cut_text(text,default_cut_punc)
else:
text = cut_text(text,cut_punc)
return StreamingResponse(get_tts_wav(refer_wav_path, prompt_text, prompt_language, text, text_language, top_k, top_p, temperature, speed, inp_refs), media_type="audio/"+media_type)
return StreamingResponse(get_tts_wav(refer_wav_path, prompt_text, prompt_language, text, text_language, top_k, top_p, temperature, speed, inp_refs, sample_steps, if_sr), media_type="audio/"+media_type)
@ -915,7 +1076,9 @@ async def tts_endpoint(request: Request):
json_post_raw.get("top_p", 1.0),
json_post_raw.get("temperature", 1.0),
json_post_raw.get("speed", 1.0),
json_post_raw.get("inp_refs", [])
json_post_raw.get("inp_refs", []),
json_post_raw.get("sample_steps", 32),
json_post_raw.get("if_sr", False)
)
@ -931,9 +1094,11 @@ async def tts_endpoint(
top_p: float = 1.0,
temperature: float = 1.0,
speed: float = 1.0,
inp_refs: list = Query(default=[])
inp_refs: list = Query(default=[]),
sample_steps: int = 32,
if_sr: bool = False
):
return handle(refer_wav_path, prompt_text, prompt_language, text, text_language, cut_punc, top_k, top_p, temperature, speed, inp_refs)
return handle(refer_wav_path, prompt_text, prompt_language, text, text_language, cut_punc, top_k, top_p, temperature, speed, inp_refs, sample_steps, if_sr)
if __name__ == "__main__":

View File

@ -1,4 +1,7 @@
#!/bin/bash
# 安装构建工具
# Install build tools
echo "Installing GCC..."
conda install -c conda-forge gcc=14
@ -8,6 +11,12 @@ conda install -c conda-forge gxx
echo "Installing ffmpeg and cmake..."
conda install ffmpeg cmake
# 设置编译环境
# Set up build environment
export CMAKE_MAKE_PROGRAM="$CONDA_PREFIX/bin/cmake"
export CC="$CONDA_PREFIX/bin/gcc"
export CXX="$CONDA_PREFIX/bin/g++"
echo "Checking for CUDA installation..."
if command -v nvidia-smi &> /dev/null; then
USE_CUDA=true
@ -49,6 +58,10 @@ fi
echo "Installing Python dependencies from requirements.txt..."
# 刷新环境
# Refresh environment
hash -r
pip install -r requirements.txt
if [ "$USE_ROCM" = true ] && [ "$IS_WSL" = true ] ; then
@ -60,3 +73,4 @@ if [ "$USE_ROCM" = true ] && [ "$IS_WSL" = true ] ; then
fi
echo "Installation completed successfully!"

View File

@ -25,7 +25,7 @@ psutil
jieba_fast
jieba
split-lang
fast_langdetect
fast_langdetect>=0.3.0
Faster_Whisper
wordsegment
rotary_embedding_torch