From 10e589f99ec663982e6d663a2c7e8b122fc7e478 Mon Sep 17 00:00:00 2001 From: RVC-Boss <129054828+RVC-Boss@users.noreply.github.com> Date: Fri, 2 Aug 2024 15:47:05 +0800 Subject: [PATCH] support gpt-sovits v2 support gpt-sovits v2 --- GPT_SoVITS/text/g2pw/dataset.py | 166 +++++++++++++++++ GPT_SoVITS/text/g2pw/g2pw.py | 145 +++++++++++++++ GPT_SoVITS/text/g2pw/onnx_api.py | 238 +++++++++++++++++++++++++ GPT_SoVITS/text/g2pw/polyphonic.pickle | Bin 0 -> 1498 bytes GPT_SoVITS/text/g2pw/polyphonic.rep | 53 ++++++ GPT_SoVITS/text/g2pw/utils.py | 145 +++++++++++++++ 6 files changed, 747 insertions(+) create mode 100644 GPT_SoVITS/text/g2pw/dataset.py create mode 100644 GPT_SoVITS/text/g2pw/g2pw.py create mode 100644 GPT_SoVITS/text/g2pw/onnx_api.py create mode 100644 GPT_SoVITS/text/g2pw/polyphonic.pickle create mode 100644 GPT_SoVITS/text/g2pw/polyphonic.rep create mode 100644 GPT_SoVITS/text/g2pw/utils.py diff --git a/GPT_SoVITS/text/g2pw/dataset.py b/GPT_SoVITS/text/g2pw/dataset.py new file mode 100644 index 0000000..0fb28da --- /dev/null +++ b/GPT_SoVITS/text/g2pw/dataset.py @@ -0,0 +1,166 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Credits + This code is modified from https://github.com/GitYCC/g2pW +""" +from typing import Dict +from typing import List +from typing import Tuple + +import numpy as np + +from .utils import tokenize_and_map + +ANCHOR_CHAR = '▁' + + +def prepare_onnx_input(tokenizer, + labels: List[str], + char2phonemes: Dict[str, List[int]], + chars: List[str], + texts: List[str], + query_ids: List[int], + use_mask: bool=False, + window_size: int=None, + max_len: int=512) -> Dict[str, np.array]: + if window_size is not None: + truncated_texts, truncated_query_ids = _truncate_texts( + window_size=window_size, texts=texts, query_ids=query_ids) + input_ids = [] + token_type_ids = [] + attention_masks = [] + phoneme_masks = [] + char_ids = [] + position_ids = [] + + for idx in range(len(texts)): + text = (truncated_texts if window_size else texts)[idx].lower() + query_id = (truncated_query_ids if window_size else query_ids)[idx] + + try: + tokens, text2token, token2text = tokenize_and_map( + tokenizer=tokenizer, text=text) + except Exception: + print(f'warning: text "{text}" is invalid') + return {} + + text, query_id, tokens, text2token, token2text = _truncate( + max_len=max_len, + text=text, + query_id=query_id, + tokens=tokens, + text2token=text2token, + token2text=token2text) + + processed_tokens = ['[CLS]'] + tokens + ['[SEP]'] + + input_id = list( + np.array(tokenizer.convert_tokens_to_ids(processed_tokens))) + token_type_id = list(np.zeros((len(processed_tokens), ), dtype=int)) + attention_mask = list(np.ones((len(processed_tokens), ), dtype=int)) + + query_char = text[query_id] + phoneme_mask = [1 if i in char2phonemes[query_char] else 0 for i in range(len(labels))] \ + if use_mask else [1] * len(labels) + char_id = chars.index(query_char) + position_id = text2token[ + query_id] + 1 # [CLS] token locate at first place + + input_ids.append(input_id) + token_type_ids.append(token_type_id) + attention_masks.append(attention_mask) + phoneme_masks.append(phoneme_mask) + char_ids.append(char_id) + position_ids.append(position_id) + + outputs = { + 'input_ids': np.array(input_ids).astype(np.int64), + 'token_type_ids': np.array(token_type_ids).astype(np.int64), + 'attention_masks': np.array(attention_masks).astype(np.int64), + 'phoneme_masks': np.array(phoneme_masks).astype(np.float32), + 'char_ids': np.array(char_ids).astype(np.int64), + 'position_ids': np.array(position_ids).astype(np.int64), + } + return outputs + + +def _truncate_texts(window_size: int, texts: List[str], + query_ids: List[int]) -> Tuple[List[str], List[int]]: + truncated_texts = [] + truncated_query_ids = [] + for text, query_id in zip(texts, query_ids): + start = max(0, query_id - window_size // 2) + end = min(len(text), query_id + window_size // 2) + truncated_text = text[start:end] + truncated_texts.append(truncated_text) + + truncated_query_id = query_id - start + truncated_query_ids.append(truncated_query_id) + return truncated_texts, truncated_query_ids + + +def _truncate(max_len: int, + text: str, + query_id: int, + tokens: List[str], + text2token: List[int], + token2text: List[Tuple[int]]): + truncate_len = max_len - 2 + if len(tokens) <= truncate_len: + return (text, query_id, tokens, text2token, token2text) + + token_position = text2token[query_id] + + token_start = token_position - truncate_len // 2 + token_end = token_start + truncate_len + font_exceed_dist = -token_start + back_exceed_dist = token_end - len(tokens) + if font_exceed_dist > 0: + token_start += font_exceed_dist + token_end += font_exceed_dist + elif back_exceed_dist > 0: + token_start -= back_exceed_dist + token_end -= back_exceed_dist + + start = token2text[token_start][0] + end = token2text[token_end - 1][1] + + return (text[start:end], query_id - start, tokens[token_start:token_end], [ + i - token_start if i is not None else None + for i in text2token[start:end] + ], [(s - start, e - start) for s, e in token2text[token_start:token_end]]) + + +def get_phoneme_labels(polyphonic_chars: List[List[str]] + ) -> Tuple[List[str], Dict[str, List[int]]]: + labels = sorted(list(set([phoneme for char, phoneme in polyphonic_chars]))) + char2phonemes = {} + for char, phoneme in polyphonic_chars: + if char not in char2phonemes: + char2phonemes[char] = [] + char2phonemes[char].append(labels.index(phoneme)) + return labels, char2phonemes + + +def get_char_phoneme_labels(polyphonic_chars: List[List[str]] + ) -> Tuple[List[str], Dict[str, List[int]]]: + labels = sorted( + list(set([f'{char} {phoneme}' for char, phoneme in polyphonic_chars]))) + char2phonemes = {} + for char, phoneme in polyphonic_chars: + if char not in char2phonemes: + char2phonemes[char] = [] + char2phonemes[char].append(labels.index(f'{char} {phoneme}')) + return labels, char2phonemes diff --git a/GPT_SoVITS/text/g2pw/g2pw.py b/GPT_SoVITS/text/g2pw/g2pw.py new file mode 100644 index 0000000..ef6e339 --- /dev/null +++ b/GPT_SoVITS/text/g2pw/g2pw.py @@ -0,0 +1,145 @@ +# This code is modified from https://github.com/mozillazg/pypinyin-g2pW + +import pickle +import os + +from pypinyin.constants import RE_HANS +from pypinyin.core import Pinyin, Style +from pypinyin.seg.simpleseg import simple_seg +from pypinyin.converter import UltimateConverter +from pypinyin.contrib.tone_convert import to_tone +from .onnx_api import G2PWOnnxConverter + +current_file_path = os.path.dirname(__file__) +CACHE_PATH = os.path.join(current_file_path, "polyphonic.pickle") +PP_DICT_PATH = os.path.join(current_file_path, "polyphonic.rep") + + +class G2PWPinyin(Pinyin): + def __init__(self, model_dir='G2PWModel/', model_source=None, + enable_non_tradional_chinese=True, + v_to_u=False, neutral_tone_with_five=False, tone_sandhi=False, **kwargs): + self._g2pw = G2PWOnnxConverter( + model_dir=model_dir, + style='pinyin', + model_source=model_source, + enable_non_tradional_chinese=enable_non_tradional_chinese, + ) + self._converter = Converter( + self._g2pw, v_to_u=v_to_u, + neutral_tone_with_five=neutral_tone_with_five, + tone_sandhi=tone_sandhi, + ) + + def get_seg(self, **kwargs): + return simple_seg + + +class Converter(UltimateConverter): + def __init__(self, g2pw_instance, v_to_u=False, + neutral_tone_with_five=False, + tone_sandhi=False, **kwargs): + super(Converter, self).__init__( + v_to_u=v_to_u, + neutral_tone_with_five=neutral_tone_with_five, + tone_sandhi=tone_sandhi, **kwargs) + + self._g2pw = g2pw_instance + + def convert(self, words, style, heteronym, errors, strict, **kwargs): + pys = [] + if RE_HANS.match(words): + pys = self._to_pinyin(words, style=style, heteronym=heteronym, + errors=errors, strict=strict) + post_data = self.post_pinyin(words, heteronym, pys) + if post_data is not None: + pys = post_data + + pys = self.convert_styles( + pys, words, style, heteronym, errors, strict) + + else: + py = self.handle_nopinyin(words, style=style, errors=errors, + heteronym=heteronym, strict=strict) + if py: + pys.extend(py) + + return _remove_dup_and_empty(pys) + + def _to_pinyin(self, han, style, heteronym, errors, strict, **kwargs): + pinyins = [] + + if han in pp_dict: + phns = pp_dict[han] + for ph in phns: + pinyins.append([ph]) + return pinyins + + g2pw_pinyin = self._g2pw(han) + + if not g2pw_pinyin: # g2pw 不支持的汉字改为使用 pypinyin 原有逻辑 + return super(Converter, self).convert( + han, Style.TONE, heteronym, errors, strict, **kwargs) + + for i, item in enumerate(g2pw_pinyin[0]): + if item is None: # g2pw 不支持的汉字改为使用 pypinyin 原有逻辑 + py = super(Converter, self).convert( + han[i], Style.TONE, heteronym, errors, strict, **kwargs) + pinyins.extend(py) + else: + pinyins.append([to_tone(item)]) + + return pinyins + + +def _remove_dup_items(lst, remove_empty=False): + new_lst = [] + for item in lst: + if remove_empty and not item: + continue + if item not in new_lst: + new_lst.append(item) + return new_lst + + +def _remove_dup_and_empty(lst_list): + new_lst_list = [] + for lst in lst_list: + lst = _remove_dup_items(lst, remove_empty=True) + if lst: + new_lst_list.append(lst) + else: + new_lst_list.append(['']) + + return new_lst_list + + +def cache_dict(polyphonic_dict, file_path): + with open(file_path, "wb") as pickle_file: + pickle.dump(polyphonic_dict, pickle_file) + + +def get_dict(): + if os.path.exists(CACHE_PATH): + with open(CACHE_PATH, "rb") as pickle_file: + polyphonic_dict = pickle.load(pickle_file) + else: + polyphonic_dict = read_dict() + cache_dict(polyphonic_dict, CACHE_PATH) + + return polyphonic_dict + + +def read_dict(): + polyphonic_dict = {} + with open(PP_DICT_PATH) as f: + line = f.readline() + while line: + key, value_str = line.split(':') + value = eval(value_str.strip()) + polyphonic_dict[key.strip()] = value + line = f.readline() + return polyphonic_dict + + +pp_dict = get_dict() \ No newline at end of file diff --git a/GPT_SoVITS/text/g2pw/onnx_api.py b/GPT_SoVITS/text/g2pw/onnx_api.py new file mode 100644 index 0000000..ff903cc --- /dev/null +++ b/GPT_SoVITS/text/g2pw/onnx_api.py @@ -0,0 +1,238 @@ +# This code is modified from https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/paddlespeech/t2s/frontend/g2pw +# This code is modified from https://github.com/GitYCC/g2pW + +import json +import os +import zipfile,requests +from typing import Any +from typing import Dict +from typing import List +from typing import Tuple + +import numpy as np +import onnxruntime +from opencc import OpenCC +from transformers import AutoTokenizer +from pypinyin import pinyin +from pypinyin import Style + +from .dataset import get_char_phoneme_labels +from .dataset import get_phoneme_labels +from .dataset import prepare_onnx_input +from .utils import load_config +from ..zh_normalization.char_convert import tranditional_to_simplified + +model_version = '1.1' + + +def predict(session, onnx_input: Dict[str, Any], + labels: List[str]) -> Tuple[List[str], List[float]]: + all_preds = [] + all_confidences = [] + probs = session.run([], { + "input_ids": onnx_input['input_ids'], + "token_type_ids": onnx_input['token_type_ids'], + "attention_mask": onnx_input['attention_masks'], + "phoneme_mask": onnx_input['phoneme_masks'], + "char_ids": onnx_input['char_ids'], + "position_ids": onnx_input['position_ids'] + })[0] + + preds = np.argmax(probs, axis=1).tolist() + max_probs = [] + for index, arr in zip(preds, probs.tolist()): + max_probs.append(arr[index]) + all_preds += [labels[pred] for pred in preds] + all_confidences += max_probs + + return all_preds, all_confidences + + +def download_and_decompress(model_dir: str='G2PWModel/'): + if not os.path.exists(model_dir): + parent_directory = os.path.dirname(model_dir) + zip_dir = os.path.join(parent_directory,"G2PWModel_1.1.zip") + extract_dir = os.path.join(parent_directory,"G2PWModel_1.1") + extract_dir_new = os.path.join(parent_directory,"G2PWModel") + print("Downloading g2pw model...") + modelscope_url = "https://paddlespeech.bj.bcebos.com/Parakeet/released_models/g2p/G2PWModel_1.1.zip" + with requests.get(modelscope_url, stream=True) as r: + r.raise_for_status() + with open(zip_dir, 'wb') as f: + for chunk in r.iter_content(chunk_size=8192): + if chunk: + f.write(chunk) + + print("Extracting g2pw model...") + with zipfile.ZipFile(zip_dir, "r") as zip_ref: + zip_ref.extractall(parent_directory) + + os.rename(extract_dir, extract_dir_new) + + return model_dir + +class G2PWOnnxConverter: + def __init__(self, + model_dir: str='G2PWModel/', + style: str='bopomofo', + model_source: str=None, + enable_non_tradional_chinese: bool=False): + uncompress_path = download_and_decompress(model_dir) + + sess_options = onnxruntime.SessionOptions() + sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL + sess_options.execution_mode = onnxruntime.ExecutionMode.ORT_SEQUENTIAL + sess_options.intra_op_num_threads = 2 + self.session_g2pW = onnxruntime.InferenceSession( + os.path.join(uncompress_path, 'g2pW.onnx'), + sess_options=sess_options, providers=['CUDAExecutionProvider', 'CPUExecutionProvider']) + # sess_options=sess_options) + self.config = load_config( + config_path=os.path.join(uncompress_path, 'config.py'), + use_default=True) + + self.model_source = model_source if model_source else self.config.model_source + self.enable_opencc = enable_non_tradional_chinese + + self.tokenizer = AutoTokenizer.from_pretrained(self.model_source) + + polyphonic_chars_path = os.path.join(uncompress_path, + 'POLYPHONIC_CHARS.txt') + monophonic_chars_path = os.path.join(uncompress_path, + 'MONOPHONIC_CHARS.txt') + self.polyphonic_chars = [ + line.split('\t') + for line in open(polyphonic_chars_path, encoding='utf-8').read() + .strip().split('\n') + ] + self.non_polyphonic = { + '一', '不', '和', '咋', '嗲', '剖', '差', '攢', '倒', '難', '奔', '勁', '拗', + '肖', '瘙', '誒', '泊', '听', '噢' + } + self.non_monophonic = {'似', '攢'} + self.monophonic_chars = [ + line.split('\t') + for line in open(monophonic_chars_path, encoding='utf-8').read() + .strip().split('\n') + ] + self.labels, self.char2phonemes = get_char_phoneme_labels( + polyphonic_chars=self.polyphonic_chars + ) if self.config.use_char_phoneme else get_phoneme_labels( + polyphonic_chars=self.polyphonic_chars) + + self.chars = sorted(list(self.char2phonemes.keys())) + + self.polyphonic_chars_new = set(self.chars) + for char in self.non_polyphonic: + if char in self.polyphonic_chars_new: + self.polyphonic_chars_new.remove(char) + + self.monophonic_chars_dict = { + char: phoneme + for char, phoneme in self.monophonic_chars + } + for char in self.non_monophonic: + if char in self.monophonic_chars_dict: + self.monophonic_chars_dict.pop(char) + + self.pos_tags = [ + 'UNK', 'A', 'C', 'D', 'I', 'N', 'P', 'T', 'V', 'DE', 'SHI' + ] + + with open( + os.path.join(uncompress_path, + 'bopomofo_to_pinyin_wo_tune_dict.json'), + 'r', + encoding='utf-8') as fr: + self.bopomofo_convert_dict = json.load(fr) + self.style_convert_func = { + 'bopomofo': lambda x: x, + 'pinyin': self._convert_bopomofo_to_pinyin, + }[style] + + with open( + os.path.join(uncompress_path, 'char_bopomofo_dict.json'), + 'r', + encoding='utf-8') as fr: + self.char_bopomofo_dict = json.load(fr) + + if self.enable_opencc: + self.cc = OpenCC('s2tw') + + def _convert_bopomofo_to_pinyin(self, bopomofo: str) -> str: + tone = bopomofo[-1] + assert tone in '12345' + component = self.bopomofo_convert_dict.get(bopomofo[:-1]) + if component: + return component + tone + else: + print(f'Warning: "{bopomofo}" cannot convert to pinyin') + return None + + def __call__(self, sentences: List[str]) -> List[List[str]]: + if isinstance(sentences, str): + sentences = [sentences] + + if self.enable_opencc: + translated_sentences = [] + for sent in sentences: + translated_sent = self.cc.convert(sent) + assert len(translated_sent) == len(sent) + translated_sentences.append(translated_sent) + sentences = translated_sentences + + texts, query_ids, sent_ids, partial_results = self._prepare_data( + sentences=sentences) + if len(texts) == 0: + # sentences no polyphonic words + return partial_results + + onnx_input = prepare_onnx_input( + tokenizer=self.tokenizer, + labels=self.labels, + char2phonemes=self.char2phonemes, + chars=self.chars, + texts=texts, + query_ids=query_ids, + use_mask=self.config.use_mask, + window_size=None) + + preds, confidences = predict( + session=self.session_g2pW, + onnx_input=onnx_input, + labels=self.labels) + if self.config.use_char_phoneme: + preds = [pred.split(' ')[1] for pred in preds] + + results = partial_results + for sent_id, query_id, pred in zip(sent_ids, query_ids, preds): + results[sent_id][query_id] = self.style_convert_func(pred) + + return results + + def _prepare_data( + self, sentences: List[str] + ) -> Tuple[List[str], List[int], List[int], List[List[str]]]: + texts, query_ids, sent_ids, partial_results = [], [], [], [] + for sent_id, sent in enumerate(sentences): + # pypinyin works well for Simplified Chinese than Traditional Chinese + sent_s = tranditional_to_simplified(sent) + pypinyin_result = pinyin( + sent_s, neutral_tone_with_five=True, style=Style.TONE3) + partial_result = [None] * len(sent) + for i, char in enumerate(sent): + if char in self.polyphonic_chars_new: + texts.append(sent) + query_ids.append(i) + sent_ids.append(sent_id) + elif char in self.monophonic_chars_dict: + partial_result[i] = self.style_convert_func( + self.monophonic_chars_dict[char]) + elif char in self.char_bopomofo_dict: + partial_result[i] = pypinyin_result[i][0] + # partial_result[i] = self.style_convert_func(self.char_bopomofo_dict[char][0]) + else: + partial_result[i] = pypinyin_result[i][0] + + partial_results.append(partial_result) + return texts, query_ids, sent_ids, partial_results diff --git a/GPT_SoVITS/text/g2pw/polyphonic.pickle b/GPT_SoVITS/text/g2pw/polyphonic.pickle new file mode 100644 index 0000000000000000000000000000000000000000..98cba72f643b42e048b7ff32f3a551ab8d6566e5 GIT binary patch literal 1498 zcmX|BT~8ZV5UsLzEg|K-4}I-Jq)3(W`B15Up#Pz&Y`pAUAU27;O0ZOk2pMem0)Z?a zah5^?YQ6*}5P_Y9@Dutce6PL#qGx98SeCrLXYS0IGiP=|ZT{!~z@NX88z$#^Tkfot z{DEgr$F~#X&6RH@fyvqQ#2%apP>k*-#(x;)%XbNKduwigogr1oM|TpVDmoVE<*uza z#}L9zE{yY~P*Q&F+15(|tN9WYn?+_ScC>A~TT-k>#Z)y41X(+<&lcs*OgW}nlRA=g zrtX@n-)r4Mx>{b4=kM*a4M|rDP)F6YAj8flm##44gELw+-W1vfhhb+4G$dU`N4FFG zBTlpTZ|&hHNfmPWYB|2goIAb2n-hsI#;nk!4oYLz{cqf<{>k`ZxfF^zt)A^lGu{Ib zGG0Ms%hJ=~qLC`II5MgbjFd?mw*G|?e;P{CAm$~$xTP2QBSyy4cl0@Cy4}9D2IATGc{iDi0WK-odn-Od&%u|B&q-_ z=|7E{jsj={Pq75oho3*V?fq=?YS4VgftXxSBK01VVRO^~+M>Ura`W$OSsF)VnN2;5 z)a0}JnO%7$NywK}2DWC*EiAaT1Hp8hhXE)x9@tXOtlOq^<5#f=RK6-B!lXw=(D&&| zFtU69GPS$1TB^oY^QXX+&C~B5StDc3)inG;`Q}HyZeQ5;g%Ekff*PNG#dVuYcJ*Hd za}{cwTEK&h{P4{f^olsrE$X?^8(zu{3Z!^&ls(yL57tLBEb?qB0atP9Zg!dq1;|^5 z7BCMvjE28+%Wi$-i(beZFdIq5 z?Q`2$B0%3XYKrCy?!&gRd;7wp0A112uQ<-$b;et)p2XM*eFKKFy{@gV^gpj=?Qe30 z{7wt=fT=@Z@`5Un-;u>VJs&T~H8ih(4~%CMdU+WykJ084(pr;*Z^ZPJa3Xuruy4OF9TqE~cbs}0V6Mg&mlFHMBlkj7S0VcXljP 0: + match_space = re.match(r'^ +', text) + if match_space: + space_str = match_space.group(0) + index_map_from_text_to_word += [None] * len(space_str) + text = text[len(space_str):] + continue + + match_en = re.match(r'^[a-zA-Z0-9]+', text) + if match_en: + en_word = match_en.group(0) + + word_start_pos = len(index_map_from_text_to_word) + word_end_pos = word_start_pos + len(en_word) + index_map_from_word_to_text.append((word_start_pos, word_end_pos)) + + index_map_from_text_to_word += [len(words)] * len(en_word) + + words.append(en_word) + text = text[len(en_word):] + else: + word_start_pos = len(index_map_from_text_to_word) + word_end_pos = word_start_pos + 1 + index_map_from_word_to_text.append((word_start_pos, word_end_pos)) + + index_map_from_text_to_word += [len(words)] + + words.append(text[0]) + text = text[1:] + return words, index_map_from_text_to_word, index_map_from_word_to_text + + +def tokenize_and_map(tokenizer, text: str): + words, text2word, word2text = wordize_and_map(text=text) + + tokens = [] + index_map_from_token_to_text = [] + for word, (word_start, word_end) in zip(words, word2text): + word_tokens = tokenizer.tokenize(word) + + if len(word_tokens) == 0 or word_tokens == ['[UNK]']: + index_map_from_token_to_text.append((word_start, word_end)) + tokens.append('[UNK]') + else: + current_word_start = word_start + for word_token in word_tokens: + word_token_len = len(re.sub(r'^##', '', word_token)) + index_map_from_token_to_text.append( + (current_word_start, current_word_start + word_token_len)) + current_word_start = current_word_start + word_token_len + tokens.append(word_token) + + index_map_from_text_to_token = text2word + for i, (token_start, token_end) in enumerate(index_map_from_token_to_text): + for token_pos in range(token_start, token_end): + index_map_from_text_to_token[token_pos] = i + + return tokens, index_map_from_text_to_token, index_map_from_token_to_text + + +def _load_config(config_path: os.PathLike): + import importlib.util + spec = importlib.util.spec_from_file_location('__init__', config_path) + config = importlib.util.module_from_spec(spec) + spec.loader.exec_module(config) + return config + + +default_config_dict = { + 'manual_seed': 1313, + 'model_source': 'bert-base-chinese', + 'window_size': 32, + 'num_workers': 2, + 'use_mask': True, + 'use_char_phoneme': False, + 'use_conditional': True, + 'param_conditional': { + 'affect_location': 'softmax', + 'bias': True, + 'char-linear': True, + 'pos-linear': False, + 'char+pos-second': True, + 'char+pos-second_lowrank': False, + 'lowrank_size': 0, + 'char+pos-second_fm': False, + 'fm_size': 0, + 'fix_mode': None, + 'count_json': 'train.count.json' + }, + 'lr': 5e-5, + 'val_interval': 200, + 'num_iter': 10000, + 'use_focal': False, + 'param_focal': { + 'alpha': 0.0, + 'gamma': 0.7 + }, + 'use_pos': True, + 'param_pos ': { + 'weight': 0.1, + 'pos_joint_training': True, + 'train_pos_path': 'train.pos', + 'valid_pos_path': 'dev.pos', + 'test_pos_path': 'test.pos' + } +} + + +def load_config(config_path: os.PathLike, use_default: bool=False): + config = _load_config(config_path) + if use_default: + for attr, val in default_config_dict.items(): + if not hasattr(config, attr): + setattr(config, attr, val) + elif isinstance(val, dict): + d = getattr(config, attr) + for dict_k, dict_v in val.items(): + if dict_k not in d: + d[dict_k] = dict_v + return config