diff --git a/GPT_SoVITS/AR/models/t2s_model.py b/GPT_SoVITS/AR/models/t2s_model.py index f8f6582..8a32d0d 100644 --- a/GPT_SoVITS/AR/models/t2s_model.py +++ b/GPT_SoVITS/AR/models/t2s_model.py @@ -5,7 +5,7 @@ from typing import List, Optional import torch from tqdm import tqdm -from AR.models.utils import make_pad_mask +from AR.models.utils import make_pad_mask, make_pad_mask_left from AR.models.utils import ( topk_sampling, sample, @@ -162,7 +162,7 @@ class T2SBlock: ) return x, k_cache, v_cache - def decode_next_token(self, x:torch.Tensor, k_cache:torch.Tensor, v_cache:torch.Tensor, attn_mask:Optional[torch.Tensor]=None, torch_sdpa:bool=True): + def decode_next_token(self, x:torch.Tensor, k_cache:torch.Tensor, v_cache:torch.Tensor, attn_mask:torch.Tensor=None, torch_sdpa:bool=True): q, k, v = F.linear(x, self.qkv_w, self.qkv_b).chunk(3, dim=-1) k_cache = torch.cat([k_cache, k], dim=1) @@ -178,7 +178,7 @@ class T2SBlock: if torch_sdpa: - attn = F.scaled_dot_product_attention(q, k, v) + attn = F.scaled_dot_product_attention(q, k, v, (~attn_mask) if attn_mask is not None else None) else: attn = scaled_dot_product_attention(q, k, v, attn_mask) @@ -223,7 +223,7 @@ class T2STransformer: self, x:torch.Tensor, k_cache: List[torch.Tensor], v_cache: List[torch.Tensor], - attn_mask : Optional[torch.Tensor]=None, + attn_mask : torch.Tensor=None, torch_sdpa:bool=True ): for i in range(self.num_blocks): @@ -573,71 +573,88 @@ class Text2SemanticDecoder(nn.Module): x_item = self.ar_text_embedding(x_item.unsqueeze(0)) x_item = x_item + self.bert_proj(bert_item.transpose(0, 1).unsqueeze(0)) x_item = self.ar_text_position(x_item).squeeze(0) - x_item = F.pad(x_item,(0,0,0,max_len-x_item.shape[0]),value=0) if x_item.shape[0] torch.Tensor: return expaned_lengths >= lengths.unsqueeze(-1) +def make_pad_mask_left(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor: + """ + Args: + lengths: + A 1-D tensor containing sentence lengths. + max_len: + The length of masks. + Returns: + Return a 2-D bool tensor, where masked positions + are filled with `True` and non-masked positions are + filled with `False`. + + #>>> lengths = torch.tensor([1, 3, 2, 5]) + #>>> make_pad_mask(lengths) + tensor( + [ + [True, True, False], + [True, False, False], + [True, True, False], + ... + ] + ) + """ + assert lengths.ndim == 1, lengths.ndim + max_len = max(max_len, lengths.max()) + n = lengths.size(0) + seq_range = torch.arange(0, max_len, device=lengths.device) + expaned_lengths = seq_range.unsqueeze(0).repeat(n, 1) + expaned_lengths -= (max_len-lengths).unsqueeze(-1) + + return expaned_lengths<0 + + # https://github.com/microsoft/unilm/blob/master/xtune/src/transformers/modeling_utils.py def top_k_top_p_filtering( logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens_to_keep=1 diff --git a/GPT_SoVITS/TTS_infer_pack/TTS.py b/GPT_SoVITS/TTS_infer_pack/TTS.py index ee2ec1e..012cbf8 100644 --- a/GPT_SoVITS/TTS_infer_pack/TTS.py +++ b/GPT_SoVITS/TTS_infer_pack/TTS.py @@ -145,7 +145,15 @@ class TTS_Config: self.device = self.configs.get("device", torch.device("cpu")) + if "cuda" in str(self.device) and not torch.cuda.is_available(): + print(f"Warning: CUDA is not available, set device to CPU.") + self.device = torch.device("cpu") + self.is_half = self.configs.get("is_half", False) + # if str(self.device) == "cpu" and self.is_half: + # print(f"Warning: Half precision is not supported on CPU, set is_half to False.") + # self.is_half = False + self.version = version self.t2s_weights_path = self.configs.get("t2s_weights_path", None) self.vits_weights_path = self.configs.get("vits_weights_path", None) diff --git a/GPT_SoVITS/inference_webui.py b/GPT_SoVITS/inference_webui.py index e5a8a60..afae2cf 100644 --- a/GPT_SoVITS/inference_webui.py +++ b/GPT_SoVITS/inference_webui.py @@ -691,7 +691,7 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, wav_gen = bigvgan_model(cmf_res) audio=wav_gen[0][0]#.cpu().detach().numpy() max_audio=torch.abs(audio).max()#简单防止16bit爆音 - if max_audio>1:audio/=max_audio + if max_audio>1:audio=audio/max_audio audio_opt.append(audio) audio_opt.append(zero_wav_torch)#zero_wav t4 = ttime() diff --git a/GPT_SoVITS/module/models.py b/GPT_SoVITS/module/models.py index 623da80..33bd607 100644 --- a/GPT_SoVITS/module/models.py +++ b/GPT_SoVITS/module/models.py @@ -1162,6 +1162,7 @@ class SynthesizerTrnV3(nn.Module): use_sdp=True, semantic_frame_rate=None, freeze_quantizer=None, + version="v3", **kwargs): super().__init__() @@ -1182,6 +1183,7 @@ class SynthesizerTrnV3(nn.Module): self.segment_size = segment_size self.n_speakers = n_speakers self.gin_channels = gin_channels + self.version = version self.model_dim=512 self.use_sdp = use_sdp diff --git a/GPT_SoVITS/text/LangSegmenter/langsegmenter.py b/GPT_SoVITS/text/LangSegmenter/langsegmenter.py index cca5bf2..c558348 100644 --- a/GPT_SoVITS/text/LangSegmenter/langsegmenter.py +++ b/GPT_SoVITS/text/LangSegmenter/langsegmenter.py @@ -8,66 +8,7 @@ jieba.setLogLevel(logging.CRITICAL) # 更改fast_langdetect大模型位置 from pathlib import Path import fast_langdetect -fast_langdetect.ft_detect.infer.CACHE_DIRECTORY = Path(__file__).parent.parent.parent / "pretrained_models" / "fast_langdetect" - -# 防止win下无法读取模型 -import os -from typing import Optional -def load_fasttext_model( - model_path: Path, - download_url: Optional[str] = None, - proxy: Optional[str] = None, -): - """ - Load a FastText model, downloading it if necessary. - :param model_path: Path to the FastText model file - :param download_url: URL to download the model from - :param proxy: Proxy URL for downloading the model - :return: FastText model - :raises DetectError: If model loading fails - """ - if all([ - fast_langdetect.ft_detect.infer.VERIFY_FASTTEXT_LARGE_MODEL, - model_path.exists(), - model_path.name == fast_langdetect.ft_detect.infer.FASTTEXT_LARGE_MODEL_NAME, - ]): - if not fast_langdetect.ft_detect.infer.verify_md5(model_path, fast_langdetect.ft_detect.infer.VERIFY_FASTTEXT_LARGE_MODEL): - fast_langdetect.ft_detect.infer.logger.warning( - f"fast-langdetect: MD5 hash verification failed for {model_path}, " - f"please check the integrity of the downloaded file from {fast_langdetect.ft_detect.infer.FASTTEXT_LARGE_MODEL_URL}. " - "\n This may seriously reduce the prediction accuracy. " - "If you want to ignore this, please set `fast_langdetect.ft_detect.infer.VERIFY_FASTTEXT_LARGE_MODEL = None` " - ) - if not model_path.exists(): - if download_url: - fast_langdetect.ft_detect.infer.download_model(download_url, model_path, proxy) - if not model_path.exists(): - raise fast_langdetect.ft_detect.infer.DetectError(f"FastText model file not found at {model_path}") - - try: - # Load FastText model - if (re.match(r'^[A-Za-z0-9_/\\:.]*$', str(model_path))): - model = fast_langdetect.ft_detect.infer.fasttext.load_model(str(model_path)) - else: - python_dir = os.getcwd() - if (str(model_path)[:len(python_dir)].upper() == python_dir.upper()): - model = fast_langdetect.ft_detect.infer.fasttext.load_model(os.path.relpath(model_path, python_dir)) - else: - import tempfile - import shutil - with tempfile.NamedTemporaryFile(delete=False) as tmpfile: - shutil.copyfile(model_path, tmpfile.name) - - model = fast_langdetect.ft_detect.infer.fasttext.load_model(tmpfile.name) - os.unlink(tmpfile.name) - return model - - except Exception as e: - fast_langdetect.ft_detect.infer.logger.warning(f"fast-langdetect:Failed to load FastText model from {model_path}: {e}") - raise fast_langdetect.ft_detect.infer.DetectError(f"Failed to load FastText model: {e}") - -if os.name == 'nt': - fast_langdetect.ft_detect.infer.load_fasttext_model = load_fasttext_model +fast_langdetect.infer._default_detector = fast_langdetect.infer.LangDetector(fast_langdetect.infer.LangDetectConfig(cache_dir=Path(__file__).parent.parent.parent / "pretrained_models" / "fast_langdetect")) from split_lang import LangSplitter diff --git a/GPT_SoVITS/text/chinese.py b/GPT_SoVITS/text/chinese.py index 2255c6e..55dc997 100644 --- a/GPT_SoVITS/text/chinese.py +++ b/GPT_SoVITS/text/chinese.py @@ -17,6 +17,8 @@ pinyin_to_symbol_map = { for line in open(os.path.join(current_file_path, "opencpop-strict.txt")).readlines() } +import jieba_fast, logging +jieba_fast.setLogLevel(logging.CRITICAL) import jieba_fast.posseg as psg diff --git a/GPT_SoVITS/text/chinese2.py b/GPT_SoVITS/text/chinese2.py index f716b41..2b4599d 100644 --- a/GPT_SoVITS/text/chinese2.py +++ b/GPT_SoVITS/text/chinese2.py @@ -18,13 +18,15 @@ pinyin_to_symbol_map = { for line in open(os.path.join(current_file_path, "opencpop-strict.txt")).readlines() } +import jieba_fast, logging +jieba_fast.setLogLevel(logging.CRITICAL) import jieba_fast.posseg as psg # is_g2pw_str = os.environ.get("is_g2pw", "True")##默认开启 # is_g2pw = False#True if is_g2pw_str.lower() == 'true' else False is_g2pw = True#True if is_g2pw_str.lower() == 'true' else False if is_g2pw: - print("当前使用g2pw进行拼音推理") + # print("当前使用g2pw进行拼音推理") from text.g2pw import G2PWPinyin, correct_pronunciation parent_directory = os.path.dirname(current_file_path) g2pw = G2PWPinyin(model_dir="GPT_SoVITS/text/G2PWModel",model_source=os.environ.get("bert_path","GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large"),v_to_u=False, neutral_tone_with_five=True) diff --git a/GPT_SoVITS/text/japanese.py b/GPT_SoVITS/text/japanese.py index d815ef4..e023ce7 100644 --- a/GPT_SoVITS/text/japanese.py +++ b/GPT_SoVITS/text/japanese.py @@ -10,7 +10,7 @@ try: if os.name == 'nt': python_dir = os.getcwd() OPEN_JTALK_DICT_DIR = pyopenjtalk.OPEN_JTALK_DICT_DIR.decode("utf-8") - if not (re.match(r'^[A-Za-z0-9_/\\:.]*$', OPEN_JTALK_DICT_DIR)): + if not (re.match(r'^[A-Za-z0-9_/\\:.\-]*$', OPEN_JTALK_DICT_DIR)): if (OPEN_JTALK_DICT_DIR[:len(python_dir)].upper() == python_dir.upper()): OPEN_JTALK_DICT_DIR = os.path.join(os.path.relpath(OPEN_JTALK_DICT_DIR,python_dir)) else: @@ -25,7 +25,7 @@ try: OPEN_JTALK_DICT_DIR = os.path.join("TEMP", "ja", "open_jtalk_dic") pyopenjtalk.OPEN_JTALK_DICT_DIR = OPEN_JTALK_DICT_DIR.encode("utf-8") - if not (re.match(r'^[A-Za-z0-9_/\\:.]*$', current_file_path)): + if not (re.match(r'^[A-Za-z0-9_/\\:.\-]*$', current_file_path)): if (current_file_path[:len(python_dir)].upper() == python_dir.upper()): current_file_path = os.path.join(os.path.relpath(current_file_path,python_dir)) else: diff --git a/GPT_SoVITS/text/korean.py b/GPT_SoVITS/text/korean.py index 79d89af..daae41f 100644 --- a/GPT_SoVITS/text/korean.py +++ b/GPT_SoVITS/text/korean.py @@ -19,13 +19,13 @@ if os.name == 'nt': print(f'you have to install eunjeon. install it...') else: installpath = spam_spec.submodule_search_locations[0] - if not (re.match(r'^[A-Za-z0-9_/\\:.]*$', installpath)): + if not (re.match(r'^[A-Za-z0-9_/\\:.\-]*$', installpath)): import sys from eunjeon import Mecab as _Mecab class Mecab(_Mecab): def get_dicpath(installpath): - if not (re.match(r'^[A-Za-z0-9_/\\:.]*$', installpath)): + if not (re.match(r'^[A-Za-z0-9_/\\:.\-]*$', installpath)): import shutil python_dir = os.getcwd() if (installpath[:len(python_dir)].upper() == python_dir.upper()): diff --git a/api.py b/api.py index c5f7024..d92d9c8 100644 --- a/api.py +++ b/api.py @@ -150,9 +150,9 @@ sys.path.append(now_dir) sys.path.append("%s/GPT_SoVITS" % (now_dir)) import signal -import LangSegment +from text.LangSegmenter import LangSegmenter from time import time as ttime -import torch +import torch, torchaudio import librosa import soundfile as sf from fastapi import FastAPI, Request, Query, HTTPException @@ -162,7 +162,8 @@ from transformers import AutoModelForMaskedLM, AutoTokenizer import numpy as np from feature_extractor import cnhubert from io import BytesIO -from module.models import SynthesizerTrn +from module.models import SynthesizerTrn, SynthesizerTrnV3 +from peft import LoraConfig, PeftModel, get_peft_model from AR.models.t2s_lightning_module import Text2SemanticLightningModule from text import cleaned_text_to_sequence from text.cleaner import clean_text @@ -197,6 +198,61 @@ def is_full(*items): # 任意一项为空返回False return True +def init_bigvgan(): + global bigvgan_model + from BigVGAN import bigvgan + bigvgan_model = bigvgan.BigVGAN.from_pretrained("%s/GPT_SoVITS/pretrained_models/models--nvidia--bigvgan_v2_24khz_100band_256x" % (now_dir,), use_cuda_kernel=False) # if True, RuntimeError: Ninja is required to load C++ extensions + # remove weight norm in the model and set to eval mode + bigvgan_model.remove_weight_norm() + bigvgan_model = bigvgan_model.eval() + if is_half == True: + bigvgan_model = bigvgan_model.half().to(device) + else: + bigvgan_model = bigvgan_model.to(device) + + +resample_transform_dict={} +def resample(audio_tensor, sr0): + global resample_transform_dict + if sr0 not in resample_transform_dict: + resample_transform_dict[sr0] = torchaudio.transforms.Resample( + sr0, 24000 + ).to(device) + return resample_transform_dict[sr0](audio_tensor) + + +from module.mel_processing import spectrogram_torch,mel_spectrogram_torch +spec_min = -12 +spec_max = 2 +def norm_spec(x): + return (x - spec_min) / (spec_max - spec_min) * 2 - 1 +def denorm_spec(x): + return (x + 1) / 2 * (spec_max - spec_min) + spec_min +mel_fn=lambda x: mel_spectrogram_torch(x, **{ + "n_fft": 1024, + "win_size": 1024, + "hop_size": 256, + "num_mels": 100, + "sampling_rate": 24000, + "fmin": 0, + "fmax": None, + "center": False +}) + + +sr_model=None +def audio_sr(audio,sr): + global sr_model + if sr_model==None: + from tools.audio_sr import AP_BWE + try: + sr_model=AP_BWE(device,DictToAttrRecursive) + except FileNotFoundError: + logger.info("你没有下载超分模型的参数,因此不进行超分。如想超分请先参照教程把文件下载") + return audio.cpu().detach().numpy(),sr + return sr_model(audio,sr) + + class Speaker: def __init__(self, name, gpt, sovits, phones = None, bert = None, prompt = None): self.name = name @@ -214,31 +270,72 @@ class Sovits: self.vq_model = vq_model self.hps = hps +from process_ckpt import get_sovits_version_from_path_fast,load_sovits_new def get_sovits_weights(sovits_path): - dict_s2 = torch.load(sovits_path, map_location="cpu") + path_sovits_v3="GPT_SoVITS/pretrained_models/s2Gv3.pth" + is_exist_s2gv3=os.path.exists(path_sovits_v3) + + version, model_version, if_lora_v3=get_sovits_version_from_path_fast(sovits_path) + if if_lora_v3==True and is_exist_s2gv3==False: + logger.info("SoVITS V3 底模缺失,无法加载相应 LoRA 权重") + + dict_s2 = load_sovits_new(sovits_path) hps = dict_s2["config"] hps = DictToAttrRecursive(hps) hps.model.semantic_frame_rate = "25hz" - if dict_s2['weight']['enc_p.text_embedding.weight'].shape[0] == 322: + if 'enc_p.text_embedding.weight' not in dict_s2['weight']: + hps.model.version = "v2"#v3model,v2sybomls + elif dict_s2['weight']['enc_p.text_embedding.weight'].shape[0] == 322: hps.model.version = "v1" else: hps.model.version = "v2" - logger.info(f"模型版本: {hps.model.version}") + + if model_version == "v3": + hps.model.version = "v3" + model_params_dict = vars(hps.model) - vq_model = SynthesizerTrn( - hps.data.filter_length // 2 + 1, - hps.train.segment_size // hps.data.hop_length, - n_speakers=hps.data.n_speakers, - **model_params_dict - ) + if model_version!="v3": + vq_model = SynthesizerTrn( + hps.data.filter_length // 2 + 1, + hps.train.segment_size // hps.data.hop_length, + n_speakers=hps.data.n_speakers, + **model_params_dict + ) + else: + vq_model = SynthesizerTrnV3( + hps.data.filter_length // 2 + 1, + hps.train.segment_size // hps.data.hop_length, + n_speakers=hps.data.n_speakers, + **model_params_dict + ) + init_bigvgan() + model_version=hps.model.version + logger.info(f"模型版本: {model_version}") if ("pretrained" not in sovits_path): - del vq_model.enc_q + try: + del vq_model.enc_q + except:pass if is_half == True: vq_model = vq_model.half().to(device) else: vq_model = vq_model.to(device) vq_model.eval() - vq_model.load_state_dict(dict_s2["weight"], strict=False) + if if_lora_v3 == False: + vq_model.load_state_dict(dict_s2["weight"], strict=False) + else: + vq_model.load_state_dict(load_sovits_new(path_sovits_v3)["weight"], strict=False) + lora_rank=dict_s2["lora_rank"] + lora_config = LoraConfig( + target_modules=["to_k", "to_q", "to_v", "to_out.0"], + r=lora_rank, + lora_alpha=lora_rank, + init_lora_weights=True, + ) + vq_model.cfm = get_peft_model(vq_model.cfm, lora_config) + vq_model.load_state_dict(dict_s2["weight"], strict=False) + vq_model.cfm = vq_model.cfm.merge_and_unload() + # torch.save(vq_model.state_dict(),"merge_win.pth") + vq_model.eval() sovits = Sovits(vq_model, hps) return sovits @@ -260,8 +357,8 @@ def get_gpt_weights(gpt_path): t2s_model = t2s_model.half() t2s_model = t2s_model.to(device) t2s_model.eval() - total = sum([param.nelement() for param in t2s_model.parameters()]) - logger.info("Number of parameter: %.2fM" % (total / 1e6)) + # total = sum([param.nelement() for param in t2s_model.parameters()]) + # logger.info("Number of parameter: %.2fM" % (total / 1e6)) gpt = Gpt(max_sec, t2s_model) return gpt @@ -295,6 +392,7 @@ def get_bert_feature(text, word2ph): def clean_text_inf(text, language, version): + language = language.replace("all_","") phones, word2ph, norm_text = clean_text(text, language, version) phones = cleaned_text_to_sequence(phones, version) return phones, word2ph, norm_text @@ -315,16 +413,10 @@ def get_bert_inf(phones, word2ph, norm_text, language): from text import chinese def get_phones_and_bert(text,language,version,final=False): if language in {"en", "all_zh", "all_ja", "all_ko", "all_yue"}: - language = language.replace("all_","") - if language == "en": - LangSegment.setfilters(["en"]) - formattext = " ".join(tmp["text"] for tmp in LangSegment.getTexts(text)) - else: - # 因无法区别中日韩文汉字,以用户输入为准 - formattext = text + formattext = text while " " in formattext: formattext = formattext.replace(" ", " ") - if language == "zh": + if language == "all_zh": if re.search(r'[A-Za-z]', formattext): formattext = re.sub(r'[a-z]', lambda x: x.group(0).upper(), formattext) formattext = chinese.mix_text_normalize(formattext) @@ -332,7 +424,7 @@ def get_phones_and_bert(text,language,version,final=False): else: phones, word2ph, norm_text = clean_text_inf(formattext, language, version) bert = get_bert_feature(norm_text, word2ph).to(device) - elif language == "yue" and re.search(r'[A-Za-z]', formattext): + elif language == "all_yue" and re.search(r'[A-Za-z]', formattext): formattext = re.sub(r'[a-z]', lambda x: x.group(0).upper(), formattext) formattext = chinese.mix_text_normalize(formattext) return get_phones_and_bert(formattext,"yue",version) @@ -345,19 +437,18 @@ def get_phones_and_bert(text,language,version,final=False): elif language in {"zh", "ja", "ko", "yue", "auto", "auto_yue"}: textlist=[] langlist=[] - LangSegment.setfilters(["zh","ja","en","ko"]) if language == "auto": - for tmp in LangSegment.getTexts(text): + for tmp in LangSegmenter.getTexts(text): langlist.append(tmp["lang"]) textlist.append(tmp["text"]) elif language == "auto_yue": - for tmp in LangSegment.getTexts(text): + for tmp in LangSegmenter.getTexts(text): if tmp["lang"] == "zh": tmp["lang"] = "yue" langlist.append(tmp["lang"]) textlist.append(tmp["text"]) else: - for tmp in LangSegment.getTexts(text): + for tmp in LangSegmenter.getTexts(text): if tmp["lang"] == "en": langlist.append(tmp["lang"]) else: @@ -556,10 +647,11 @@ def only_punc(text): splits = {",", "。", "?", "!", ",", ".", "?", "!", "~", ":", ":", "—", "…", } -def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, top_k= 15, top_p = 0.6, temperature = 0.6, speed = 1, inp_refs = None, spk = "default"): +def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, top_k= 15, top_p = 0.6, temperature = 0.6, speed = 1, inp_refs = None, sample_steps = 32, if_sr = False, spk = "default"): infer_sovits = speaker_list[spk].sovits vq_model = infer_sovits.vq_model hps = infer_sovits.hps + version = vq_model.version infer_gpt = speaker_list[spk].gpt t2s_model = infer_gpt.t2s_model @@ -587,20 +679,22 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, prompt_semantic = codes[0, 0] prompt = prompt_semantic.unsqueeze(0).to(device) - refers=[] - if(inp_refs): - for path in inp_refs: - try: - refer = get_spepc(hps, path).to(dtype).to(device) - refers.append(refer) - except Exception as e: - logger.error(e) - if(len(refers)==0): - refers = [get_spepc(hps, ref_wav_path).to(dtype).to(device)] + if version != "v3": + refers=[] + if(inp_refs): + for path in inp_refs: + try: + refer = get_spepc(hps, path).to(dtype).to(device) + refers.append(refer) + except Exception as e: + logger.error(e) + if(len(refers)==0): + refers = [get_spepc(hps, ref_wav_path).to(dtype).to(device)] + else: + refer = get_spepc(hps, ref_wav_path).to(device).to(dtype) t1 = ttime() - version = vq_model.version - os.environ['version'] = version + # os.environ['version'] = version prompt_language = dict_language[prompt_language.lower()] text_language = dict_language[text_language.lower()] phones1, bert1, norm_text1 = get_phones_and_bert(prompt_text, prompt_language, version) @@ -634,20 +728,82 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, early_stop_num=hz * max_sec) pred_semantic = pred_semantic[:, -idx:].unsqueeze(0) t3 = ttime() - audio = \ - vq_model.decode(pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0), - refers,speed=speed).detach().cpu().numpy()[ - 0, 0] ###试试重建不带上prompt部分 + + if version != "v3": + audio = \ + vq_model.decode(pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0), + refers,speed=speed).detach().cpu().numpy()[ + 0, 0] ###试试重建不带上prompt部分 + else: + phoneme_ids0=torch.LongTensor(phones1).to(device).unsqueeze(0) + phoneme_ids1=torch.LongTensor(phones2).to(device).unsqueeze(0) + # print(11111111, phoneme_ids0, phoneme_ids1) + fea_ref,ge = vq_model.decode_encp(prompt.unsqueeze(0), phoneme_ids0, refer) + ref_audio, sr = torchaudio.load(ref_wav_path) + ref_audio=ref_audio.to(device).float() + if (ref_audio.shape[0] == 2): + ref_audio = ref_audio.mean(0).unsqueeze(0) + if sr!=24000: + ref_audio=resample(ref_audio,sr) + # print("ref_audio",ref_audio.abs().mean()) + mel2 = mel_fn(ref_audio) + mel2 = norm_spec(mel2) + T_min = min(mel2.shape[2], fea_ref.shape[2]) + mel2 = mel2[:, :, :T_min] + fea_ref = fea_ref[:, :, :T_min] + if (T_min > 468): + mel2 = mel2[:, :, -468:] + fea_ref = fea_ref[:, :, -468:] + T_min = 468 + chunk_len = 934 - T_min + # print("fea_ref",fea_ref,fea_ref.shape) + # print("mel2",mel2) + mel2=mel2.to(dtype) + fea_todo, ge = vq_model.decode_encp(pred_semantic, phoneme_ids1, refer, ge,speed) + # print("fea_todo",fea_todo) + # print("ge",ge.abs().mean()) + cfm_resss = [] + idx = 0 + while (1): + fea_todo_chunk = fea_todo[:, :, idx:idx + chunk_len] + if (fea_todo_chunk.shape[-1] == 0): break + idx += chunk_len + fea = torch.cat([fea_ref, fea_todo_chunk], 2).transpose(2, 1) + # set_seed(123) + cfm_res = vq_model.cfm.inference(fea, torch.LongTensor([fea.size(1)]).to(fea.device), mel2, sample_steps, inference_cfg_rate=0) + cfm_res = cfm_res[:, :, mel2.shape[2]:] + mel2 = cfm_res[:, :, -T_min:] + # print("fea", fea) + # print("mel2in", mel2) + fea_ref = fea_todo_chunk[:, :, -T_min:] + cfm_resss.append(cfm_res) + cmf_res = torch.cat(cfm_resss, 2) + cmf_res = denorm_spec(cmf_res) + if bigvgan_model==None:init_bigvgan() + with torch.inference_mode(): + wav_gen = bigvgan_model(cmf_res) + audio=wav_gen[0][0].cpu().detach().numpy() + max_audio=np.abs(audio).max() if max_audio>1: audio/=max_audio audio_opt.append(audio) audio_opt.append(zero_wav) + audio_opt = np.concatenate(audio_opt, 0) t4 = ttime() + + sr = hps.data.sampling_rate if version != "v3" else 24000 + if if_sr and sr == 24000: + audio_opt = torch.from_numpy(audio_opt).float().to(device) + audio_opt,sr=audio_sr(audio_opt.unsqueeze(0),sr) + max_audio=np.abs(audio_opt).max() + if max_audio > 1: audio_opt /= max_audio + sr = 48000 + if is_int32: - audio_bytes = pack_audio(audio_bytes,(np.concatenate(audio_opt, 0) * 2147483647).astype(np.int32),hps.data.sampling_rate) + audio_bytes = pack_audio(audio_bytes,(audio_opt * 2147483647).astype(np.int32),sr) else: - audio_bytes = pack_audio(audio_bytes,(np.concatenate(audio_opt, 0) * 32768).astype(np.int16),hps.data.sampling_rate) + audio_bytes = pack_audio(audio_bytes,(audio_opt * 32768).astype(np.int16),sr) # logger.info("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t3 - t2, t4 - t3)) if stream_mode == "normal": audio_bytes, audio_chunk = read_clean_buffer(audio_bytes) @@ -655,7 +811,9 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, if not stream_mode == "normal": if media_type == "wav": - audio_bytes = pack_wav(audio_bytes,hps.data.sampling_rate) + sr = 48000 if if_sr else 24000 + sr = hps.data.sampling_rate if version != "v3" else sr + audio_bytes = pack_wav(audio_bytes,sr) yield audio_bytes.getvalue() @@ -688,7 +846,7 @@ def handle_change(path, text, language): return JSONResponse({"code": 0, "message": "Success"}, status_code=200) -def handle(refer_wav_path, prompt_text, prompt_language, text, text_language, cut_punc, top_k, top_p, temperature, speed, inp_refs): +def handle(refer_wav_path, prompt_text, prompt_language, text, text_language, cut_punc, top_k, top_p, temperature, speed, inp_refs, sample_steps, if_sr): if ( refer_wav_path == "" or refer_wav_path is None or prompt_text == "" or prompt_text is None @@ -702,12 +860,15 @@ def handle(refer_wav_path, prompt_text, prompt_language, text, text_language, cu if not default_refer.is_ready(): return JSONResponse({"code": 400, "message": "未指定参考音频且接口无预设"}, status_code=400) + if not sample_steps in [4,8,16,32]: + sample_steps = 32 + if cut_punc == None: text = cut_text(text,default_cut_punc) else: text = cut_text(text,cut_punc) - return StreamingResponse(get_tts_wav(refer_wav_path, prompt_text, prompt_language, text, text_language, top_k, top_p, temperature, speed, inp_refs), media_type="audio/"+media_type) + return StreamingResponse(get_tts_wav(refer_wav_path, prompt_text, prompt_language, text, text_language, top_k, top_p, temperature, speed, inp_refs, sample_steps, if_sr), media_type="audio/"+media_type) @@ -915,7 +1076,9 @@ async def tts_endpoint(request: Request): json_post_raw.get("top_p", 1.0), json_post_raw.get("temperature", 1.0), json_post_raw.get("speed", 1.0), - json_post_raw.get("inp_refs", []) + json_post_raw.get("inp_refs", []), + json_post_raw.get("sample_steps", 32), + json_post_raw.get("if_sr", False) ) @@ -931,9 +1094,11 @@ async def tts_endpoint( top_p: float = 1.0, temperature: float = 1.0, speed: float = 1.0, - inp_refs: list = Query(default=[]) + inp_refs: list = Query(default=[]), + sample_steps: int = 32, + if_sr: bool = False ): - return handle(refer_wav_path, prompt_text, prompt_language, text, text_language, cut_punc, top_k, top_p, temperature, speed, inp_refs) + return handle(refer_wav_path, prompt_text, prompt_language, text, text_language, cut_punc, top_k, top_p, temperature, speed, inp_refs, sample_steps, if_sr) if __name__ == "__main__": diff --git a/install.sh b/install.sh index 2fd2192..d4d7349 100644 --- a/install.sh +++ b/install.sh @@ -1,4 +1,7 @@ #!/bin/bash + +# 安装构建工具 +# Install build tools echo "Installing GCC..." conda install -c conda-forge gcc=14 @@ -8,6 +11,12 @@ conda install -c conda-forge gxx echo "Installing ffmpeg and cmake..." conda install ffmpeg cmake +# 设置编译环境 +# Set up build environment +export CMAKE_MAKE_PROGRAM="$CONDA_PREFIX/bin/cmake" +export CC="$CONDA_PREFIX/bin/gcc" +export CXX="$CONDA_PREFIX/bin/g++" + echo "Checking for CUDA installation..." if command -v nvidia-smi &> /dev/null; then USE_CUDA=true @@ -49,6 +58,10 @@ fi echo "Installing Python dependencies from requirements.txt..." + +# 刷新环境 +# Refresh environment +hash -r pip install -r requirements.txt if [ "$USE_ROCM" = true ] && [ "$IS_WSL" = true ] ; then @@ -60,3 +73,4 @@ if [ "$USE_ROCM" = true ] && [ "$IS_WSL" = true ] ; then fi echo "Installation completed successfully!" + diff --git a/requirements.txt b/requirements.txt index 144c729..0c0a9f7 100644 --- a/requirements.txt +++ b/requirements.txt @@ -25,7 +25,7 @@ psutil jieba_fast jieba split-lang -fast_langdetect +fast_langdetect>=0.3.0 Faster_Whisper wordsegment rotary_embedding_torch