白菜工厂1145号员工 ba8de9b760
优化 G2PW 的推理输入构造与多音字处理流程,减少重复计算,降低长句场景下的推理开销 (#2763)
* Enhance G2P processing by implementing batch input handling in _g2p function, improving efficiency. Update prepare_onnx_input to utilize caching for tokenization and add optional parameters for character ID mapping and phoneme masks. Refactor G2PWOnnxConverter to streamline model loading and configuration management.

* Enhance G2PW model input handling by introducing polyphonic context character support and updating the data preparation method to return additional query IDs. This improves the processing of polyphonic characters in sentences.
2026-04-18 16:52:32 +08:00

224 lines
8.2 KiB
Python

# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Credits
This code is modified from https://github.com/GitYCC/g2pW
"""
from typing import Dict
from typing import List
from typing import Optional
from typing import Tuple
import numpy as np
from .utils import tokenize_and_map
ANCHOR_CHAR = ""
def prepare_onnx_input(
tokenizer,
labels: List[str],
char2phonemes: Dict[str, List[int]],
chars: List[str],
texts: List[str],
query_ids: List[int],
use_mask: bool = False,
window_size: int = None,
max_len: int = 512,
char2id: Optional[Dict[str, int]] = None,
char_phoneme_masks: Optional[Dict[str, List[int]]] = None,
) -> Dict[str, np.array]:
if window_size is not None:
truncated_texts, truncated_query_ids = _truncate_texts(
window_size=window_size, texts=texts, query_ids=query_ids
)
input_ids = []
token_type_ids = []
attention_masks = []
phoneme_masks = []
char_ids = []
position_ids = []
tokenized_cache = {}
if char2id is None:
char2id = {char: idx for idx, char in enumerate(chars)}
if use_mask:
if char_phoneme_masks is None:
char_phoneme_masks = {
char: [1 if i in char2phonemes[char] else 0 for i in range(len(labels))]
for char in char2phonemes
}
else:
full_phoneme_mask = [1] * len(labels)
for idx in range(len(texts)):
text = (truncated_texts if window_size else texts)[idx].lower()
query_id = (truncated_query_ids if window_size else query_ids)[idx]
cached = tokenized_cache.get(text)
if cached is None:
try:
tokens, text2token, token2text = tokenize_and_map(tokenizer=tokenizer, text=text)
except Exception:
print(f'warning: text "{text}" is invalid')
return {}
if len(tokens) <= max_len - 2:
processed_tokens = ["[CLS]"] + tokens + ["[SEP]"]
shared_input_id = list(np.array(tokenizer.convert_tokens_to_ids(processed_tokens)))
shared_token_type_id = list(np.zeros((len(processed_tokens),), dtype=int))
shared_attention_mask = list(np.ones((len(processed_tokens),), dtype=int))
cached = {
"is_short": True,
"tokens": tokens,
"text2token": text2token,
"token2text": token2text,
"input_id": shared_input_id,
"token_type_id": shared_token_type_id,
"attention_mask": shared_attention_mask,
}
else:
cached = {
"is_short": False,
"tokens": tokens,
"text2token": text2token,
"token2text": token2text,
}
tokenized_cache[text] = cached
if cached["is_short"]:
text_for_query = text
query_id_for_query = query_id
text2token_for_query = cached["text2token"]
input_id = cached["input_id"]
token_type_id = cached["token_type_id"]
attention_mask = cached["attention_mask"]
else:
(
text_for_query,
query_id_for_query,
tokens_for_query,
text2token_for_query,
_token2text_for_query,
) = _truncate(
max_len=max_len,
text=text,
query_id=query_id,
tokens=cached["tokens"],
text2token=cached["text2token"],
token2text=cached["token2text"],
)
processed_tokens = ["[CLS]"] + tokens_for_query + ["[SEP]"]
input_id = list(np.array(tokenizer.convert_tokens_to_ids(processed_tokens)))
token_type_id = list(np.zeros((len(processed_tokens),), dtype=int))
attention_mask = list(np.ones((len(processed_tokens),), dtype=int))
query_char = text_for_query[query_id_for_query]
if use_mask:
phoneme_mask = char_phoneme_masks[query_char]
else:
phoneme_mask = full_phoneme_mask
char_id = char2id[query_char]
position_id = text2token_for_query[query_id_for_query] + 1 # [CLS] token locate at first place
input_ids.append(input_id)
token_type_ids.append(token_type_id)
attention_masks.append(attention_mask)
phoneme_masks.append(phoneme_mask)
char_ids.append(char_id)
position_ids.append(position_id)
max_token_length = max(len(seq) for seq in input_ids)
def _pad_sequences(sequences, pad_value=0):
return [seq + [pad_value] * (max_token_length - len(seq)) for seq in sequences]
outputs = {
"input_ids": np.array(_pad_sequences(input_ids, pad_value=0)).astype(np.int64),
"token_type_ids": np.array(_pad_sequences(token_type_ids, pad_value=0)).astype(np.int64),
"attention_masks": np.array(_pad_sequences(attention_masks, pad_value=0)).astype(np.int64),
"phoneme_masks": np.array(phoneme_masks).astype(np.float32),
"char_ids": np.array(char_ids).astype(np.int64),
"position_ids": np.array(position_ids).astype(np.int64),
}
return outputs
def _truncate_texts(window_size: int, texts: List[str], query_ids: List[int]) -> Tuple[List[str], List[int]]:
truncated_texts = []
truncated_query_ids = []
for text, query_id in zip(texts, query_ids):
start = max(0, query_id - window_size // 2)
end = min(len(text), query_id + window_size // 2)
truncated_text = text[start:end]
truncated_texts.append(truncated_text)
truncated_query_id = query_id - start
truncated_query_ids.append(truncated_query_id)
return truncated_texts, truncated_query_ids
def _truncate(
max_len: int, text: str, query_id: int, tokens: List[str], text2token: List[int], token2text: List[Tuple[int]]
):
truncate_len = max_len - 2
if len(tokens) <= truncate_len:
return (text, query_id, tokens, text2token, token2text)
token_position = text2token[query_id]
token_start = token_position - truncate_len // 2
token_end = token_start + truncate_len
font_exceed_dist = -token_start
back_exceed_dist = token_end - len(tokens)
if font_exceed_dist > 0:
token_start += font_exceed_dist
token_end += font_exceed_dist
elif back_exceed_dist > 0:
token_start -= back_exceed_dist
token_end -= back_exceed_dist
start = token2text[token_start][0]
end = token2text[token_end - 1][1]
return (
text[start:end],
query_id - start,
tokens[token_start:token_end],
[i - token_start if i is not None else None for i in text2token[start:end]],
[(s - start, e - start) for s, e in token2text[token_start:token_end]],
)
def get_phoneme_labels(polyphonic_chars: List[List[str]]) -> Tuple[List[str], Dict[str, List[int]]]:
labels = sorted(list(set([phoneme for char, phoneme in polyphonic_chars])))
char2phonemes = {}
for char, phoneme in polyphonic_chars:
if char not in char2phonemes:
char2phonemes[char] = []
char2phonemes[char].append(labels.index(phoneme))
return labels, char2phonemes
def get_char_phoneme_labels(polyphonic_chars: List[List[str]]) -> Tuple[List[str], Dict[str, List[int]]]:
labels = sorted(list(set([f"{char} {phoneme}" for char, phoneme in polyphonic_chars])))
char2phonemes = {}
for char, phoneme in polyphonic_chars:
if char not in char2phonemes:
char2phonemes[char] = []
char2phonemes[char].append(labels.index(f"{char} {phoneme}"))
return labels, char2phonemes