diff --git a/.dockerignore b/.dockerignore
index dc39f76f..4eca27be 100644
--- a/.dockerignore
+++ b/.dockerignore
@@ -3,4 +3,6 @@ logs
output
reference
SoVITS_weights
-.git
\ No newline at end of file
+GPT_weights
+TEMP
+.git
diff --git a/.gitignore b/.gitignore
index 00f6bb99..96e754a9 100644
--- a/.gitignore
+++ b/.gitignore
@@ -7,5 +7,8 @@ runtime
output
logs
reference
+GPT_weights
SoVITS_weights
-GPT_weights
\ No newline at end of file
+TEMP
+
+
diff --git a/GPT_SoVITS/AR/data/bucket_sampler.py b/GPT_SoVITS/AR/data/bucket_sampler.py
index 647491f7..45f91d8e 100644
--- a/GPT_SoVITS/AR/data/bucket_sampler.py
+++ b/GPT_SoVITS/AR/data/bucket_sampler.py
@@ -1,4 +1,5 @@
-# modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/bucketsampler.py
+# modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/data/bucket_sampler.py
+# reference: https://github.com/lifeiteng/vall-e
import itertools
import math
import random
diff --git a/GPT_SoVITS/AR/data/data_module.py b/GPT_SoVITS/AR/data/data_module.py
index 037484a9..cb947959 100644
--- a/GPT_SoVITS/AR/data/data_module.py
+++ b/GPT_SoVITS/AR/data/data_module.py
@@ -1,4 +1,5 @@
-# modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/data_module.py
+# modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/data/data_module.py
+# reference: https://github.com/lifeiteng/vall-e
from pytorch_lightning import LightningDataModule
from AR.data.bucket_sampler import DistributedBucketSampler
from AR.data.dataset import Text2SemanticDataset
@@ -41,7 +42,8 @@ class Text2SemanticDataModule(LightningDataModule):
# pad_val=self.config['data']['pad_val'])
def train_dataloader(self):
- batch_size = max(min(self.config["train"]["batch_size"],len(self._train_dataset)//4),1)#防止不保存
+ batch_size=self.config["train"]["batch_size"]//2 if self.config["train"].get("if_dpo",False)==True else self.config["train"]["batch_size"]
+ batch_size = max(min(batch_size,len(self._train_dataset)//4),1)#防止不保存
sampler = DistributedBucketSampler(self._train_dataset, batch_size=batch_size)
return DataLoader(
self._train_dataset,
diff --git a/GPT_SoVITS/AR/data/dataset.py b/GPT_SoVITS/AR/data/dataset.py
index b1ea69e6..1a2ffef1 100644
--- a/GPT_SoVITS/AR/data/dataset.py
+++ b/GPT_SoVITS/AR/data/dataset.py
@@ -1,4 +1,5 @@
-# modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/t2s_dataset.py
+# modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/data/dataset.py
+# reference: https://github.com/lifeiteng/vall-e
import pdb
import sys
diff --git a/GPT_SoVITS/AR/models/t2s_lightning_module.py b/GPT_SoVITS/AR/models/t2s_lightning_module.py
index f9dfc648..2dd3f392 100644
--- a/GPT_SoVITS/AR/models/t2s_lightning_module.py
+++ b/GPT_SoVITS/AR/models/t2s_lightning_module.py
@@ -1,4 +1,5 @@
-# modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/model/t2s_lightning_module.py
+# modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/models/t2s_lightning_module.py
+# reference: https://github.com/lifeiteng/vall-e
import os, sys
now_dir = os.getcwd()
@@ -11,7 +12,6 @@ from AR.models.t2s_model import Text2SemanticDecoder
from AR.modules.lr_schedulers import WarmupCosineLRSchedule
from AR.modules.optim import ScaledAdam
-
class Text2SemanticLightningModule(LightningModule):
def __init__(self, config, output_dir, is_train=True):
super().__init__()
@@ -35,7 +35,8 @@ class Text2SemanticLightningModule(LightningModule):
def training_step(self, batch: Dict, batch_idx: int):
opt = self.optimizers()
scheduler = self.lr_schedulers()
- loss, acc = self.model.forward(
+ forward=self.model.forward if self.config["train"].get("if_dpo",False)==True else self.model.forward_old
+ loss, acc = forward(
batch["phoneme_ids"],
batch["phoneme_ids_len"],
batch["semantic_ids"],
diff --git a/GPT_SoVITS/AR/models/t2s_lightning_module_onnx.py b/GPT_SoVITS/AR/models/t2s_lightning_module_onnx.py
index bb9e30b9..487edb01 100644
--- a/GPT_SoVITS/AR/models/t2s_lightning_module_onnx.py
+++ b/GPT_SoVITS/AR/models/t2s_lightning_module_onnx.py
@@ -1,4 +1,5 @@
-# modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/model/t2s_lightning_module.py
+# modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/models/t2s_lightning_module.py
+# reference: https://github.com/lifeiteng/vall-e
import os, sys
now_dir = os.getcwd()
diff --git a/GPT_SoVITS/AR/models/t2s_model.py b/GPT_SoVITS/AR/models/t2s_model.py
index aaeace98..c8ad3d82 100644
--- a/GPT_SoVITS/AR/models/t2s_model.py
+++ b/GPT_SoVITS/AR/models/t2s_model.py
@@ -1,4 +1,5 @@
-# modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/model/t2s_model.py
+# modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/models/t2s_model.py
+# reference: https://github.com/lifeiteng/vall-e
import torch
from tqdm import tqdm
@@ -337,7 +338,7 @@ class Text2SemanticDecoder(nn.Module):
# AR Decoder
y = prompts
- prefix_len = y.shape[1]
+
x_len = x.shape[1]
x_attn_mask = torch.zeros((x_len, x_len), dtype=torch.bool)
stop = False
@@ -353,47 +354,41 @@ class Text2SemanticDecoder(nn.Module):
"first_infer": 1,
"stage": 0,
}
- for idx in tqdm(range(1500)):
- if cache["first_infer"] == 1:
- y_emb = self.ar_audio_embedding(y)
- else:
- y_emb = torch.cat(
- [cache["y_emb"], self.ar_audio_embedding(y[:, -1:])], 1
- )
- cache["y_emb"] = y_emb
+ ################### first step ##########################
+ if y is not None:
+ y_emb = self.ar_audio_embedding(y)
+ y_len = y_emb.shape[1]
+ prefix_len = y.shape[1]
y_pos = self.ar_audio_position(y_emb)
- # x 和逐渐增长的 y 一起输入给模型
- if cache["first_infer"] == 1:
- xy_pos = torch.concat([x, y_pos], dim=1)
- else:
- xy_pos = y_pos[:, -1:]
- y_len = y_pos.shape[1]
- ###以下3个不做缓存
- if cache["first_infer"] == 1:
- x_attn_mask_pad = F.pad(
+ xy_pos = torch.concat([x, y_pos], dim=1)
+ cache["y_emb"] = y_emb
+ ref_free = False
+ else:
+ y_emb = None
+ y_len = 0
+ prefix_len = 0
+ y_pos = None
+ xy_pos = x
+ y = torch.zeros(x.shape[0], 0, dtype=torch.int, device=x.device)
+ ref_free = True
+
+ x_attn_mask_pad = F.pad(
x_attn_mask,
(0, y_len), ###xx的纯0扩展到xx纯0+xy纯1,(x,x+y)
value=True,
)
- y_attn_mask = F.pad( ###yy的右上1扩展到左边xy的0,(y,x+y)
- torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1),
- (x_len, 0),
- value=False,
- )
- xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0).to(
- y.device
- )
- else:
- ###最右边一列(是错的)
- # xy_attn_mask=torch.ones((1, x_len+y_len), dtype=torch.bool,device=xy_pos.device)
- # xy_attn_mask[:,-1]=False
- ###最下面一行(是对的)
- xy_attn_mask = torch.zeros(
- (1, x_len + y_len), dtype=torch.bool, device=xy_pos.device
- )
- # pdb.set_trace()
- ###缓存重头戏
- # print(1111,xy_pos.shape,xy_attn_mask.shape,x_len,y_len)
+ y_attn_mask = F.pad( ###yy的右上1扩展到左边xy的0,(y,x+y)
+ torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1),
+ (x_len, 0),
+ value=False,
+ )
+ xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0).to(
+ x.device
+ )
+
+
+ for idx in tqdm(range(1500)):
+
xy_dec, _ = self.h((xy_pos, None), mask=xy_attn_mask, cache=cache)
logits = self.ar_predict_layer(
xy_dec[:, -1]
@@ -404,6 +399,10 @@ class Text2SemanticDecoder(nn.Module):
samples = sample(
logits[0], y, top_k=top_k, top_p=top_p, repetition_penalty=1.35, temperature=temperature
)[0].unsqueeze(0)
+ # 本次生成的 semantic_ids 和之前的 y 构成新的 y
+ # print(samples.shape)#[1,1]#第一个1是bs
+ y = torch.concat([y, samples], dim=1)
+
if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num:
print("use early stop num:", early_stop_num)
stop = True
@@ -412,13 +411,38 @@ class Text2SemanticDecoder(nn.Module):
# print(torch.argmax(logits, dim=-1)[0] == self.EOS, samples[0, 0] == self.EOS)
stop = True
if stop:
- if prompts.shape[1] == y.shape[1]:
+ # if prompts.shape[1] == y.shape[1]:
+ # y = torch.concat([y, torch.zeros_like(samples)], dim=1)
+ # print("bad zero prediction")
+ if y.shape[1]==0:
y = torch.concat([y, torch.zeros_like(samples)], dim=1)
print("bad zero prediction")
print(f"T2S Decoding EOS [{prefix_len} -> {y.shape[1]}]")
break
- # 本次生成的 semantic_ids 和之前的 y 构成新的 y
- # print(samples.shape)#[1,1]#第一个1是bs
- y = torch.concat([y, samples], dim=1)
+
+ ####################### update next step ###################################
cache["first_infer"] = 0
- return y, idx
+ if cache["y_emb"] is not None:
+ y_emb = torch.cat(
+ [cache["y_emb"], self.ar_audio_embedding(y[:, -1:])], dim = 1
+ )
+ cache["y_emb"] = y_emb
+ y_pos = self.ar_audio_position(y_emb)
+ xy_pos = y_pos[:, -1:]
+ else:
+ y_emb = self.ar_audio_embedding(y[:, -1:])
+ cache["y_emb"] = y_emb
+ y_pos = self.ar_audio_position(y_emb)
+ xy_pos = y_pos
+ y_len = y_pos.shape[1]
+
+ ###最右边一列(是错的)
+ # xy_attn_mask=torch.ones((1, x_len+y_len), dtype=torch.bool,device=xy_pos.device)
+ # xy_attn_mask[:,-1]=False
+ ###最下面一行(是对的)
+ xy_attn_mask = torch.zeros(
+ (1, x_len + y_len), dtype=torch.bool, device=xy_pos.device
+ )
+ if ref_free:
+ return y[:, :-1], 0
+ return y[:, :-1], idx-1
diff --git a/GPT_SoVITS/AR/models/t2s_model_onnx.py b/GPT_SoVITS/AR/models/t2s_model_onnx.py
index 92f2d745..7834297d 100644
--- a/GPT_SoVITS/AR/models/t2s_model_onnx.py
+++ b/GPT_SoVITS/AR/models/t2s_model_onnx.py
@@ -1,4 +1,5 @@
-# modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/model/t2s_model.py
+# modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/models/t2s_model.py
+# reference: https://github.com/lifeiteng/vall-e
import torch
from tqdm import tqdm
diff --git a/GPT_SoVITS/AR/models/utils.py b/GPT_SoVITS/AR/models/utils.py
index bc5f2d0f..9678c7e1 100644
--- a/GPT_SoVITS/AR/models/utils.py
+++ b/GPT_SoVITS/AR/models/utils.py
@@ -1,4 +1,5 @@
-# modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/model/utils.py\
+# modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/models/utils.py
+# reference: https://github.com/lifeiteng/vall-e
import torch
import torch.nn.functional as F
from typing import Tuple
@@ -114,7 +115,8 @@ def logits_to_probs(
top_p: Optional[int] = None,
repetition_penalty: float = 1.0,
):
- previous_tokens = previous_tokens.squeeze()
+ if previous_tokens is not None:
+ previous_tokens = previous_tokens.squeeze()
# print(logits.shape,previous_tokens.shape)
# pdb.set_trace()
if previous_tokens is not None and repetition_penalty != 1.0:
diff --git a/GPT_SoVITS/AR/modules/lr_schedulers.py b/GPT_SoVITS/AR/modules/lr_schedulers.py
index 7dec462b..b8867467 100644
--- a/GPT_SoVITS/AR/modules/lr_schedulers.py
+++ b/GPT_SoVITS/AR/modules/lr_schedulers.py
@@ -1,4 +1,5 @@
-# modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/model/lr_schedulers.py
+# modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/modules/lr_schedulers.py
+# reference: https://github.com/lifeiteng/vall-e
import math
import torch
diff --git a/GPT_SoVITS/AR/modules/patched_mha_with_cache.py b/GPT_SoVITS/AR/modules/patched_mha_with_cache.py
index 57206703..7be241da 100644
--- a/GPT_SoVITS/AR/modules/patched_mha_with_cache.py
+++ b/GPT_SoVITS/AR/modules/patched_mha_with_cache.py
@@ -5,8 +5,8 @@ from torch.nn.functional import (
_none_or_dtype,
_in_projection_packed,
)
-
-# import torch
+from torch.nn import functional as F
+import torch
# Tensor = torch.Tensor
# from typing import Callable, List, Optional, Tuple, Union
@@ -448,9 +448,11 @@ def multi_head_attention_forward_patched(
k = k.view(bsz, num_heads, src_len, head_dim)
v = v.view(bsz, num_heads, src_len, head_dim)
+ # with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=True):
attn_output = scaled_dot_product_attention(
q, k, v, attn_mask, dropout_p, is_causal
)
+
attn_output = (
attn_output.permute(2, 0, 1, 3).contiguous().view(bsz * tgt_len, embed_dim)
)
diff --git a/GPT_SoVITS/AR/text_processing/phonemizer.py b/GPT_SoVITS/AR/text_processing/phonemizer.py
index 9fcf5c09..9c5f58fb 100644
--- a/GPT_SoVITS/AR/text_processing/phonemizer.py
+++ b/GPT_SoVITS/AR/text_processing/phonemizer.py
@@ -1,4 +1,5 @@
-# modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/text_processing/phonemizer.py
+# modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/text_processing/phonemizer.py
+# reference: https://github.com/lifeiteng/vall-e
import itertools
import re
from typing import Dict
diff --git a/GPT_SoVITS/AR/text_processing/symbols.py b/GPT_SoVITS/AR/text_processing/symbols.py
index c57e2d41..7d754a78 100644
--- a/GPT_SoVITS/AR/text_processing/symbols.py
+++ b/GPT_SoVITS/AR/text_processing/symbols.py
@@ -1,4 +1,5 @@
-# modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/text_processing/symbols.py
+# modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/text_processing/symbols.py
+# reference: https://github.com/lifeiteng/vall-e
PAD = "_"
PUNCTUATION = ';:,.!?¡¿—…"«»“” '
LETTERS = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
diff --git a/GPT_SoVITS/inference_webui.py b/GPT_SoVITS/inference_webui.py
index 9c5197a7..ee099627 100644
--- a/GPT_SoVITS/inference_webui.py
+++ b/GPT_SoVITS/inference_webui.py
@@ -16,6 +16,7 @@ logging.getLogger("asyncio").setLevel(logging.ERROR)
logging.getLogger("charset_normalizer").setLevel(logging.ERROR)
logging.getLogger("torchaudio._extension").setLevel(logging.ERROR)
import pdb
+import torch
if os.path.exists("./gweight.txt"):
with open("./gweight.txt", 'r', encoding="utf-8") as file:
@@ -48,11 +49,11 @@ is_share = os.environ.get("is_share", "False")
is_share = eval(is_share)
if "_CUDA_VISIBLE_DEVICES" in os.environ:
os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"]
-is_half = eval(os.environ.get("is_half", "True"))
+is_half = eval(os.environ.get("is_half", "True")) and not torch.backends.mps.is_available()
import gradio as gr
from transformers import AutoModelForMaskedLM, AutoTokenizer
import numpy as np
-import librosa, torch
+import librosa
from feature_extractor import cnhubert
cnhubert.cnhubert_base_path = cnhubert_base_path
@@ -72,8 +73,6 @@ os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1' # 确保直接启动推理UI时
if torch.cuda.is_available():
device = "cuda"
-elif torch.backends.mps.is_available():
- device = "mps"
else:
device = "cpu"
@@ -209,50 +208,8 @@ dict_language = {
}
-def splite_en_inf(sentence, language):
- pattern = re.compile(r'[a-zA-Z ]+')
- textlist = []
- langlist = []
- pos = 0
- for match in pattern.finditer(sentence):
- start, end = match.span()
- if start > pos:
- textlist.append(sentence[pos:start])
- langlist.append(language)
- textlist.append(sentence[start:end])
- langlist.append("en")
- pos = end
- if pos < len(sentence):
- textlist.append(sentence[pos:])
- langlist.append(language)
- # Merge punctuation into previous word
- for i in range(len(textlist)-1, 0, -1):
- if re.match(r'^[\W_]+$', textlist[i]):
- textlist[i-1] += textlist[i]
- del textlist[i]
- del langlist[i]
- # Merge consecutive words with the same language tag
- i = 0
- while i < len(langlist) - 1:
- if langlist[i] == langlist[i+1]:
- textlist[i] += textlist[i+1]
- del textlist[i+1]
- del langlist[i+1]
- else:
- i += 1
-
- return textlist, langlist
-
-
def clean_text_inf(text, language):
- formattext = ""
- language = language.replace("all_","")
- for tmp in LangSegment.getTexts(text):
- if tmp["lang"] == language:
- formattext += tmp["text"] + " "
- while " " in formattext:
- formattext = formattext.replace(" ", " ")
- phones, word2ph, norm_text = clean_text(formattext, language)
+ phones, word2ph, norm_text = clean_text(text, language)
phones = cleaned_text_to_sequence(phones)
return phones, word2ph, norm_text
@@ -270,57 +227,6 @@ def get_bert_inf(phones, word2ph, norm_text, language):
return bert
-def nonen_clean_text_inf(text, language):
- if(language!="auto"):
- textlist, langlist = splite_en_inf(text, language)
- else:
- textlist=[]
- langlist=[]
- for tmp in LangSegment.getTexts(text):
- langlist.append(tmp["lang"])
- textlist.append(tmp["text"])
- print(textlist)
- print(langlist)
- phones_list = []
- word2ph_list = []
- norm_text_list = []
- for i in range(len(textlist)):
- lang = langlist[i]
- phones, word2ph, norm_text = clean_text_inf(textlist[i], lang)
- phones_list.append(phones)
- if lang == "zh":
- word2ph_list.append(word2ph)
- norm_text_list.append(norm_text)
- print(word2ph_list)
- phones = sum(phones_list, [])
- word2ph = sum(word2ph_list, [])
- norm_text = ' '.join(norm_text_list)
-
- return phones, word2ph, norm_text
-
-
-def nonen_get_bert_inf(text, language):
- if(language!="auto"):
- textlist, langlist = splite_en_inf(text, language)
- else:
- textlist=[]
- langlist=[]
- for tmp in LangSegment.getTexts(text):
- langlist.append(tmp["lang"])
- textlist.append(tmp["text"])
- print(textlist)
- print(langlist)
- bert_list = []
- for i in range(len(textlist)):
- lang = langlist[i]
- phones, word2ph, norm_text = clean_text_inf(textlist[i], lang)
- bert = get_bert_inf(phones, word2ph, norm_text, lang)
- bert_list.append(bert)
- bert = torch.cat(bert_list, dim=1)
-
- return bert
-
-
splits = {",", "。", "?", "!", ",", ".", "?", "!", "~", ":", ":", "—", "…", }
@@ -330,23 +236,63 @@ def get_first(text):
return text
-def get_cleaned_text_final(text,language):
+def get_phones_and_bert(text,language):
if language in {"en","all_zh","all_ja"}:
- phones, word2ph, norm_text = clean_text_inf(text, language)
+ language = language.replace("all_","")
+ if language == "en":
+ LangSegment.setfilters(["en"])
+ formattext = " ".join(tmp["text"] for tmp in LangSegment.getTexts(text))
+ else:
+ # 因无法区别中日文汉字,以用户输入为准
+ formattext = text
+ while " " in formattext:
+ formattext = formattext.replace(" ", " ")
+ phones, word2ph, norm_text = clean_text_inf(formattext, language)
+ if language == "zh":
+ bert = get_bert_feature(norm_text, word2ph).to(device)
+ else:
+ bert = torch.zeros(
+ (1024, len(phones)),
+ dtype=torch.float16 if is_half == True else torch.float32,
+ ).to(device)
elif language in {"zh", "ja","auto"}:
- phones, word2ph, norm_text = nonen_clean_text_inf(text, language)
- return phones, word2ph, norm_text
+ textlist=[]
+ langlist=[]
+ LangSegment.setfilters(["zh","ja","en","ko"])
+ if language == "auto":
+ for tmp in LangSegment.getTexts(text):
+ if tmp["lang"] == "ko":
+ langlist.append("zh")
+ textlist.append(tmp["text"])
+ else:
+ langlist.append(tmp["lang"])
+ textlist.append(tmp["text"])
+ else:
+ for tmp in LangSegment.getTexts(text):
+ if tmp["lang"] == "en":
+ langlist.append(tmp["lang"])
+ else:
+ # 因无法区别中日文汉字,以用户输入为准
+ langlist.append(language)
+ textlist.append(tmp["text"])
+ print(textlist)
+ print(langlist)
+ phones_list = []
+ bert_list = []
+ norm_text_list = []
+ for i in range(len(textlist)):
+ lang = langlist[i]
+ phones, word2ph, norm_text = clean_text_inf(textlist[i], lang)
+ bert = get_bert_inf(phones, word2ph, norm_text, lang)
+ phones_list.append(phones)
+ norm_text_list.append(norm_text)
+ bert_list.append(bert)
+ bert = torch.cat(bert_list, dim=1)
+ phones = sum(phones_list, [])
+ norm_text = ''.join(norm_text_list)
+
+ return phones,bert.to(dtype),norm_text
-def get_bert_final(phones, word2ph, text,language,device):
- if language == "en":
- bert = get_bert_inf(phones, word2ph, text, language)
- elif language in {"zh", "ja","auto"}:
- bert = nonen_get_bert_inf(text, language)
- elif language == "all_zh":
- bert = get_bert_feature(text, word2ph).to(device)
- else:
- bert = torch.zeros((1024, len(phones))).to(device)
- return bert
def merge_short_text_in_array(texts, threshold):
if (len(texts)) < 2:
@@ -365,15 +311,19 @@ def merge_short_text_in_array(texts, threshold):
result[len(result) - 1] += text
return result
-def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, how_to_cut=i18n("不切"), top_k=20, top_p=0.6, temperature=0.6):
+def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, how_to_cut=i18n("不切"), top_k=20, top_p=0.6, temperature=0.6, ref_free = False):
+ if prompt_text is None or len(prompt_text) == 0:
+ ref_free = True
t0 = ttime()
prompt_language = dict_language[prompt_language]
text_language = dict_language[text_language]
- prompt_text = prompt_text.strip("\n")
- if (prompt_text[-1] not in splits): prompt_text += "。" if prompt_language != "en" else "."
+ if not ref_free:
+ prompt_text = prompt_text.strip("\n")
+ if (prompt_text[-1] not in splits): prompt_text += "。" if prompt_language != "en" else "."
+ print(i18n("实际输入的参考文本:"), prompt_text)
text = text.strip("\n")
if (text[0] not in splits and len(get_first(text)) < 4): text = "。" + text if text_language != "en" else "." + text
- print(i18n("实际输入的参考文本:"), prompt_text)
+
print(i18n("实际输入的目标文本:"), text)
zero_wav = np.zeros(
int(hps.data.sampling_rate * 0.3),
@@ -398,11 +348,10 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
1, 2
) # .float()
codes = vq_model.extract_latent(ssl_content)
+
prompt_semantic = codes[0, 0]
t1 = ttime()
- phones1, word2ph1, norm_text1=get_cleaned_text_final(prompt_text, prompt_language)
-
if (how_to_cut == i18n("凑四句一切")):
text = cut1(text)
elif (how_to_cut == i18n("凑50字一切")):
@@ -419,7 +368,8 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
texts = text.split("\n")
texts = merge_short_text_in_array(texts, 5)
audio_opt = []
- bert1=get_bert_final(phones1, word2ph1, norm_text1,prompt_language,device).to(dtype)
+ if not ref_free:
+ phones1,bert1,norm_text1=get_phones_and_bert(prompt_text, prompt_language)
for text in texts:
# 解决输入目标文本的空行导致报错的问题
@@ -427,11 +377,15 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
continue
if (text[-1] not in splits): text += "。" if text_language != "en" else "."
print(i18n("实际输入的目标文本(每句):"), text)
- phones2, word2ph2, norm_text2 = get_cleaned_text_final(text, text_language)
- bert2 = get_bert_final(phones2, word2ph2, norm_text2, text_language, device).to(dtype)
- bert = torch.cat([bert1, bert2], 1)
+ phones2,bert2,norm_text2=get_phones_and_bert(text, text_language)
+ print(i18n("前端处理后的文本(每句):"), norm_text2)
+ if not ref_free:
+ bert = torch.cat([bert1, bert2], 1)
+ all_phoneme_ids = torch.LongTensor(phones1+phones2).to(device).unsqueeze(0)
+ else:
+ bert = bert2
+ all_phoneme_ids = torch.LongTensor(phones2).to(device).unsqueeze(0)
- all_phoneme_ids = torch.LongTensor(phones1 + phones2).to(device).unsqueeze(0)
bert = bert.to(device).unsqueeze(0)
all_phoneme_len = torch.tensor([all_phoneme_ids.shape[-1]]).to(device)
prompt = prompt_semantic.unsqueeze(0).to(device)
@@ -441,7 +395,7 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
pred_semantic, idx = t2s_model.model.infer_panel(
all_phoneme_ids,
all_phoneme_len,
- prompt,
+ None if ref_free else prompt,
bert,
# prompt_phone_len=ph_offset,
top_k=top_k,
@@ -551,10 +505,13 @@ def cut5(inp):
# if not re.search(r'[^\w\s]', inp[-1]):
# inp += '。'
inp = inp.strip("\n")
- punds = r'[,.;?!、,。?!;:]'
+ punds = r'[,.;?!、,。?!;:…]'
items = re.split(f'({punds})', inp)
- items = ["".join(group) for group in zip(items[::2], items[1::2])]
- opt = "\n".join(items)
+ mergeitems = ["".join(group) for group in zip(items[::2], items[1::2])]
+ # 在句子不存在符号或句尾无符号的时候保证文本完整
+ if len(items)%2 == 1:
+ mergeitems.append(items[-1])
+ opt = "\n".join(mergeitems)
return opt
@@ -607,11 +564,14 @@ with gr.Blocks(title="GPT-SoVITS WebUI") as app:
gr.Markdown(value=i18n("*请上传并填写参考信息"))
with gr.Row():
inp_ref = gr.Audio(label=i18n("请上传3~10秒内参考音频,超过会报错!"), type="filepath")
- prompt_text = gr.Textbox(label=i18n("参考音频的文本"), value="")
+ with gr.Column():
+ ref_text_free = gr.Checkbox(label=i18n("开启无参考文本模式。不填参考文本亦相当于开启。"), value=False, interactive=True, show_label=True)
+ gr.Markdown(i18n("使用无参考文本模式时建议使用微调的GPT,听不清参考音频说的啥(不晓得写啥)可以开,开启后无视填写的参考文本。"))
+ prompt_text = gr.Textbox(label=i18n("参考音频的文本"), value="")
prompt_language = gr.Dropdown(
label=i18n("参考音频的语种"), choices=[i18n("中文"), i18n("英文"), i18n("日文"), i18n("中英混合"), i18n("日英混合"), i18n("多语种混合")], value=i18n("中文")
)
- gr.Markdown(value=i18n("*请填写需要合成的目标文本。中英混合选中文,日英混合选日文,中日混合暂不支持,非目标语言文本自动遗弃。"))
+ gr.Markdown(value=i18n("*请填写需要合成的目标文本和语种模式"))
with gr.Row():
text = gr.Textbox(label=i18n("需要合成的文本"), value="")
text_language = gr.Dropdown(
@@ -624,6 +584,7 @@ with gr.Blocks(title="GPT-SoVITS WebUI") as app:
interactive=True,
)
with gr.Row():
+ gr.Markdown(value=i18n("gpt采样参数(无参考文本时不要太低):"))
top_k = gr.Slider(minimum=1,maximum=100,step=1,label=i18n("top_k"),value=5,interactive=True)
top_p = gr.Slider(minimum=0,maximum=1,step=0.05,label=i18n("top_p"),value=1,interactive=True)
temperature = gr.Slider(minimum=0,maximum=1,step=0.05,label=i18n("temperature"),value=1,interactive=True)
@@ -632,7 +593,7 @@ with gr.Blocks(title="GPT-SoVITS WebUI") as app:
inference_button.click(
get_tts_wav,
- [inp_ref, prompt_text, prompt_language, text, text_language, how_to_cut,top_k,top_p,temperature],
+ [inp_ref, prompt_text, prompt_language, text, text_language, how_to_cut, top_k, top_p, temperature, ref_text_free],
[output],
)
@@ -650,7 +611,7 @@ with gr.Blocks(title="GPT-SoVITS WebUI") as app:
button3.click(cut3, [text_inp], [text_opt])
button4.click(cut4, [text_inp], [text_opt])
button5.click(cut5, [text_inp], [text_opt])
- gr.Markdown(value=i18n("后续将支持混合语种编码文本输入。"))
+ gr.Markdown(value=i18n("后续将支持转音素、手工修改音素、语音合成分步执行。"))
app.queue(concurrency_count=511, max_size=1022).launch(
server_name="0.0.0.0",
diff --git a/GPT_SoVITS/module/models.py b/GPT_SoVITS/module/models.py
index c99485cf..a4d22352 100644
--- a/GPT_SoVITS/module/models.py
+++ b/GPT_SoVITS/module/models.py
@@ -228,6 +228,7 @@ class TextEncoder(nn.Module):
)
y = self.ssl_proj(y * y_mask) * y_mask
+
y = self.encoder_ssl(y * y_mask, y_mask)
text_mask = torch.unsqueeze(
@@ -958,11 +959,13 @@ class SynthesizerTrn(nn.Module):
@torch.no_grad()
def decode(self, codes, text, refer, noise_scale=0.5):
- refer_lengths = torch.LongTensor([refer.size(2)]).to(refer.device)
- refer_mask = torch.unsqueeze(
- commons.sequence_mask(refer_lengths, refer.size(2)), 1
- ).to(refer.dtype)
- ge = self.ref_enc(refer * refer_mask, refer_mask)
+ ge = None
+ if refer is not None:
+ refer_lengths = torch.LongTensor([refer.size(2)]).to(refer.device)
+ refer_mask = torch.unsqueeze(
+ commons.sequence_mask(refer_lengths, refer.size(2)), 1
+ ).to(refer.dtype)
+ ge = self.ref_enc(refer * refer_mask, refer_mask)
y_lengths = torch.LongTensor([codes.size(2) * 2]).to(codes.device)
text_lengths = torch.LongTensor([text.size(-1)]).to(text.device)
diff --git a/GPT_SoVITS/prepare_datasets/1-get-text.py b/GPT_SoVITS/prepare_datasets/1-get-text.py
index 88c9d858..5873164a 100644
--- a/GPT_SoVITS/prepare_datasets/1-get-text.py
+++ b/GPT_SoVITS/prepare_datasets/1-get-text.py
@@ -33,13 +33,13 @@ from time import time as ttime
import shutil
-def my_save(fea, path): #####fix issue: torch.save doesn't support chinese path
- dir = os.path.dirname(path)
- name = os.path.basename(path)
- tmp_path = "%s/%s%s.pth" % (dir, ttime(), i_part)
- torch.save(fea, tmp_path)
- shutil.move(tmp_path, "%s/%s" % (dir, name))
-
+def my_save(fea,path):#####fix issue: torch.save doesn't support chinese path
+ dir=os.path.dirname(path)
+ name=os.path.basename(path)
+ # tmp_path="%s/%s%s.pth"%(dir,ttime(),i_part)
+ tmp_path="%s%s.pth"%(ttime(),i_part)
+ torch.save(fea,tmp_path)
+ shutil.move(tmp_path,"%s/%s"%(dir,name))
txt_path = "%s/2-name2text-%s.txt" % (opt_dir, i_part)
diff --git a/GPT_SoVITS/prepare_datasets/2-get-hubert-wav32k.py b/GPT_SoVITS/prepare_datasets/2-get-hubert-wav32k.py
index b8355dd4..7607259e 100644
--- a/GPT_SoVITS/prepare_datasets/2-get-hubert-wav32k.py
+++ b/GPT_SoVITS/prepare_datasets/2-get-hubert-wav32k.py
@@ -35,7 +35,8 @@ import shutil
def my_save(fea,path):#####fix issue: torch.save doesn't support chinese path
dir=os.path.dirname(path)
name=os.path.basename(path)
- tmp_path="%s/%s%s.pth"%(dir,ttime(),i_part)
+ # tmp_path="%s/%s%s.pth"%(dir,ttime(),i_part)
+ tmp_path="%s%s.pth"%(ttime(),i_part)
torch.save(fea,tmp_path)
shutil.move(tmp_path,"%s/%s"%(dir,name))
@@ -98,7 +99,7 @@ for line in lines[int(i_part)::int(all_parts)]:
try:
# wav_name,text=line.split("\t")
wav_name, spk_name, language, text = line.split("|")
- if (inp_wav_dir !=None):
+ if (inp_wav_dir != "" and inp_wav_dir != None):
wav_name = os.path.basename(wav_name)
wav_path = "%s/%s"%(inp_wav_dir, wav_name)
diff --git a/GPT_SoVITS/process_ckpt.py b/GPT_SoVITS/process_ckpt.py
index 74833379..3a436f10 100644
--- a/GPT_SoVITS/process_ckpt.py
+++ b/GPT_SoVITS/process_ckpt.py
@@ -1,11 +1,18 @@
import traceback
from collections import OrderedDict
-
+from time import time as ttime
+import shutil,os
import torch
from tools.i18n.i18n import I18nAuto
i18n = I18nAuto()
+def my_save(fea,path):#####fix issue: torch.save doesn't support chinese path
+ dir=os.path.dirname(path)
+ name=os.path.basename(path)
+ tmp_path="%s.pth"%(ttime())
+ torch.save(fea,tmp_path)
+ shutil.move(tmp_path,"%s/%s"%(dir,name))
def savee(ckpt, name, epoch, steps, hps):
try:
@@ -17,7 +24,8 @@ def savee(ckpt, name, epoch, steps, hps):
opt["weight"][key] = ckpt[key].half()
opt["config"] = hps
opt["info"] = "%sepoch_%siteration" % (epoch, steps)
- torch.save(opt, "%s/%s.pth" % (hps.save_weight_dir, name))
+ # torch.save(opt, "%s/%s.pth" % (hps.save_weight_dir, name))
+ my_save(opt, "%s/%s.pth" % (hps.save_weight_dir, name))
return "Success."
except:
return traceback.format_exc()
diff --git a/GPT_SoVITS/s1_train.py b/GPT_SoVITS/s1_train.py
index c26302a8..fb273542 100644
--- a/GPT_SoVITS/s1_train.py
+++ b/GPT_SoVITS/s1_train.py
@@ -24,6 +24,14 @@ torch.set_float32_matmul_precision("high")
from AR.utils import get_newest_ckpt
from collections import OrderedDict
+from time import time as ttime
+import shutil
+def my_save(fea,path):#####fix issue: torch.save doesn't support chinese path
+ dir=os.path.dirname(path)
+ name=os.path.basename(path)
+ tmp_path="%s.pth"%(ttime())
+ torch.save(fea,tmp_path)
+ shutil.move(tmp_path,"%s/%s"%(dir,name))
class my_model_ckpt(ModelCheckpoint):
@@ -70,7 +78,8 @@ class my_model_ckpt(ModelCheckpoint):
to_save_od["weight"][key] = dictt[key].half()
to_save_od["config"] = self.config
to_save_od["info"] = "GPT-e%s" % (trainer.current_epoch + 1)
- torch.save(
+ # torch.save(
+ my_save(
to_save_od,
"%s/%s-e%s.ckpt"
% (
diff --git a/GPT_SoVITS/text/chinese.py b/GPT_SoVITS/text/chinese.py
index eb8a45b6..f259988d 100644
--- a/GPT_SoVITS/text/chinese.py
+++ b/GPT_SoVITS/text/chinese.py
@@ -44,6 +44,8 @@ rep_map = {
"$": ".",
"/": ",",
"—": "-",
+ "~": "…",
+ "~":"…",
}
tone_modifier = ToneSandhi()
diff --git a/GPT_SoVITS/text/english.py b/GPT_SoVITS/text/english.py
index 0a5d6c21..90f48a55 100644
--- a/GPT_SoVITS/text/english.py
+++ b/GPT_SoVITS/text/english.py
@@ -169,9 +169,9 @@ def read_dict_new():
line = line.strip()
word_split = line.split(" ")
word = word_split[0]
- if word not in g2p_dict:
- g2p_dict[word] = []
- g2p_dict[word].append(word_split[1:])
+ #if word not in g2p_dict:
+ g2p_dict[word] = []
+ g2p_dict[word].append(word_split[1:])
line_index = line_index + 1
line = f.readline()
diff --git a/GPT_SoVITS/text/tone_sandhi.py b/GPT_SoVITS/text/tone_sandhi.py
index 9f62abe9..c557dad1 100644
--- a/GPT_SoVITS/text/tone_sandhi.py
+++ b/GPT_SoVITS/text/tone_sandhi.py
@@ -672,6 +672,7 @@ class ToneSandhi:
and i + 1 < len(seg)
and seg[i - 1][0] == seg[i + 1][0]
and seg[i - 1][1] == "v"
+ and seg[i + 1][1] == "v"
):
new_seg[i - 1][0] = new_seg[i - 1][0] + "一" + new_seg[i - 1][0]
else:
diff --git a/GPT_SoVITS/text/zh_normalization/num.py b/GPT_SoVITS/text/zh_normalization/num.py
index 8a54d3e6..8ef7f48f 100644
--- a/GPT_SoVITS/text/zh_normalization/num.py
+++ b/GPT_SoVITS/text/zh_normalization/num.py
@@ -172,6 +172,21 @@ def replace_range(match) -> str:
return result
+# ~至表达式
+RE_TO_RANGE = re.compile(
+ r'((-?)((\d+)(\.\d+)?)|(\.(\d+)))(%|°C|℃|度|摄氏度|cm2|cm²|cm3|cm³|cm|db|ds|kg|km|m2|m²|m³|m3|ml|m|mm|s)[~]((-?)((\d+)(\.\d+)?)|(\.(\d+)))(%|°C|℃|度|摄氏度|cm2|cm²|cm3|cm³|cm|db|ds|kg|km|m2|m²|m³|m3|ml|m|mm|s)')
+
+def replace_to_range(match) -> str:
+ """
+ Args:
+ match (re.Match)
+ Returns:
+ str
+ """
+ result = match.group(0).replace('~', '至')
+ return result
+
+
def _get_value(value_string: str, use_zero: bool=True) -> List[str]:
stripped = value_string.lstrip('0')
if len(stripped) == 0:
diff --git a/GPT_SoVITS/text/zh_normalization/text_normlization.py b/GPT_SoVITS/text/zh_normalization/text_normlization.py
index 1250e96c..712537d5 100644
--- a/GPT_SoVITS/text/zh_normalization/text_normlization.py
+++ b/GPT_SoVITS/text/zh_normalization/text_normlization.py
@@ -33,6 +33,7 @@ from .num import RE_NUMBER
from .num import RE_PERCENTAGE
from .num import RE_POSITIVE_QUANTIFIERS
from .num import RE_RANGE
+from .num import RE_TO_RANGE
from .num import replace_default_num
from .num import replace_frac
from .num import replace_negative_num
@@ -40,6 +41,7 @@ from .num import replace_number
from .num import replace_percentage
from .num import replace_positive_quantifier
from .num import replace_range
+from .num import replace_to_range
from .phonecode import RE_MOBILE_PHONE
from .phonecode import RE_NATIONAL_UNIFORM_NUMBER
from .phonecode import RE_TELEPHONE
@@ -65,7 +67,7 @@ class TextNormalizer():
if lang == "zh":
text = text.replace(" ", "")
# 过滤掉特殊字符
- text = re.sub(r'[——《》【】<=>{}()()#&@“”^_|…\\]', '', text)
+ text = re.sub(r'[——《》【】<=>{}()()#&@“”^_|\\]', '', text)
text = self.SENTENCE_SPLITOR.sub(r'\1\n', text)
text = text.strip()
sentences = [sentence.strip() for sentence in re.split(r'\n+', text)]
@@ -73,8 +75,8 @@ class TextNormalizer():
def _post_replace(self, sentence: str) -> str:
sentence = sentence.replace('/', '每')
- sentence = sentence.replace('~', '至')
- sentence = sentence.replace('~', '至')
+ # sentence = sentence.replace('~', '至')
+ # sentence = sentence.replace('~', '至')
sentence = sentence.replace('①', '一')
sentence = sentence.replace('②', '二')
sentence = sentence.replace('③', '三')
@@ -128,6 +130,8 @@ class TextNormalizer():
sentence = RE_TIME_RANGE.sub(replace_time, sentence)
sentence = RE_TIME.sub(replace_time, sentence)
+ # 处理~波浪号作为至的替换
+ sentence = RE_TO_RANGE.sub(replace_to_range, sentence)
sentence = RE_TEMPERATURE.sub(replace_temperature, sentence)
sentence = replace_measure(sentence)
sentence = RE_FRAC.sub(replace_frac, sentence)
diff --git a/GPT_SoVITS/utils.py b/GPT_SoVITS/utils.py
index 0ce03b33..7984b5a8 100644
--- a/GPT_SoVITS/utils.py
+++ b/GPT_SoVITS/utils.py
@@ -64,6 +64,14 @@ def load_checkpoint(checkpoint_path, model, optimizer=None, skip_optimizer=False
)
return model, optimizer, learning_rate, iteration
+from time import time as ttime
+import shutil
+def my_save(fea,path):#####fix issue: torch.save doesn't support chinese path
+ dir=os.path.dirname(path)
+ name=os.path.basename(path)
+ tmp_path="%s.pth"%(ttime())
+ torch.save(fea,tmp_path)
+ shutil.move(tmp_path,"%s/%s"%(dir,name))
def save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path):
logger.info(
@@ -75,7 +83,8 @@ def save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path)
state_dict = model.module.state_dict()
else:
state_dict = model.state_dict()
- torch.save(
+ # torch.save(
+ my_save(
{
"model": state_dict,
"iteration": iteration,
diff --git a/GPT_SoVITS_Inference.ipynb b/GPT_SoVITS_Inference.ipynb
new file mode 100644
index 00000000..a5b55325
--- /dev/null
+++ b/GPT_SoVITS_Inference.ipynb
@@ -0,0 +1,152 @@
+{
+ "nbformat": 4,
+ "nbformat_minor": 0,
+ "metadata": {
+ "colab": {
+ "provenance": []
+ },
+ "kernelspec": {
+ "name": "python3",
+ "display_name": "Python 3"
+ },
+ "accelerator": "GPU"
+ },
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "source": [
+ "# Credits for bubarino giving me the huggingface import code (感谢 bubarino 给了我 huggingface 导入代码)"
+ ],
+ "metadata": {
+ "id": "himHYZmra7ix"
+ }
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "e9b7iFV3dm1f"
+ },
+ "source": [
+ "!git clone https://github.com/RVC-Boss/GPT-SoVITS.git\n",
+ "%cd GPT-SoVITS\n",
+ "!apt-get update && apt-get install -y --no-install-recommends tzdata ffmpeg libsox-dev parallel aria2 git git-lfs && git lfs install\n",
+ "!pip install -r requirements.txt"
+ ],
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "# @title Download pretrained models 下载预训练模型\n",
+ "!mkdir -p /content/GPT-SoVITS/GPT_SoVITS/pretrained_models\n",
+ "!mkdir -p /content/GPT-SoVITS/tools/damo_asr/models\n",
+ "!mkdir -p /content/GPT-SoVITS/tools/uvr5\n",
+ "%cd /content/GPT-SoVITS/GPT_SoVITS/pretrained_models\n",
+ "!git clone https://huggingface.co/lj1995/GPT-SoVITS\n",
+ "%cd /content/GPT-SoVITS/tools/damo_asr/models\n",
+ "!git clone https://www.modelscope.cn/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch.git\n",
+ "!git clone https://www.modelscope.cn/damo/speech_fsmn_vad_zh-cn-16k-common-pytorch.git\n",
+ "!git clone https://www.modelscope.cn/damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch.git\n",
+ "# @title UVR5 pretrains 安装uvr5模型\n",
+ "%cd /content/GPT-SoVITS/tools/uvr5\n",
+ "!git clone https://huggingface.co/Delik/uvr5_weights\n",
+ "!git config core.sparseCheckout true\n",
+ "!mv /content/GPT-SoVITS/GPT_SoVITS/pretrained_models/GPT-SoVITS/* /content/GPT-SoVITS/GPT_SoVITS/pretrained_models/"
+ ],
+ "metadata": {
+ "id": "0NgxXg5sjv7z",
+ "cellView": "form"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "#@title Create folder models 创建文件夹模型\n",
+ "import os\n",
+ "base_directory = \"/content/GPT-SoVITS\"\n",
+ "folder_names = [\"SoVITS_weights\", \"GPT_weights\"]\n",
+ "\n",
+ "for folder_name in folder_names:\n",
+ " if os.path.exists(os.path.join(base_directory, folder_name)):\n",
+ " print(f\"The folder '{folder_name}' already exists. (文件夹'{folder_name}'已经存在。)\")\n",
+ " else:\n",
+ " os.makedirs(os.path.join(base_directory, folder_name))\n",
+ " print(f\"The folder '{folder_name}' was created successfully! (文件夹'{folder_name}'已成功创建!)\")\n",
+ "\n",
+ "print(\"All folders have been created. (所有文件夹均已创建。)\")"
+ ],
+ "metadata": {
+ "cellView": "form",
+ "id": "cPDEH-9czOJF"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "import requests\n",
+ "import zipfile\n",
+ "import shutil\n",
+ "import os\n",
+ "\n",
+ "#@title Import model 导入模型 (HuggingFace)\n",
+ "hf_link = 'https://huggingface.co/modelloosrvcc/Nagisa_Shingetsu_GPT-SoVITS/resolve/main/Nagisa.zip' #@param {type: \"string\"}\n",
+ "\n",
+ "output_path = '/content/'\n",
+ "\n",
+ "response = requests.get(hf_link)\n",
+ "with open(output_path + 'file.zip', 'wb') as file:\n",
+ " file.write(response.content)\n",
+ "\n",
+ "with zipfile.ZipFile(output_path + 'file.zip', 'r') as zip_ref:\n",
+ " zip_ref.extractall(output_path)\n",
+ "\n",
+ "os.remove(output_path + \"file.zip\")\n",
+ "\n",
+ "source_directory = output_path\n",
+ "SoVITS_destination_directory = '/content/GPT-SoVITS/SoVITS_weights'\n",
+ "GPT_destination_directory = '/content/GPT-SoVITS/GPT_weights'\n",
+ "\n",
+ "for filename in os.listdir(source_directory):\n",
+ " if filename.endswith(\".pth\"):\n",
+ " source_path = os.path.join(source_directory, filename)\n",
+ " destination_path = os.path.join(SoVITS_destination_directory, filename)\n",
+ " shutil.move(source_path, destination_path)\n",
+ "\n",
+ "for filename in os.listdir(source_directory):\n",
+ " if filename.endswith(\".ckpt\"):\n",
+ " source_path = os.path.join(source_directory, filename)\n",
+ " destination_path = os.path.join(GPT_destination_directory, filename)\n",
+ " shutil.move(source_path, destination_path)\n",
+ "\n",
+ "print(f'Model downloaded. (模型已下载。)')"
+ ],
+ "metadata": {
+ "cellView": "form",
+ "id": "vbZY-LnM0tzq"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "# @title launch WebUI 启动WebUI\n",
+ "!/usr/local/bin/pip install ipykernel\n",
+ "!sed -i '10s/False/True/' /content/GPT-SoVITS/config.py\n",
+ "%cd /content/GPT-SoVITS/\n",
+ "!/usr/local/bin/python webui.py"
+ ],
+ "metadata": {
+ "id": "4oRGUzkrk8C7",
+ "cellView": "form"
+ },
+ "execution_count": null,
+ "outputs": []
+ }
+ ]
+}
\ No newline at end of file
diff --git a/README.md b/README.md
index 72f3694f..0b0e2d44 100644
--- a/README.md
+++ b/README.md
@@ -17,14 +17,6 @@ A Powerful Few-shot Voice Conversion and Text-to-Speech WebUI.
---
-> Check out our [demo video](https://www.bilibili.com/video/BV12g4y1m7Uw) here!
-
-Unseen speakers few-shot fine-tuning demo:
-
-https://github.com/RVC-Boss/GPT-SoVITS/assets/129054828/05bee1fa-bdd8-4d85-9350-80c060ab47fb
-
-For users in China region, you can use AutoDL Cloud Docker to experience the full functionality online: https://www.codewithgpu.com/i/RVC-Boss/GPT-SoVITS/GPT-SoVITS-Official
-
## Features:
1. **Zero-shot TTS:** Input a 5-second vocal sample and experience instant text-to-speech conversion.
@@ -35,19 +27,31 @@ For users in China region, you can use AutoDL Cloud Docker to experience the ful
4. **WebUI Tools:** Integrated tools include voice accompaniment separation, automatic training set segmentation, Chinese ASR, and text labeling, assisting beginners in creating training datasets and GPT/SoVITS models.
-## Environment Preparation
+**Check out our [demo video](https://www.bilibili.com/video/BV12g4y1m7Uw) here!**
-If you are a Windows user (tested with win>=10) you can install directly via the prezip. Just download the [prezip](https://huggingface.co/lj1995/GPT-SoVITS-windows-package/resolve/main/GPT-SoVITS-beta.7z?download=true), unzip it and double-click go-webui.bat to start GPT-SoVITS-WebUI.
+Unseen speakers few-shot fine-tuning demo:
+
+https://github.com/RVC-Boss/GPT-SoVITS/assets/129054828/05bee1fa-bdd8-4d85-9350-80c060ab47fb
+
+[教程中文版](https://www.yuque.com/baicaigongchang1145haoyuangong/ib3g1e) [User guide (EN)](https://rentry.co/GPT-SoVITS-guide#/)
+
+## Installation
+
+For users in China region, you can [click here](https://www.codewithgpu.com/i/RVC-Boss/GPT-SoVITS/GPT-SoVITS-Official) to use AutoDL Cloud Docker to experience the full functionality online.
### Tested Environments
- Python 3.9, PyTorch 2.0.1, CUDA 11
- Python 3.10.13, PyTorch 2.1.2, CUDA 12.3
-- Python 3.9, PyTorch 2.3.0.dev20240122, macOS 14.3 (Apple silicon, GPU)
+- Python 3.9, PyTorch 2.3.0.dev20240122, macOS 14.3 (Apple silicon)
-_Note: numba==0.56.4 require py<3.11_
+_Note: numba==0.56.4 requires py<3.11_
-### Quick Install with Conda
+### Windows
+
+If you are a Windows user (tested with win>=10), you can directly download the [pre-packaged distribution](https://huggingface.co/lj1995/GPT-SoVITS-windows-package/resolve/main/GPT-SoVITS-beta.7z?download=true) and double-click on _go-webui.bat_ to start GPT-SoVITS-WebUI.
+
+### Linux
```bash
conda create -n GPTSoVits python=3.9
@@ -55,15 +59,37 @@ conda activate GPTSoVits
bash install.sh
```
+### macOS
+
+Only Macs that meet the following conditions can train models:
+
+- Mac computers with Apple silicon
+- macOS 12.3 or later
+- Xcode command-line tools installed by running `xcode-select --install`
+
+**All Macs can do inference with CPU, which has been demonstrated to outperform GPU inference.**
+
+First make sure you have installed FFmpeg by running `brew install ffmpeg` or `conda install ffmpeg`, then install by using the following commands:
+
+```bash
+conda create -n GPTSoVits python=3.9
+conda activate GPTSoVits
+
+pip3 install --pre torch torchaudio --index-url https://download.pytorch.org/whl/nightly/cpu
+pip install -r requirements.txt
+```
+
+_Note: Training models will only work if you've installed PyTorch Nightly._
+
### Install Manually
-#### Pip Packages
+#### Install Dependences
```bash
pip install -r requirements.txt
```
-#### FFmpeg
+#### Install FFmpeg
##### Conda Users
@@ -79,57 +105,10 @@ sudo apt install libsox-dev
conda install -c conda-forge 'ffmpeg<7'
```
-##### MacOS Users
-
-```bash
-brew install ffmpeg
-```
-
##### Windows Users
Download and place [ffmpeg.exe](https://huggingface.co/lj1995/VoiceConversionWebUI/blob/main/ffmpeg.exe) and [ffprobe.exe](https://huggingface.co/lj1995/VoiceConversionWebUI/blob/main/ffprobe.exe) in the GPT-SoVITS root.
-### Pretrained Models
-
-Download pretrained models from [GPT-SoVITS Models](https://huggingface.co/lj1995/GPT-SoVITS) and place them in `GPT_SoVITS/pretrained_models`.
-
-For UVR5 (Vocals/Accompaniment Separation & Reverberation Removal, additionally), download models from [UVR5 Weights](https://huggingface.co/lj1995/VoiceConversionWebUI/tree/main/uvr5_weights) and place them in `tools/uvr5/uvr5_weights`.
-
-Users in China region can download these two models by entering the links below and clicking "Download a copy"
-
-- [GPT-SoVITS Models](https://www.icloud.com.cn/iclouddrive/056y_Xog_HXpALuVUjscIwTtg#GPT-SoVITS_Models)
-
-- [UVR5 Weights](https://www.icloud.com.cn/iclouddrive/0bekRKDiJXboFhbfm3lM2fVbA#UVR5_Weights)
-
-For Chinese ASR (additionally), download models from [Damo ASR Model](https://modelscope.cn/models/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/files), [Damo VAD Model](https://modelscope.cn/models/damo/speech_fsmn_vad_zh-cn-16k-common-pytorch/files), and [Damo Punc Model](https://modelscope.cn/models/damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch/files) and place them in `tools/damo_asr/models`.
-
-### For Mac Users
-
-If you are a Mac user, make sure you meet the following conditions for training and inferencing with GPU:
-
-- Mac computers with Apple silicon or AMD GPUs
-- macOS 12.3 or later
-- Xcode command-line tools installed by running `xcode-select --install`
-
-_Other Macs can do inference with CPU only._
-
-Then install by using the following commands:
-
-#### Create Environment
-
-```bash
-conda create -n GPTSoVits python=3.9
-conda activate GPTSoVits
-```
-
-#### Install Requirements
-
-```bash
-pip install -r requirements.txt
-pip uninstall torch torchaudio
-pip3 install --pre torch torchaudio --index-url https://download.pytorch.org/whl/nightly/cpu
-```
-
### Using Docker
#### docker-compose.yaml configuration
@@ -157,6 +136,20 @@ As above, modify the corresponding parameters based on your actual situation, th
docker run --rm -it --gpus=all --env=is_half=False --volume=G:\GPT-SoVITS-DockerTest\output:/workspace/output --volume=G:\GPT-SoVITS-DockerTest\logs:/workspace/logs --volume=G:\GPT-SoVITS-DockerTest\SoVITS_weights:/workspace/SoVITS_weights --workdir=/workspace -p 9880:9880 -p 9871:9871 -p 9872:9872 -p 9873:9873 -p 9874:9874 --shm-size="16G" -d breakstring/gpt-sovits:xxxxx
```
+## Pretrained Models
+
+Download pretrained models from [GPT-SoVITS Models](https://huggingface.co/lj1995/GPT-SoVITS) and place them in `GPT_SoVITS/pretrained_models`.
+
+For UVR5 (Vocals/Accompaniment Separation & Reverberation Removal, additionally), download models from [UVR5 Weights](https://huggingface.co/lj1995/VoiceConversionWebUI/tree/main/uvr5_weights) and place them in `tools/uvr5/uvr5_weights`.
+
+Users in China region can download these two models by entering the links below and clicking "Download a copy"
+
+- [GPT-SoVITS Models](https://www.icloud.com.cn/iclouddrive/056y_Xog_HXpALuVUjscIwTtg#GPT-SoVITS_Models)
+
+- [UVR5 Weights](https://www.icloud.com.cn/iclouddrive/0bekRKDiJXboFhbfm3lM2fVbA#UVR5_Weights)
+
+For Chinese ASR (additionally), download models from [Damo ASR Model](https://modelscope.cn/models/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/files), [Damo VAD Model](https://modelscope.cn/models/damo/speech_fsmn_vad_zh-cn-16k-common-pytorch/files), and [Damo Punc Model](https://modelscope.cn/models/damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch/files) and place them in `tools/damo_asr/models`.
+
## Dataset Format
The TTS annotation .list file format:
@@ -182,7 +175,7 @@ D:\GPT-SoVITS\xxx/xxx.wav|xxx|en|I like playing Genshin.
- [ ] **High Priority:**
- [x] Localization in Japanese and English.
- - [ ] User guide.
+ - [x] User guide.
- [x] Japanese and English dataset fine tune training.
- [ ] **Features:**
@@ -227,27 +220,34 @@ ASR processing is performed through Faster_Whisper(ASR marking except Chinese)
python ./tools/damo_asr/WhisperASR.py -i -o