stream_v2pro: 新增 --lang 参数提示参考文字的语言类型

This commit is contained in:
csh 2025-08-18 16:26:29 +08:00
parent 60f07ea36e
commit d313fbc740

View File

@ -444,6 +444,7 @@ def export_prov2(
output_path,
device="cpu",
is_half=True,
lang="auto",
):
if export_torch_script.sv_cn_model == None:
init_sv_cn(device,is_half)
@ -454,7 +455,7 @@ def export_prov2(
print(f"device: {device}")
ref_seq_id, ref_bert_T, ref_norm_text = get_phones_and_bert(
ref_text, "all_zh", "v2"
ref_text, lang, "v2"
)
ref_seq = torch.LongTensor([ref_seq_id]).to(device)
ref_bert = ref_bert_T.T
@ -503,9 +504,9 @@ def export_prov2(
stream_t2s = torch.jit.script(stream_t2s)
ref_audio_sr = resamplex(ref_audio, 16000, 32000)
ref_audio_sr = ref_audio_sr.to(device)
if is_half:
ref_audio_sr = ref_audio_sr.half()
ref_audio_sr = ref_audio_sr.to(device)
top_k = 15
@ -588,6 +589,7 @@ if __name__ == "__main__":
parser.add_argument("--device", help="Device to use", default="cuda" if torch.cuda.is_available() else "cpu")
parser.add_argument("--version", help="version of the model", default="v2Pro")
parser.add_argument("--no-half", action="store_true", help = "Do not use half precision for model weights")
parser.add_argument("--lang", default="auto", help="Language for text processing (default: auto)")
args = parser.parse_args()
@ -605,4 +607,5 @@ if __name__ == "__main__":
output_path=args.output_path,
device=args.device,
is_half=is_half,
lang=args.lang,
)