From 134602cd99119365f4772db81c42b33037547a93 Mon Sep 17 00:00:00 2001 From: chasonjiang <1440499136@qq.com> Date: Mon, 11 Mar 2024 17:37:30 +0800 Subject: [PATCH 1/6] =?UTF-8?q?=E4=BF=AE=E6=94=B9=E4=BA=86=E9=BB=98?= =?UTF-8?q?=E8=AE=A4batch=5Fsize=E4=B8=BA20?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- GPT_SoVITS/inference_webui.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/GPT_SoVITS/inference_webui.py b/GPT_SoVITS/inference_webui.py index 5a26615a..9c8c83d7 100644 --- a/GPT_SoVITS/inference_webui.py +++ b/GPT_SoVITS/inference_webui.py @@ -105,7 +105,7 @@ def inference(text, text_lang, "batch_size":int(batch_size), "speed_factor":float(speed_factor), "split_bucket":split_bucket, - "return_fragment":False + "return_fragment":True } for item in tts_pipline.run(inputs): @@ -186,7 +186,7 @@ with gr.Blocks(title="GPT-SoVITS WebUI") as app: with gr.Row(): with gr.Column(): - batch_size = gr.Slider(minimum=1,maximum=200,step=1,label=i18n("batch_size"),value=1,interactive=True) + batch_size = gr.Slider(minimum=1,maximum=200,step=1,label=i18n("batch_size"),value=20,interactive=True) speed_factor = gr.Slider(minimum=0.25,maximum=4,step=0.05,label="speed_factor",value=1.0,interactive=True) top_k = gr.Slider(minimum=1,maximum=100,step=1,label=i18n("top_k"),value=5,interactive=True) top_p = gr.Slider(minimum=0,maximum=1,step=0.05,label=i18n("top_p"),value=1,interactive=True) From 46826b28b0377d9a1f7f0625c12b40f36f911b37 Mon Sep 17 00:00:00 2001 From: chasonjiang <1440499136@qq.com> Date: Mon, 11 Mar 2024 17:43:31 +0800 Subject: [PATCH 2/6] =?UTF-8?q?=E8=AE=BE=E7=BD=AE=E9=BB=98=E8=AE=A4?= =?UTF-8?q?=E8=BF=94=E5=9B=9E=E4=B8=80=E6=95=B4=E7=AF=87=E6=96=87=E7=AB=A0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- GPT_SoVITS/inference_webui.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/GPT_SoVITS/inference_webui.py b/GPT_SoVITS/inference_webui.py index 9c8c83d7..2308b389 100644 --- a/GPT_SoVITS/inference_webui.py +++ b/GPT_SoVITS/inference_webui.py @@ -105,7 +105,7 @@ def inference(text, text_lang, "batch_size":int(batch_size), "speed_factor":float(speed_factor), "split_bucket":split_bucket, - "return_fragment":True + "return_fragment":False } for item in tts_pipline.run(inputs): From bfd72860687be2a628e54f20574c6fc28371cd57 Mon Sep 17 00:00:00 2001 From: chasonjiang <1440499136@qq.com> Date: Mon, 11 Mar 2024 19:35:55 +0800 Subject: [PATCH 3/6] =?UTF-8?q?=E4=BF=AE=E5=A4=8D=E4=BA=86,=E4=B8=AD?= =?UTF-8?q?=E8=8B=B1=E6=96=87=E6=B7=B7=E5=90=88=E6=96=87=E6=9C=AC=E5=90=88?= =?UTF-8?q?=E6=88=90=E8=8B=B1=E6=96=87=E6=97=B6,=20=E5=87=BA=E7=8E=B0?= =?UTF-8?q?=E7=A9=BA=E5=AD=97=E7=AC=A6=E6=8A=A5=E9=94=99=E7=9A=84=E9=97=AE?= =?UTF-8?q?=E9=A2=98=20=E4=BC=98=E5=8C=96=E4=BA=86=E4=BB=A3=E7=A0=81,=20?= =?UTF-8?q?=E5=A2=9E=E5=8A=A0=E4=BA=86=E5=81=A5=E5=A3=AE=E6=80=A7?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- GPT_SoVITS/TTS_infer_pack/TTS.py | 85 +++++++++++++++---- GPT_SoVITS/TTS_infer_pack/TextPreprocessor.py | 28 ++++-- 2 files changed, 88 insertions(+), 25 deletions(-) diff --git a/GPT_SoVITS/TTS_infer_pack/TTS.py b/GPT_SoVITS/TTS_infer_pack/TTS.py index 249e945d..7912ddf6 100644 --- a/GPT_SoVITS/TTS_infer_pack/TTS.py +++ b/GPT_SoVITS/TTS_infer_pack/TTS.py @@ -173,35 +173,36 @@ class TTS: self.stop_flag:bool = False + self.precison:torch.dtype = torch.float16 if self.configs.is_half else torch.float32 def _init_models(self,): self.init_t2s_weights(self.configs.t2s_weights_path) self.init_vits_weights(self.configs.vits_weights_path) self.init_bert_weights(self.configs.bert_base_path) self.init_cnhuhbert_weights(self.configs.cnhuhbert_base_path) + self.enable_half_precision(self.configs.is_half) def init_cnhuhbert_weights(self, base_path: str): + print(f"Loading CNHuBERT weights from {base_path}") self.cnhuhbert_model = CNHubert(base_path) self.cnhuhbert_model=self.cnhuhbert_model.eval() - if self.configs.is_half == True: - self.cnhuhbert_model = self.cnhuhbert_model.half() self.cnhuhbert_model = self.cnhuhbert_model.to(self.configs.device) def init_bert_weights(self, base_path: str): + print(f"Loading BERT weights from {base_path}") self.bert_tokenizer = AutoTokenizer.from_pretrained(base_path) self.bert_model = AutoModelForMaskedLM.from_pretrained(base_path) self.bert_model=self.bert_model.eval() - if self.configs.is_half: - self.bert_model = self.bert_model.half() self.bert_model = self.bert_model.to(self.configs.device) def init_vits_weights(self, weights_path: str): + print(f"Loading VITS weights from {weights_path}") self.configs.vits_weights_path = weights_path self.configs.save_configs() dict_s2 = torch.load(weights_path, map_location=self.configs.device) @@ -224,8 +225,6 @@ class TTS: if hasattr(vits_model, "enc_q"): del vits_model.enc_q - if self.configs.is_half: - vits_model = vits_model.half() vits_model = vits_model.to(self.configs.device) vits_model = vits_model.eval() vits_model.load_state_dict(dict_s2["weight"], strict=False) @@ -233,6 +232,7 @@ class TTS: def init_t2s_weights(self, weights_path: str): + print(f"Loading Text2Semantic weights from {weights_path}") self.configs.t2s_weights_path = weights_path self.configs.save_configs() self.configs.hz = 50 @@ -242,12 +242,60 @@ class TTS: t2s_model = Text2SemanticLightningModule(config, "****", is_train=False, flash_attn_enabled=self.configs.flash_attn_enabled) t2s_model.load_state_dict(dict_s1["weight"]) - if self.configs.is_half: - t2s_model = t2s_model.half() t2s_model = t2s_model.to(self.configs.device) t2s_model = t2s_model.eval() self.t2s_model = t2s_model + def enable_half_precision(self, enable: bool = True): + ''' + To enable half precision for the TTS model. + Args: + enable: bool, whether to enable half precision. + + ''' + if self.configs.device == "cpu": + print("Half precision is not supported on CPU.") + return + + self.configs.is_half = enable + self.precison = torch.float16 if enable else torch.float32 + self.configs.save_configs() + if enable: + if self.t2s_model is not None: + self.t2s_model =self.t2s_model.half() + if self.vits_model is not None: + self.vits_model = self.vits_model.half() + if self.bert_model is not None: + self.bert_model =self.bert_model.half() + if self.cnhuhbert_model is not None: + self.cnhuhbert_model = self.cnhuhbert_model.half() + else: + if self.t2s_model is not None: + self.t2s_model = self.t2s_model.float() + if self.vits_model is not None: + self.vits_model = self.vits_model.float() + if self.bert_model is not None: + self.bert_model = self.bert_model.float() + if self.cnhuhbert_model is not None: + self.cnhuhbert_model = self.cnhuhbert_model.float() + + def set_device(self, device: torch.device): + ''' + To set the device for all models. + Args: + device: torch.device, the device to use for all models. + ''' + self.configs.device = device + self.configs.save_configs() + if self.t2s_model is not None: + self.t2s_model = self.t2s_model.to(device) + if self.vits_model is not None: + self.vits_model = self.vits_model.to(device) + if self.bert_model is not None: + self.bert_model = self.bert_model.to(device) + if self.cnhuhbert_model is not None: + self.cnhuhbert_model = self.cnhuhbert_model.to(device) + def set_ref_audio(self, ref_audio_path:str): ''' To set the reference audio for the TTS model, @@ -347,7 +395,7 @@ class TTS: pos_end = min(pos+batch_size,index_and_len_list.shape[0]) while pos < pos_end: batch=index_and_len_list[pos:pos_end, 1].astype(np.float32) - score=batch[(pos_end-pos)//2]/batch.mean() + score=batch[(pos_end-pos)//2]/(batch.mean()+1e-8) if (score>=threshold) or (pos_end-pos==1): batch_index=index_and_len_list[pos:pos_end, 0].tolist() batch_index_list_len += len(batch_index) @@ -379,13 +427,13 @@ class TTS: for item in item_list: if prompt_data is not None: all_bert_features = torch.cat([prompt_data["bert_features"], item["bert_features"]], 1)\ - .to(dtype=torch.float32 if not self.configs.is_half else torch.float16) + .to(dtype=self.precison) all_phones = torch.LongTensor(prompt_data["phones"]+item["phones"]) phones = torch.LongTensor(item["phones"]) # norm_text = prompt_data["norm_text"]+item["norm_text"] else: all_bert_features = item["bert_features"]\ - .to(dtype=torch.float32 if not self.configs.is_half else torch.float16) + .to(dtype=self.precison) phones = torch.LongTensor(item["phones"]) all_phones = phones # norm_text = item["norm_text"] @@ -405,7 +453,7 @@ class TTS: # phones_batch = self.batch_sequences(phones_list, axis=0, pad_value=0, max_length=max_len) all_phones_batch = self.batch_sequences(all_phones_list, axis=0, pad_value=0, max_length=max_len) # all_bert_features_batch = all_bert_features_list - all_bert_features_batch = torch.zeros(len(item_list), 1024, max_len, dtype=torch.float32) + all_bert_features_batch = torch.zeros(len(item_list), 1024, max_len, dtype=self.precison) for idx, item in enumerate(all_bert_features_list): all_bert_features_batch[idx, :, : item.shape[-1]] = item @@ -535,6 +583,11 @@ class TTS: ###### text preprocessing ######## data = self.text_preprocessor.preprocess(text, text_lang, text_split_method) + if len(data) == 0: + yield self.configs.sampling_rate, np.zeros(int(self.configs.sampling_rate * 0.3), + dtype=np.int16) + return + t1 = ttime() data, batch_index_list = self.to_batch(data, prompt_data=self.prompt_cache if not no_prompt_text else None, @@ -587,10 +640,8 @@ class TTS: t4 = ttime() t_34 += t4 - t3 - refer_audio_spepc:torch.Tensor = self.prompt_cache["refer_spepc"].to(self.configs.device) - if self.configs.is_half: - refer_audio_spepc = refer_audio_spepc.half() - + refer_audio_spepc:torch.Tensor = self.prompt_cache["refer_spepc"]\ + .to(dtype=self.precison, device=self.configs.device) batch_audio_fragment = [] @@ -672,7 +723,7 @@ class TTS: split_bucket:bool=True)->tuple[int, np.ndarray]: zero_wav = torch.zeros( int(self.configs.sampling_rate * 0.3), - dtype=torch.float16 if self.configs.is_half else torch.float32, + dtype=self.precison, device=self.configs.device ) diff --git a/GPT_SoVITS/TTS_infer_pack/TextPreprocessor.py b/GPT_SoVITS/TTS_infer_pack/TextPreprocessor.py index ee61bc30..58b2678c 100644 --- a/GPT_SoVITS/TTS_infer_pack/TextPreprocessor.py +++ b/GPT_SoVITS/TTS_infer_pack/TextPreprocessor.py @@ -59,6 +59,8 @@ class TextPreprocessor: print(i18n("############ 提取文本Bert特征 ############")) for text in tqdm(texts): phones, bert_features, norm_text = self.segment_and_extract_feature_for_text(text, lang) + if phones is None: + continue res={ "phones": phones, "bert_features": bert_features, @@ -79,12 +81,10 @@ class TextPreprocessor: while "\n\n" in text: text = text.replace("\n\n", "\n") - print(i18n("实际输入的目标文本(切句后):")) - print(text) + _texts = text.split("\n") _texts = merge_short_text_in_array(_texts, 5) texts = [] - for text in _texts: @@ -94,15 +94,21 @@ class TextPreprocessor: if (text[-1] not in splits): text += "。" if lang != "en" else "." # 解决句子过长导致Bert报错的问题 - texts.extend(split_big_text(text)) - + if (len(text) > 510): + texts.extend(split_big_text(text)) + else: + texts.append(text) + print(i18n("实际输入的目标文本(切句后):")) + print(texts) return texts def segment_and_extract_feature_for_text(self, texts:list, language:str)->Tuple[list, torch.Tensor, str]: textlist, langlist = self.seg_text(texts, language) - phones, bert_features, norm_text = self.extract_bert_feature(textlist, langlist) + if len(textlist) == 0: + return None, None, None + phones, bert_features, norm_text = self.extract_bert_feature(textlist, langlist) return phones, bert_features, norm_text @@ -113,6 +119,8 @@ class TextPreprocessor: if language in ["auto", "zh", "ja"]: LangSegment.setfilters(["zh","ja","en","ko"]) for tmp in LangSegment.getTexts(text): + if tmp["text"] == "": + continue if tmp["lang"] == "ko": langlist.append("zh") elif tmp["lang"] == "en": @@ -126,14 +134,18 @@ class TextPreprocessor: formattext = " ".join(tmp["text"] for tmp in LangSegment.getTexts(text)) while " " in formattext: formattext = formattext.replace(" ", " ") - textlist.append(formattext) - langlist.append("en") + if formattext != "": + textlist.append(formattext) + langlist.append("en") elif language in ["all_zh","all_ja"]: + formattext = text while " " in formattext: formattext = formattext.replace(" ", " ") language = language.replace("all_","") + if text == "": + return [],[] textlist.append(formattext) langlist.append(language) From 511b99e4a96aafae856ab0fa5bcb22dee5c55a6c Mon Sep 17 00:00:00 2001 From: chasonjiang <1440499136@qq.com> Date: Tue, 12 Mar 2024 15:30:08 +0800 Subject: [PATCH 4/6] =?UTF-8?q?=E5=A2=9E=E5=8A=A0=E4=BA=86TTS=5FConfig?= =?UTF-8?q?=E7=B1=BB=E7=9A=84=E5=81=A5=E5=A3=AE=E6=80=A7?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- GPT_SoVITS/TTS_infer_pack/TTS.py | 94 ++++++++++++++++++++------------ GPT_SoVITS/inference_webui.py | 1 + 2 files changed, 61 insertions(+), 34 deletions(-) diff --git a/GPT_SoVITS/TTS_infer_pack/TTS.py b/GPT_SoVITS/TTS_infer_pack/TTS.py index 7912ddf6..694d4a7d 100644 --- a/GPT_SoVITS/TTS_infer_pack/TTS.py +++ b/GPT_SoVITS/TTS_infer_pack/TTS.py @@ -1,3 +1,4 @@ +from copy import deepcopy import math import os, sys import random @@ -50,18 +51,7 @@ custom: class TTS_Config: - def __init__(self, configs: Union[dict, str]): - configs_base_path:str = "GPT_SoVITS/configs/" - os.makedirs(configs_base_path, exist_ok=True) - self.configs_path:str = os.path.join(configs_base_path, "tts_infer.yaml") - if isinstance(configs, str): - self.configs_path = configs - configs:dict = self._load_configs(configs) - - # assert isinstance(configs, dict) - self.default_configs:dict = configs.get("default", None) - if self.default_configs is None: - self.default_configs={ + default_configs={ "device": "cpu", "is_half": False, "t2s_weights_path": "GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt", @@ -70,15 +60,54 @@ class TTS_Config: "bert_base_path": "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large", "flash_attn_enabled": True } - self.configs:dict = configs.get("custom", self.default_configs) + configs:dict = None + def __init__(self, configs: Union[dict, str]=None): - self.device = self.configs.get("device") - self.is_half = self.configs.get("is_half") - self.t2s_weights_path = self.configs.get("t2s_weights_path") - self.vits_weights_path = self.configs.get("vits_weights_path") - self.bert_base_path = self.configs.get("bert_base_path") - self.cnhuhbert_base_path = self.configs.get("cnhuhbert_base_path") - self.flash_attn_enabled = self.configs.get("flash_attn_enabled") + # 设置默认配置文件路径 + configs_base_path:str = "GPT_SoVITS/configs/" + os.makedirs(configs_base_path, exist_ok=True) + self.configs_path:str = os.path.join(configs_base_path, "tts_infer.yaml") + + if configs in ["", None]: + if not os.path.exists(self.configs_path): + self.save_configs() + print(f"Create default config file at {self.configs_path}") + configs:dict = {"default": deepcopy(self.default_configs)} + + if isinstance(configs, str): + self.configs_path = configs + configs:dict = self._load_configs(self.configs_path) + + assert isinstance(configs, dict) + default_configs:dict = configs.get("default", None) + if default_configs is not None: + self.default_configs = default_configs + + self.configs:dict = configs.get("custom", deepcopy(self.default_configs)) + + + self.device = self.configs.get("device", torch.device("cpu")) + self.is_half = self.configs.get("is_half", False) + self.flash_attn_enabled = self.configs.get("flash_attn_enabled", True) + self.t2s_weights_path = self.configs.get("t2s_weights_path", None) + self.vits_weights_path = self.configs.get("vits_weights_path", None) + self.bert_base_path = self.configs.get("bert_base_path", None) + self.cnhuhbert_base_path = self.configs.get("cnhuhbert_base_path", None) + + + if (self.t2s_weights_path in [None, ""]) or (not os.path.exists(self.t2s_weights_path)): + self.t2s_weights_path = self.default_configs['t2s_weights_path'] + print(f"fall back to default t2s_weights_path: {self.t2s_weights_path}") + if (self.vits_weights_path in [None, ""]) or (not os.path.exists(self.vits_weights_path)): + self.vits_weights_path = self.default_configs['vits_weights_path'] + print(f"fall back to default vits_weights_path: {self.vits_weights_path}") + if (self.bert_base_path in [None, ""]) or (not os.path.exists(self.bert_base_path)): + self.bert_base_path = self.default_configs['bert_base_path'] + print(f"fall back to default bert_base_path: {self.bert_base_path}") + if (self.cnhuhbert_base_path in [None, ""]) or (not os.path.exists(self.cnhuhbert_base_path)): + self.cnhuhbert_base_path = self.default_configs['cnhuhbert_base_path'] + print(f"fall back to default cnhuhbert_base_path: {self.cnhuhbert_base_path}") + self.update_configs() self.max_sec = None @@ -92,7 +121,7 @@ class TTS_Config: self.n_speakers:int = 300 self.langauges:list = ["auto", "en", "zh", "ja", "all_zh", "all_ja"] - print(self) + # print(self) def _load_configs(self, configs_path: str)->dict: with open(configs_path, 'r') as f: @@ -102,24 +131,18 @@ class TTS_Config: def save_configs(self, configs_path:str=None)->None: configs={ - "default": { - "device": "cpu", - "is_half": False, - "t2s_weights_path": "GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt", - "vits_weights_path": "GPT_SoVITS/pretrained_models/s2G488k.pth", - "cnhuhbert_base_path": "GPT_SoVITS/pretrained_models/chinese-hubert-base", - "bert_base_path": "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large", - "flash_attn_enabled": True - }, - "custom": self.update_configs() + "default":self.default_configs, } + if self.configs is not None: + configs["custom"] = self.update_configs() + if configs_path is None: configs_path = self.configs_path with open(configs_path, 'w') as f: yaml.dump(configs, f) def update_configs(self): - config = { + self.config = { "device" : str(self.device), "is_half" : self.is_half, "t2s_weights_path" : self.t2s_weights_path, @@ -128,7 +151,7 @@ class TTS_Config: "cnhuhbert_base_path": self.cnhuhbert_base_path, "flash_attn_enabled" : self.flash_attn_enabled } - return config + return self.config def __str__(self): self.configs = self.update_configs() @@ -137,6 +160,9 @@ class TTS_Config: string += f"{str(k).ljust(20)}: {str(v)}\n" string += "-" * 100 + '\n' return string + + def __repr__(self): + return self.__str__() class TTS: @@ -253,7 +279,7 @@ class TTS: enable: bool, whether to enable half precision. ''' - if self.configs.device == "cpu": + if self.configs.device == "cpu" and enable: print("Half precision is not supported on CPU.") return diff --git a/GPT_SoVITS/inference_webui.py b/GPT_SoVITS/inference_webui.py index 2308b389..bc680315 100644 --- a/GPT_SoVITS/inference_webui.py +++ b/GPT_SoVITS/inference_webui.py @@ -80,6 +80,7 @@ if cnhubert_base_path is not None: if bert_path is not None: tts_config.bert_base_path = bert_path +print(tts_config) tts_pipline = TTS(tts_config) gpt_path = tts_config.t2s_weights_path sovits_path = tts_config.vits_weights_path From 345f3203f84d6017151f1075bed0e917ac784130 Mon Sep 17 00:00:00 2001 From: chasonjiang <1440499136@qq.com> Date: Tue, 12 Mar 2024 16:08:50 +0800 Subject: [PATCH 5/6] =?UTF-8?q?=E4=BF=AE=E5=A4=8D=E4=BA=86=E7=83=AD?= =?UTF-8?q?=E5=88=87=E6=8D=A2=E6=A8=A1=E5=9E=8B=E6=97=B6=EF=BC=8C=E7=B2=BE?= =?UTF-8?q?=E5=BA=A6=E4=B8=8D=E5=8C=B9=E9=85=8D=E5=AF=BC=E8=87=B4=E7=9A=84?= =?UTF-8?q?=E9=94=99=E8=AF=AF=E3=80=82?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- GPT_SoVITS/TTS_infer_pack/TTS.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/GPT_SoVITS/TTS_infer_pack/TTS.py b/GPT_SoVITS/TTS_infer_pack/TTS.py index 694d4a7d..61ba7be1 100644 --- a/GPT_SoVITS/TTS_infer_pack/TTS.py +++ b/GPT_SoVITS/TTS_infer_pack/TTS.py @@ -206,7 +206,7 @@ class TTS: self.init_vits_weights(self.configs.vits_weights_path) self.init_bert_weights(self.configs.bert_base_path) self.init_cnhuhbert_weights(self.configs.cnhuhbert_base_path) - self.enable_half_precision(self.configs.is_half) + # self.enable_half_precision(self.configs.is_half) @@ -215,6 +215,8 @@ class TTS: self.cnhuhbert_model = CNHubert(base_path) self.cnhuhbert_model=self.cnhuhbert_model.eval() self.cnhuhbert_model = self.cnhuhbert_model.to(self.configs.device) + if self.configs.is_half: + self.cnhuhbert_model = self.cnhuhbert_model.half() @@ -224,6 +226,8 @@ class TTS: self.bert_model = AutoModelForMaskedLM.from_pretrained(base_path) self.bert_model=self.bert_model.eval() self.bert_model = self.bert_model.to(self.configs.device) + if self.configs.is_half: + self.bert_model = self.bert_model.half() @@ -255,6 +259,8 @@ class TTS: vits_model = vits_model.eval() vits_model.load_state_dict(dict_s2["weight"], strict=False) self.vits_model = vits_model + if self.configs.is_half: + self.vits_model = self.vits_model.half() def init_t2s_weights(self, weights_path: str): @@ -271,6 +277,8 @@ class TTS: t2s_model = t2s_model.to(self.configs.device) t2s_model = t2s_model.eval() self.t2s_model = t2s_model + if self.configs.is_half: + self.t2s_model = self.t2s_model.half() def enable_half_precision(self, enable: bool = True): ''' From f2cbc826c7b8767e2dfec9399f3350c094ce3c34 Mon Sep 17 00:00:00 2001 From: chasonjiang <1440499136@qq.com> Date: Tue, 12 Mar 2024 16:31:21 +0800 Subject: [PATCH 6/6] =?UTF-8?q?=E6=B3=A8=E9=87=8A=E4=BA=86inference=5Fwebu?= =?UTF-8?q?i=E4=B8=ADmps=E7=9A=84=E9=80=89=E9=A1=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- GPT_SoVITS/inference_webui.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/GPT_SoVITS/inference_webui.py b/GPT_SoVITS/inference_webui.py index bc680315..4c9d90f2 100644 --- a/GPT_SoVITS/inference_webui.py +++ b/GPT_SoVITS/inference_webui.py @@ -45,8 +45,8 @@ os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1' # 确保直接启动推理UI时 if torch.cuda.is_available(): device = "cuda" -elif torch.backends.mps.is_available(): - device = "mps" +# elif torch.backends.mps.is_available(): +# device = "mps" else: device = "cpu"