mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2026-04-29 21:00:42 +08:00
* Enhance G2P processing by implementing batch input handling in _g2p function, improving efficiency. Update prepare_onnx_input to utilize caching for tokenization and add optional parameters for character ID mapping and phoneme masks. Refactor G2PWOnnxConverter to streamline model loading and configuration management. * Enhance G2PW model input handling by introducing polyphonic context character support and updating the data preparation method to return additional query IDs. This improves the processing of polyphonic characters in sentences.
224 lines
8.2 KiB
Python
224 lines
8.2 KiB
Python
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""
|
|
Credits
|
|
This code is modified from https://github.com/GitYCC/g2pW
|
|
"""
|
|
|
|
from typing import Dict
|
|
from typing import List
|
|
from typing import Optional
|
|
from typing import Tuple
|
|
|
|
import numpy as np
|
|
|
|
from .utils import tokenize_and_map
|
|
|
|
ANCHOR_CHAR = "▁"
|
|
|
|
|
|
def prepare_onnx_input(
|
|
tokenizer,
|
|
labels: List[str],
|
|
char2phonemes: Dict[str, List[int]],
|
|
chars: List[str],
|
|
texts: List[str],
|
|
query_ids: List[int],
|
|
use_mask: bool = False,
|
|
window_size: int = None,
|
|
max_len: int = 512,
|
|
char2id: Optional[Dict[str, int]] = None,
|
|
char_phoneme_masks: Optional[Dict[str, List[int]]] = None,
|
|
) -> Dict[str, np.array]:
|
|
if window_size is not None:
|
|
truncated_texts, truncated_query_ids = _truncate_texts(
|
|
window_size=window_size, texts=texts, query_ids=query_ids
|
|
)
|
|
input_ids = []
|
|
token_type_ids = []
|
|
attention_masks = []
|
|
phoneme_masks = []
|
|
char_ids = []
|
|
position_ids = []
|
|
tokenized_cache = {}
|
|
|
|
if char2id is None:
|
|
char2id = {char: idx for idx, char in enumerate(chars)}
|
|
if use_mask:
|
|
if char_phoneme_masks is None:
|
|
char_phoneme_masks = {
|
|
char: [1 if i in char2phonemes[char] else 0 for i in range(len(labels))]
|
|
for char in char2phonemes
|
|
}
|
|
else:
|
|
full_phoneme_mask = [1] * len(labels)
|
|
|
|
for idx in range(len(texts)):
|
|
text = (truncated_texts if window_size else texts)[idx].lower()
|
|
query_id = (truncated_query_ids if window_size else query_ids)[idx]
|
|
|
|
cached = tokenized_cache.get(text)
|
|
if cached is None:
|
|
try:
|
|
tokens, text2token, token2text = tokenize_and_map(tokenizer=tokenizer, text=text)
|
|
except Exception:
|
|
print(f'warning: text "{text}" is invalid')
|
|
return {}
|
|
|
|
if len(tokens) <= max_len - 2:
|
|
processed_tokens = ["[CLS]"] + tokens + ["[SEP]"]
|
|
shared_input_id = list(np.array(tokenizer.convert_tokens_to_ids(processed_tokens)))
|
|
shared_token_type_id = list(np.zeros((len(processed_tokens),), dtype=int))
|
|
shared_attention_mask = list(np.ones((len(processed_tokens),), dtype=int))
|
|
cached = {
|
|
"is_short": True,
|
|
"tokens": tokens,
|
|
"text2token": text2token,
|
|
"token2text": token2text,
|
|
"input_id": shared_input_id,
|
|
"token_type_id": shared_token_type_id,
|
|
"attention_mask": shared_attention_mask,
|
|
}
|
|
else:
|
|
cached = {
|
|
"is_short": False,
|
|
"tokens": tokens,
|
|
"text2token": text2token,
|
|
"token2text": token2text,
|
|
}
|
|
tokenized_cache[text] = cached
|
|
|
|
if cached["is_short"]:
|
|
text_for_query = text
|
|
query_id_for_query = query_id
|
|
text2token_for_query = cached["text2token"]
|
|
input_id = cached["input_id"]
|
|
token_type_id = cached["token_type_id"]
|
|
attention_mask = cached["attention_mask"]
|
|
else:
|
|
(
|
|
text_for_query,
|
|
query_id_for_query,
|
|
tokens_for_query,
|
|
text2token_for_query,
|
|
_token2text_for_query,
|
|
) = _truncate(
|
|
max_len=max_len,
|
|
text=text,
|
|
query_id=query_id,
|
|
tokens=cached["tokens"],
|
|
text2token=cached["text2token"],
|
|
token2text=cached["token2text"],
|
|
)
|
|
processed_tokens = ["[CLS]"] + tokens_for_query + ["[SEP]"]
|
|
input_id = list(np.array(tokenizer.convert_tokens_to_ids(processed_tokens)))
|
|
token_type_id = list(np.zeros((len(processed_tokens),), dtype=int))
|
|
attention_mask = list(np.ones((len(processed_tokens),), dtype=int))
|
|
|
|
query_char = text_for_query[query_id_for_query]
|
|
if use_mask:
|
|
phoneme_mask = char_phoneme_masks[query_char]
|
|
else:
|
|
phoneme_mask = full_phoneme_mask
|
|
char_id = char2id[query_char]
|
|
position_id = text2token_for_query[query_id_for_query] + 1 # [CLS] token locate at first place
|
|
|
|
input_ids.append(input_id)
|
|
token_type_ids.append(token_type_id)
|
|
attention_masks.append(attention_mask)
|
|
phoneme_masks.append(phoneme_mask)
|
|
char_ids.append(char_id)
|
|
position_ids.append(position_id)
|
|
|
|
max_token_length = max(len(seq) for seq in input_ids)
|
|
|
|
def _pad_sequences(sequences, pad_value=0):
|
|
return [seq + [pad_value] * (max_token_length - len(seq)) for seq in sequences]
|
|
|
|
outputs = {
|
|
"input_ids": np.array(_pad_sequences(input_ids, pad_value=0)).astype(np.int64),
|
|
"token_type_ids": np.array(_pad_sequences(token_type_ids, pad_value=0)).astype(np.int64),
|
|
"attention_masks": np.array(_pad_sequences(attention_masks, pad_value=0)).astype(np.int64),
|
|
"phoneme_masks": np.array(phoneme_masks).astype(np.float32),
|
|
"char_ids": np.array(char_ids).astype(np.int64),
|
|
"position_ids": np.array(position_ids).astype(np.int64),
|
|
}
|
|
return outputs
|
|
|
|
|
|
def _truncate_texts(window_size: int, texts: List[str], query_ids: List[int]) -> Tuple[List[str], List[int]]:
|
|
truncated_texts = []
|
|
truncated_query_ids = []
|
|
for text, query_id in zip(texts, query_ids):
|
|
start = max(0, query_id - window_size // 2)
|
|
end = min(len(text), query_id + window_size // 2)
|
|
truncated_text = text[start:end]
|
|
truncated_texts.append(truncated_text)
|
|
|
|
truncated_query_id = query_id - start
|
|
truncated_query_ids.append(truncated_query_id)
|
|
return truncated_texts, truncated_query_ids
|
|
|
|
|
|
def _truncate(
|
|
max_len: int, text: str, query_id: int, tokens: List[str], text2token: List[int], token2text: List[Tuple[int]]
|
|
):
|
|
truncate_len = max_len - 2
|
|
if len(tokens) <= truncate_len:
|
|
return (text, query_id, tokens, text2token, token2text)
|
|
|
|
token_position = text2token[query_id]
|
|
|
|
token_start = token_position - truncate_len // 2
|
|
token_end = token_start + truncate_len
|
|
font_exceed_dist = -token_start
|
|
back_exceed_dist = token_end - len(tokens)
|
|
if font_exceed_dist > 0:
|
|
token_start += font_exceed_dist
|
|
token_end += font_exceed_dist
|
|
elif back_exceed_dist > 0:
|
|
token_start -= back_exceed_dist
|
|
token_end -= back_exceed_dist
|
|
|
|
start = token2text[token_start][0]
|
|
end = token2text[token_end - 1][1]
|
|
|
|
return (
|
|
text[start:end],
|
|
query_id - start,
|
|
tokens[token_start:token_end],
|
|
[i - token_start if i is not None else None for i in text2token[start:end]],
|
|
[(s - start, e - start) for s, e in token2text[token_start:token_end]],
|
|
)
|
|
|
|
|
|
def get_phoneme_labels(polyphonic_chars: List[List[str]]) -> Tuple[List[str], Dict[str, List[int]]]:
|
|
labels = sorted(list(set([phoneme for char, phoneme in polyphonic_chars])))
|
|
char2phonemes = {}
|
|
for char, phoneme in polyphonic_chars:
|
|
if char not in char2phonemes:
|
|
char2phonemes[char] = []
|
|
char2phonemes[char].append(labels.index(phoneme))
|
|
return labels, char2phonemes
|
|
|
|
|
|
def get_char_phoneme_labels(polyphonic_chars: List[List[str]]) -> Tuple[List[str], Dict[str, List[int]]]:
|
|
labels = sorted(list(set([f"{char} {phoneme}" for char, phoneme in polyphonic_chars])))
|
|
char2phonemes = {}
|
|
for char, phoneme in polyphonic_chars:
|
|
if char not in char2phonemes:
|
|
char2phonemes[char] = []
|
|
char2phonemes[char].append(labels.index(f"{char} {phoneme}"))
|
|
return labels, char2phonemes
|