# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ Credits This code is modified from https://github.com/GitYCC/g2pW """ from typing import Dict from typing import List from typing import Optional from typing import Tuple import numpy as np from .utils import tokenize_and_map ANCHOR_CHAR = "▁" def prepare_onnx_input( tokenizer, labels: List[str], char2phonemes: Dict[str, List[int]], chars: List[str], texts: List[str], query_ids: List[int], use_mask: bool = False, window_size: int = None, max_len: int = 512, char2id: Optional[Dict[str, int]] = None, char_phoneme_masks: Optional[Dict[str, List[int]]] = None, ) -> Dict[str, np.array]: if window_size is not None: truncated_texts, truncated_query_ids = _truncate_texts( window_size=window_size, texts=texts, query_ids=query_ids ) input_ids = [] token_type_ids = [] attention_masks = [] phoneme_masks = [] char_ids = [] position_ids = [] tokenized_cache = {} if char2id is None: char2id = {char: idx for idx, char in enumerate(chars)} if use_mask: if char_phoneme_masks is None: char_phoneme_masks = { char: [1 if i in char2phonemes[char] else 0 for i in range(len(labels))] for char in char2phonemes } else: full_phoneme_mask = [1] * len(labels) for idx in range(len(texts)): text = (truncated_texts if window_size else texts)[idx].lower() query_id = (truncated_query_ids if window_size else query_ids)[idx] cached = tokenized_cache.get(text) if cached is None: try: tokens, text2token, token2text = tokenize_and_map(tokenizer=tokenizer, text=text) except Exception: print(f'warning: text "{text}" is invalid') return {} if len(tokens) <= max_len - 2: processed_tokens = ["[CLS]"] + tokens + ["[SEP]"] shared_input_id = list(np.array(tokenizer.convert_tokens_to_ids(processed_tokens))) shared_token_type_id = list(np.zeros((len(processed_tokens),), dtype=int)) shared_attention_mask = list(np.ones((len(processed_tokens),), dtype=int)) cached = { "is_short": True, "tokens": tokens, "text2token": text2token, "token2text": token2text, "input_id": shared_input_id, "token_type_id": shared_token_type_id, "attention_mask": shared_attention_mask, } else: cached = { "is_short": False, "tokens": tokens, "text2token": text2token, "token2text": token2text, } tokenized_cache[text] = cached if cached["is_short"]: text_for_query = text query_id_for_query = query_id text2token_for_query = cached["text2token"] input_id = cached["input_id"] token_type_id = cached["token_type_id"] attention_mask = cached["attention_mask"] else: ( text_for_query, query_id_for_query, tokens_for_query, text2token_for_query, _token2text_for_query, ) = _truncate( max_len=max_len, text=text, query_id=query_id, tokens=cached["tokens"], text2token=cached["text2token"], token2text=cached["token2text"], ) processed_tokens = ["[CLS]"] + tokens_for_query + ["[SEP]"] input_id = list(np.array(tokenizer.convert_tokens_to_ids(processed_tokens))) token_type_id = list(np.zeros((len(processed_tokens),), dtype=int)) attention_mask = list(np.ones((len(processed_tokens),), dtype=int)) query_char = text_for_query[query_id_for_query] if use_mask: phoneme_mask = char_phoneme_masks[query_char] else: phoneme_mask = full_phoneme_mask char_id = char2id[query_char] position_id = text2token_for_query[query_id_for_query] + 1 # [CLS] token locate at first place input_ids.append(input_id) token_type_ids.append(token_type_id) attention_masks.append(attention_mask) phoneme_masks.append(phoneme_mask) char_ids.append(char_id) position_ids.append(position_id) max_token_length = max(len(seq) for seq in input_ids) def _pad_sequences(sequences, pad_value=0): return [seq + [pad_value] * (max_token_length - len(seq)) for seq in sequences] outputs = { "input_ids": np.array(_pad_sequences(input_ids, pad_value=0)).astype(np.int64), "token_type_ids": np.array(_pad_sequences(token_type_ids, pad_value=0)).astype(np.int64), "attention_masks": np.array(_pad_sequences(attention_masks, pad_value=0)).astype(np.int64), "phoneme_masks": np.array(phoneme_masks).astype(np.float32), "char_ids": np.array(char_ids).astype(np.int64), "position_ids": np.array(position_ids).astype(np.int64), } return outputs def _truncate_texts(window_size: int, texts: List[str], query_ids: List[int]) -> Tuple[List[str], List[int]]: truncated_texts = [] truncated_query_ids = [] for text, query_id in zip(texts, query_ids): start = max(0, query_id - window_size // 2) end = min(len(text), query_id + window_size // 2) truncated_text = text[start:end] truncated_texts.append(truncated_text) truncated_query_id = query_id - start truncated_query_ids.append(truncated_query_id) return truncated_texts, truncated_query_ids def _truncate( max_len: int, text: str, query_id: int, tokens: List[str], text2token: List[int], token2text: List[Tuple[int]] ): truncate_len = max_len - 2 if len(tokens) <= truncate_len: return (text, query_id, tokens, text2token, token2text) token_position = text2token[query_id] token_start = token_position - truncate_len // 2 token_end = token_start + truncate_len font_exceed_dist = -token_start back_exceed_dist = token_end - len(tokens) if font_exceed_dist > 0: token_start += font_exceed_dist token_end += font_exceed_dist elif back_exceed_dist > 0: token_start -= back_exceed_dist token_end -= back_exceed_dist start = token2text[token_start][0] end = token2text[token_end - 1][1] return ( text[start:end], query_id - start, tokens[token_start:token_end], [i - token_start if i is not None else None for i in text2token[start:end]], [(s - start, e - start) for s, e in token2text[token_start:token_end]], ) def get_phoneme_labels(polyphonic_chars: List[List[str]]) -> Tuple[List[str], Dict[str, List[int]]]: labels = sorted(list(set([phoneme for char, phoneme in polyphonic_chars]))) char2phonemes = {} for char, phoneme in polyphonic_chars: if char not in char2phonemes: char2phonemes[char] = [] char2phonemes[char].append(labels.index(phoneme)) return labels, char2phonemes def get_char_phoneme_labels(polyphonic_chars: List[List[str]]) -> Tuple[List[str], Dict[str, List[int]]]: labels = sorted(list(set([f"{char} {phoneme}" for char, phoneme in polyphonic_chars]))) char2phonemes = {} for char, phoneme in polyphonic_chars: if char not in char2phonemes: char2phonemes[char] = [] char2phonemes[char].append(labels.index(f"{char} {phoneme}")) return labels, char2phonemes