make sure ort providers available

This commit is contained in:
KamioRinn 2025-06-27 02:32:19 +08:00
parent ed89a02337
commit 2b85240c9c

View File

@ -93,13 +93,13 @@ class G2PWOnnxConverter:
sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
sess_options.execution_mode = onnxruntime.ExecutionMode.ORT_SEQUENTIAL sess_options.execution_mode = onnxruntime.ExecutionMode.ORT_SEQUENTIAL
sess_options.intra_op_num_threads = 2 if torch.cuda.is_available() else 0 sess_options.intra_op_num_threads = 2 if torch.cuda.is_available() else 0
try: if "CUDAExecutionProvider" in onnxruntime.get_available_providers():
self.session_g2pW = onnxruntime.InferenceSession( self.session_g2pW = onnxruntime.InferenceSession(
os.path.join(uncompress_path, "g2pW.onnx"), os.path.join(uncompress_path, "g2pW.onnx"),
sess_options=sess_options, sess_options=sess_options,
providers=["CUDAExecutionProvider", "CPUExecutionProvider"], providers=["CUDAExecutionProvider", "CPUExecutionProvider"],
) )
except: else:
self.session_g2pW = onnxruntime.InferenceSession( self.session_g2pW = onnxruntime.InferenceSession(
os.path.join(uncompress_path, "g2pW.onnx"), os.path.join(uncompress_path, "g2pW.onnx"),
sess_options=sess_options, sess_options=sess_options,