mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-10-06 14:40:00 +08:00
Merge 27664703d2fb3c86504d8168ae79639b784c56f7 into b7a904a67153170d334fdc0d7fbae220ee21f0e9
This commit is contained in:
commit
d174fefdc6
@ -550,6 +550,7 @@ class TTS:
|
|||||||
all_phones_len_list = []
|
all_phones_len_list = []
|
||||||
all_bert_features_list = []
|
all_bert_features_list = []
|
||||||
norm_text_batch = []
|
norm_text_batch = []
|
||||||
|
origin_text_batch = []
|
||||||
all_bert_max_len = 0
|
all_bert_max_len = 0
|
||||||
all_phones_max_len = 0
|
all_phones_max_len = 0
|
||||||
for item in item_list:
|
for item in item_list:
|
||||||
@ -575,6 +576,7 @@ class TTS:
|
|||||||
all_phones_len_list.append(all_phones.shape[-1])
|
all_phones_len_list.append(all_phones.shape[-1])
|
||||||
all_bert_features_list.append(all_bert_features)
|
all_bert_features_list.append(all_bert_features)
|
||||||
norm_text_batch.append(item["norm_text"])
|
norm_text_batch.append(item["norm_text"])
|
||||||
|
origin_text_batch.append(item["origin_text"])
|
||||||
|
|
||||||
phones_batch = phones_list
|
phones_batch = phones_list
|
||||||
all_phones_batch = all_phones_list
|
all_phones_batch = all_phones_list
|
||||||
@ -606,6 +608,7 @@ class TTS:
|
|||||||
"all_phones_len": torch.LongTensor(all_phones_len_list).to(device),
|
"all_phones_len": torch.LongTensor(all_phones_len_list).to(device),
|
||||||
"all_bert_features": all_bert_features_batch,
|
"all_bert_features": all_bert_features_batch,
|
||||||
"norm_text": norm_text_batch,
|
"norm_text": norm_text_batch,
|
||||||
|
"origin_text": origin_text_batch,
|
||||||
"max_len": max_len,
|
"max_len": max_len,
|
||||||
}
|
}
|
||||||
_data.append(batch)
|
_data.append(batch)
|
||||||
@ -658,6 +661,7 @@ class TTS:
|
|||||||
"batch_threshold": 0.75, # float. threshold for batch splitting.
|
"batch_threshold": 0.75, # float. threshold for batch splitting.
|
||||||
"split_bucket: True, # bool. whether to split the batch into multiple buckets.
|
"split_bucket: True, # bool. whether to split the batch into multiple buckets.
|
||||||
"return_fragment": False, # bool. step by step return the audio fragment.
|
"return_fragment": False, # bool. step by step return the audio fragment.
|
||||||
|
"return_with_srt": "", # str. return with or without("") subtitles, using "orig"inal or "norm"alized text
|
||||||
"speed_factor":1.0, # float. control the speed of the synthesized audio.
|
"speed_factor":1.0, # float. control the speed of the synthesized audio.
|
||||||
"fragment_interval":0.3, # float. to control the interval of the audio fragment.
|
"fragment_interval":0.3, # float. to control the interval of the audio fragment.
|
||||||
"seed": -1, # int. random seed for reproducibility.
|
"seed": -1, # int. random seed for reproducibility.
|
||||||
@ -685,6 +689,7 @@ class TTS:
|
|||||||
split_bucket = inputs.get("split_bucket", True)
|
split_bucket = inputs.get("split_bucket", True)
|
||||||
return_fragment = inputs.get("return_fragment", False)
|
return_fragment = inputs.get("return_fragment", False)
|
||||||
fragment_interval = inputs.get("fragment_interval", 0.3)
|
fragment_interval = inputs.get("fragment_interval", 0.3)
|
||||||
|
return_with_srt = inputs.get("return_with_srt", "")
|
||||||
seed = inputs.get("seed", -1)
|
seed = inputs.get("seed", -1)
|
||||||
seed = -1 if seed in ["", None] else seed
|
seed = -1 if seed in ["", None] else seed
|
||||||
actual_seed = set_seed(seed)
|
actual_seed = set_seed(seed)
|
||||||
@ -704,6 +709,9 @@ class TTS:
|
|||||||
split_bucket = False
|
split_bucket = False
|
||||||
print(i18n("分段返回模式不支持分桶处理,已自动关闭分桶处理"))
|
print(i18n("分段返回模式不支持分桶处理,已自动关闭分桶处理"))
|
||||||
|
|
||||||
|
ret_width = 3 if return_with_srt else 2 # return (sr, audio, srt) or (sr, audio)
|
||||||
|
srt_text = "norm_text" if return_with_srt.startswith("norm") else "origin_text"
|
||||||
|
|
||||||
if split_bucket and speed_factor==1.0:
|
if split_bucket and speed_factor==1.0:
|
||||||
print(i18n("分桶处理模式已开启"))
|
print(i18n("分桶处理模式已开启"))
|
||||||
elif speed_factor!=1.0:
|
elif speed_factor!=1.0:
|
||||||
@ -773,8 +781,7 @@ class TTS:
|
|||||||
if not return_fragment:
|
if not return_fragment:
|
||||||
data = self.text_preprocessor.preprocess(text, text_lang, text_split_method, self.configs.version)
|
data = self.text_preprocessor.preprocess(text, text_lang, text_split_method, self.configs.version)
|
||||||
if len(data) == 0:
|
if len(data) == 0:
|
||||||
yield self.configs.sampling_rate, np.zeros(int(self.configs.sampling_rate),
|
yield self.audio_failure()[:ret_width]
|
||||||
dtype=np.int16)
|
|
||||||
return
|
return
|
||||||
|
|
||||||
batch_index_list:list = None
|
batch_index_list:list = None
|
||||||
@ -806,6 +813,7 @@ class TTS:
|
|||||||
"phones": phones,
|
"phones": phones,
|
||||||
"bert_features": bert_features,
|
"bert_features": bert_features,
|
||||||
"norm_text": norm_text,
|
"norm_text": norm_text,
|
||||||
|
"origin_text": text,
|
||||||
}
|
}
|
||||||
batch_data.append(res)
|
batch_data.append(res)
|
||||||
if len(batch_data) == 0:
|
if len(batch_data) == 0:
|
||||||
@ -841,10 +849,11 @@ class TTS:
|
|||||||
all_phoneme_ids:torch.LongTensor = item["all_phones"]
|
all_phoneme_ids:torch.LongTensor = item["all_phones"]
|
||||||
all_phoneme_lens:torch.LongTensor = item["all_phones_len"]
|
all_phoneme_lens:torch.LongTensor = item["all_phones_len"]
|
||||||
all_bert_features:torch.LongTensor = item["all_bert_features"]
|
all_bert_features:torch.LongTensor = item["all_bert_features"]
|
||||||
norm_text:str = item["norm_text"]
|
# norm_text:List[str] = item["norm_text"]
|
||||||
|
# origin_text:List[str] = item["origin_text"]
|
||||||
max_len = item["max_len"]
|
max_len = item["max_len"]
|
||||||
|
|
||||||
print(i18n("前端处理后的文本(每句):"), norm_text)
|
print(i18n("前端处理后的文本(每批):"), item["norm_text"])
|
||||||
if no_prompt_text :
|
if no_prompt_text :
|
||||||
prompt = None
|
prompt = None
|
||||||
else:
|
else:
|
||||||
@ -915,39 +924,38 @@ class TTS:
|
|||||||
if return_fragment:
|
if return_fragment:
|
||||||
print("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t4 - t3, t5 - t4))
|
print("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t4 - t3, t5 - t4))
|
||||||
yield self.audio_postprocess([batch_audio_fragment],
|
yield self.audio_postprocess([batch_audio_fragment],
|
||||||
|
[item[srt_text]],
|
||||||
self.configs.sampling_rate,
|
self.configs.sampling_rate,
|
||||||
None,
|
None,
|
||||||
speed_factor,
|
speed_factor,
|
||||||
False,
|
False,
|
||||||
fragment_interval
|
fragment_interval
|
||||||
)
|
)[:ret_width]
|
||||||
else:
|
else:
|
||||||
audio.append(batch_audio_fragment)
|
audio.append(batch_audio_fragment)
|
||||||
|
|
||||||
if self.stop_flag:
|
if self.stop_flag:
|
||||||
yield self.configs.sampling_rate, np.zeros(int(self.configs.sampling_rate),
|
yield self.audio_failure()[:ret_width]
|
||||||
dtype=np.int16)
|
|
||||||
return
|
return
|
||||||
|
|
||||||
if not return_fragment:
|
if not return_fragment:
|
||||||
print("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t_34, t_45))
|
print("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t_34, t_45))
|
||||||
if len(audio) == 0:
|
if len(audio) == 0:
|
||||||
yield self.configs.sampling_rate, np.zeros(int(self.configs.sampling_rate),
|
yield self.audio_failure()[:ret_width]
|
||||||
dtype=np.int16)
|
|
||||||
return
|
return
|
||||||
yield self.audio_postprocess(audio,
|
yield self.audio_postprocess(audio,
|
||||||
|
[v[srt_text] for v in data],
|
||||||
self.configs.sampling_rate,
|
self.configs.sampling_rate,
|
||||||
batch_index_list,
|
batch_index_list,
|
||||||
speed_factor,
|
speed_factor,
|
||||||
split_bucket,
|
split_bucket,
|
||||||
fragment_interval
|
fragment_interval
|
||||||
)
|
)[:ret_width]
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
# 必须返回一个空音频, 否则会导致显存不释放。
|
# 必须返回一个空音频, 否则会导致显存不释放。
|
||||||
yield self.configs.sampling_rate, np.zeros(int(self.configs.sampling_rate),
|
yield self.audio_failure()[:ret_width]
|
||||||
dtype=np.int16)
|
|
||||||
# 重置模型, 否则会导致显存释放不完全。
|
# 重置模型, 否则会导致显存释放不完全。
|
||||||
del self.t2s_model
|
del self.t2s_model
|
||||||
del self.vits_model
|
del self.vits_model
|
||||||
@ -968,15 +976,19 @@ class TTS:
|
|||||||
torch.mps.empty_cache()
|
torch.mps.empty_cache()
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def audio_failure(self):
|
||||||
|
return self.configs.sampling_rate, np.zeros(int(self.configs.sampling_rate), dtype=np.int16), []
|
||||||
|
|
||||||
def audio_postprocess(self,
|
def audio_postprocess(self,
|
||||||
audio:List[torch.Tensor],
|
audio:List[torch.Tensor],
|
||||||
|
texts:List[List[str]],
|
||||||
sr:int,
|
sr:int,
|
||||||
batch_index_list:list=None,
|
batch_index_list:list=None,
|
||||||
speed_factor:float=1.0,
|
speed_factor:float=1.0,
|
||||||
split_bucket:bool=True,
|
split_bucket:bool=True,
|
||||||
fragment_interval:float=0.3
|
fragment_interval:float=0.3
|
||||||
)->Tuple[int, np.ndarray]:
|
)->Tuple[int, np.ndarray, List]:
|
||||||
zero_wav = torch.zeros(
|
zero_wav = torch.zeros(
|
||||||
int(self.configs.sampling_rate * fragment_interval),
|
int(self.configs.sampling_rate * fragment_interval),
|
||||||
dtype=self.precision,
|
dtype=self.precision,
|
||||||
@ -993,11 +1005,17 @@ class TTS:
|
|||||||
|
|
||||||
if split_bucket:
|
if split_bucket:
|
||||||
audio = self.recovery_order(audio, batch_index_list)
|
audio = self.recovery_order(audio, batch_index_list)
|
||||||
|
texts = self.recovery_order(texts, batch_index_list)
|
||||||
else:
|
else:
|
||||||
# audio = [item for batch in audio for item in batch]
|
# audio = [item for batch in audio for item in batch]
|
||||||
audio = sum(audio, [])
|
audio = sum(audio, [])
|
||||||
|
texts = sum(texts, [])
|
||||||
|
|
||||||
|
# 按顺序计算每段语音的起止时间,并与文字一一对应,用于生成字幕
|
||||||
|
from itertools import accumulate
|
||||||
|
stamps = [0.0] + [x/sr for x in accumulate([v.size for v in audio])]
|
||||||
|
srts = list(zip(stamps[:-1], stamps[1:], texts)) # time start, end, text
|
||||||
|
|
||||||
audio = np.concatenate(audio, 0)
|
audio = np.concatenate(audio, 0)
|
||||||
audio = (audio * 32768).astype(np.int16)
|
audio = (audio * 32768).astype(np.int16)
|
||||||
|
|
||||||
@ -1007,7 +1025,7 @@ class TTS:
|
|||||||
# except Exception as e:
|
# except Exception as e:
|
||||||
# print(f"Failed to change speed of audio: \n{e}")
|
# print(f"Failed to change speed of audio: \n{e}")
|
||||||
|
|
||||||
return sr, audio
|
return sr, audio, srts
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@ -69,6 +69,7 @@ class TextPreprocessor:
|
|||||||
"phones": phones,
|
"phones": phones,
|
||||||
"bert_features": bert_features,
|
"bert_features": bert_features,
|
||||||
"norm_text": norm_text,
|
"norm_text": norm_text,
|
||||||
|
"origin_text": text,
|
||||||
}
|
}
|
||||||
result.append(res)
|
result.append(res)
|
||||||
return result
|
return result
|
||||||
|
58
api_v2.py
58
api_v2.py
@ -36,6 +36,7 @@ POST:
|
|||||||
"split_bucket: True, # bool. whether to split the batch into multiple buckets.
|
"split_bucket: True, # bool. whether to split the batch into multiple buckets.
|
||||||
"speed_factor":1.0, # float. control the speed of the synthesized audio.
|
"speed_factor":1.0, # float. control the speed of the synthesized audio.
|
||||||
"streaming_mode": False, # bool. whether to return a streaming response.
|
"streaming_mode": False, # bool. whether to return a streaming response.
|
||||||
|
"with_srt_format": "", # str. ""(no srt) or "raw" or "srt", "lrc", "vtt", ... formats (not implemented yet)
|
||||||
"seed": -1, # int. random seed for reproducibility.
|
"seed": -1, # int. random seed for reproducibility.
|
||||||
"parallel_infer": True, # bool. whether to use parallel inference.
|
"parallel_infer": True, # bool. whether to use parallel inference.
|
||||||
"repetition_penalty": 1.35 # float. repetition penalty for T2S model.
|
"repetition_penalty": 1.35 # float. repetition penalty for T2S model.
|
||||||
@ -98,7 +99,7 @@ RESP:
|
|||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import traceback
|
import traceback
|
||||||
from typing import Generator
|
from typing import Generator, List, Union
|
||||||
|
|
||||||
now_dir = os.getcwd()
|
now_dir = os.getcwd()
|
||||||
sys.path.append(now_dir)
|
sys.path.append(now_dir)
|
||||||
@ -162,6 +163,7 @@ class TTS_Request(BaseModel):
|
|||||||
seed:int = -1
|
seed:int = -1
|
||||||
media_type:str = "wav"
|
media_type:str = "wav"
|
||||||
streaming_mode:bool = False
|
streaming_mode:bool = False
|
||||||
|
with_srt_format:str = ""
|
||||||
parallel_infer:bool = True
|
parallel_infer:bool = True
|
||||||
repetition_penalty:float = 1.35
|
repetition_penalty:float = 1.35
|
||||||
|
|
||||||
@ -211,7 +213,38 @@ def pack_audio(io_buffer:BytesIO, data:np.ndarray, rate:int, media_type:str):
|
|||||||
io_buffer.seek(0)
|
io_buffer.seek(0)
|
||||||
return io_buffer
|
return io_buffer
|
||||||
|
|
||||||
|
def pack_srt(srt:List, fmt:str):
|
||||||
|
if fmt == "raw":
|
||||||
|
return srt
|
||||||
|
# TODO: support formats like "srt", "lrc", "vtt", ...
|
||||||
|
return srt
|
||||||
|
|
||||||
|
def load_base64_audio(audio):
|
||||||
|
import base64
|
||||||
|
if isinstance(audio, (bytes, bytearray)):
|
||||||
|
audio = bytes(audio)
|
||||||
|
elif hasattr(audio, 'read'): # file-like obj
|
||||||
|
audio = audio.read()
|
||||||
|
else: # path-like
|
||||||
|
audio = open(audio, 'rb').read()
|
||||||
|
return base64.b64encode(audio).decode('ascii')
|
||||||
|
|
||||||
|
_base64_audio_cache = {}
|
||||||
|
def save_base64_audio(b64str:str):
|
||||||
|
import filetype, base64, uuid
|
||||||
|
global _base64_audio_cache
|
||||||
|
if b64str in _base64_audio_cache:
|
||||||
|
return _base64_audio_cache[b64str]
|
||||||
|
savedir = 'TEMP/upload'
|
||||||
|
data = base64.b64decode(b64str)
|
||||||
|
ft = filetype.guess(data)
|
||||||
|
ext = f'.{ft.extension}' if ft else ''
|
||||||
|
os.makedirs(savedir, exist_ok=True)
|
||||||
|
saveto = f'{savedir}/{uuid.uuid1()}{ext}'
|
||||||
|
with open(saveto, 'wb') as outf:
|
||||||
|
outf.write(data)
|
||||||
|
_base64_audio_cache[b64str] = saveto
|
||||||
|
return saveto
|
||||||
|
|
||||||
# from https://huggingface.co/spaces/coqui/voice-chat-with-mistral/blob/main/app.py
|
# from https://huggingface.co/spaces/coqui/voice-chat-with-mistral/blob/main/app.py
|
||||||
def wave_header_chunk(frame_input=b"", channels=1, sample_width=2, sample_rate=32000):
|
def wave_header_chunk(frame_input=b"", channels=1, sample_width=2, sample_rate=32000):
|
||||||
@ -277,7 +310,7 @@ async def tts_handle(req:dict):
|
|||||||
{
|
{
|
||||||
"text": "", # str.(required) text to be synthesized
|
"text": "", # str.(required) text to be synthesized
|
||||||
"text_lang: "", # str.(required) language of the text to be synthesized
|
"text_lang: "", # str.(required) language of the text to be synthesized
|
||||||
"ref_audio_path": "", # str.(required) reference audio path
|
"ref_audio_path": "", # str.(required) reference audio path ; allow data of format base64:xxxxxx
|
||||||
"aux_ref_audio_paths": [], # list.(optional) auxiliary reference audio paths for multi-speaker synthesis
|
"aux_ref_audio_paths": [], # list.(optional) auxiliary reference audio paths for multi-speaker synthesis
|
||||||
"prompt_text": "", # str.(optional) prompt text for the reference audio
|
"prompt_text": "", # str.(optional) prompt text for the reference audio
|
||||||
"prompt_lang": "", # str.(required) language of the prompt text for the reference audio
|
"prompt_lang": "", # str.(required) language of the prompt text for the reference audio
|
||||||
@ -293,6 +326,7 @@ async def tts_handle(req:dict):
|
|||||||
"seed": -1, # int. random seed for reproducibility.
|
"seed": -1, # int. random seed for reproducibility.
|
||||||
"media_type": "wav", # str. media type of the output audio, support "wav", "raw", "ogg", "aac".
|
"media_type": "wav", # str. media type of the output audio, support "wav", "raw", "ogg", "aac".
|
||||||
"streaming_mode": False, # bool. whether to return a streaming response.
|
"streaming_mode": False, # bool. whether to return a streaming response.
|
||||||
|
"with_srt_format": "", # str. ""(no srt) or "raw" or "srt", "lrc", "vtt", ... formats (not implemented yet)
|
||||||
"parallel_infer": True, # bool.(optional) whether to use parallel inference.
|
"parallel_infer": True, # bool.(optional) whether to use parallel inference.
|
||||||
"repetition_penalty": 1.35 # float.(optional) repetition penalty for T2S model.
|
"repetition_penalty": 1.35 # float.(optional) repetition penalty for T2S model.
|
||||||
}
|
}
|
||||||
@ -303,6 +337,10 @@ async def tts_handle(req:dict):
|
|||||||
streaming_mode = req.get("streaming_mode", False)
|
streaming_mode = req.get("streaming_mode", False)
|
||||||
return_fragment = req.get("return_fragment", False)
|
return_fragment = req.get("return_fragment", False)
|
||||||
media_type = req.get("media_type", "wav")
|
media_type = req.get("media_type", "wav")
|
||||||
|
with_srt_format = req.get("with_srt_format", "")
|
||||||
|
ref_audio_path = req.get("ref_audio_path", "")
|
||||||
|
if ref_audio_path.startswith("base64:"):
|
||||||
|
req['ref_audio_path'] = ref_audio_path = save_base64_audio(ref_audio_path[len("base64:"):])
|
||||||
|
|
||||||
check_res = check_params(req)
|
check_res = check_params(req)
|
||||||
if check_res is not None:
|
if check_res is not None:
|
||||||
@ -310,7 +348,10 @@ async def tts_handle(req:dict):
|
|||||||
|
|
||||||
if streaming_mode or return_fragment:
|
if streaming_mode or return_fragment:
|
||||||
req["return_fragment"] = True
|
req["return_fragment"] = True
|
||||||
|
|
||||||
|
if streaming_mode: with_srt_format = "" # streaming not support srt
|
||||||
|
req["return_with_srt"] = "orig" if with_srt_format else ""
|
||||||
|
|
||||||
try:
|
try:
|
||||||
tts_generator=tts_pipeline.run(req)
|
tts_generator=tts_pipeline.run(req)
|
||||||
|
|
||||||
@ -324,6 +365,16 @@ async def tts_handle(req:dict):
|
|||||||
# _media_type = f"audio/{media_type}" if not (streaming_mode and media_type in ["wav", "raw"]) else f"audio/x-{media_type}"
|
# _media_type = f"audio/{media_type}" if not (streaming_mode and media_type in ["wav", "raw"]) else f"audio/x-{media_type}"
|
||||||
return StreamingResponse(streaming_generator(tts_generator, media_type, ), media_type=f"audio/{media_type}")
|
return StreamingResponse(streaming_generator(tts_generator, media_type, ), media_type=f"audio/{media_type}")
|
||||||
|
|
||||||
|
elif with_srt_format:
|
||||||
|
output = []
|
||||||
|
for sr, audio_data, srt_data in tts_generator:
|
||||||
|
audio_data = pack_audio(BytesIO(), audio_data, sr, media_type).getvalue()
|
||||||
|
output.append({
|
||||||
|
"audio": load_base64_audio(audio_data), "media_type": f"audio/{media_type}",
|
||||||
|
"srt": pack_srt(srt_data, with_srt_format), "srt_fmt": with_srt_format,
|
||||||
|
})
|
||||||
|
return { "message":"succeed", "output":output } # Jsonresponse(status_code=200, content=...)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
sr, audio_data = next(tts_generator)
|
sr, audio_data = next(tts_generator)
|
||||||
audio_data = pack_audio(BytesIO(), audio_data, sr, media_type).getvalue()
|
audio_data = pack_audio(BytesIO(), audio_data, sr, media_type).getvalue()
|
||||||
@ -364,6 +415,7 @@ async def tts_get_endpoint(
|
|||||||
seed:int = -1,
|
seed:int = -1,
|
||||||
media_type:str = "wav",
|
media_type:str = "wav",
|
||||||
streaming_mode:bool = False,
|
streaming_mode:bool = False,
|
||||||
|
with_srt_format:str = "",
|
||||||
parallel_infer:bool = True,
|
parallel_infer:bool = True,
|
||||||
repetition_penalty:float = 1.35
|
repetition_penalty:float = 1.35
|
||||||
):
|
):
|
||||||
|
@ -34,3 +34,4 @@ opencc; sys_platform != 'linux'
|
|||||||
opencc==1.1.1; sys_platform == 'linux'
|
opencc==1.1.1; sys_platform == 'linux'
|
||||||
python_mecab_ko; sys_platform != 'win32'
|
python_mecab_ko; sys_platform != 'win32'
|
||||||
fastapi<0.112.2
|
fastapi<0.112.2
|
||||||
|
filetype
|
||||||
|
Loading…
x
Reference in New Issue
Block a user