From 684e1cfd2f0dd6b8e8e389b36ae2ebd6d13a5a97 Mon Sep 17 00:00:00 2001 From: Downupanddownup Date: Fri, 26 Apr 2024 14:31:54 +0800 Subject: [PATCH] =?UTF-8?q?=E6=96=87=E6=9C=AC=E7=9B=B8=E4=BC=BC=E5=BA=A6?= =?UTF-8?q?=EF=BC=8C=E6=B7=BB=E5=8A=A0GPU=E5=8A=A0=E9=80=9F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../tool/text_comparison/text_comparison.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/Ref_Audio_Selector/tool/text_comparison/text_comparison.py b/Ref_Audio_Selector/tool/text_comparison/text_comparison.py index edb0c7d..156fa53 100644 --- a/Ref_Audio_Selector/tool/text_comparison/text_comparison.py +++ b/Ref_Audio_Selector/tool/text_comparison/text_comparison.py @@ -8,14 +8,19 @@ bert_path = os.environ.get( "bert_path", "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large" ) +# Set device to GPU if available, else CPU +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +print(f'使用计算设备: {device}') + tokenizer = AutoTokenizer.from_pretrained(bert_path) -model = AutoModel.from_pretrained(bert_path) +model = AutoModel.from_pretrained(bert_path).to(device) def calculate_similarity(text1, text2, max_length=512): # 预处理文本,设置最大长度 - inputs1 = tokenizer(text1, padding=True, truncation=True, max_length=max_length, return_tensors='pt') - inputs2 = tokenizer(text2, padding=True, truncation=True, max_length=max_length, return_tensors='pt') + inputs1 = tokenizer(text1, padding=True, truncation=True, max_length=max_length, return_tensors='pt').to(device) + inputs2 = tokenizer(text2, padding=True, truncation=True, max_length=max_length, return_tensors='pt').to(device) # 获取句子向量(这里是取CLS token的向量并展平为一维) with torch.no_grad():