diff --git a/Ref_Audio_Selector/tool/text_comparison/text_comparison.py b/Ref_Audio_Selector/tool/text_comparison/text_comparison.py index edb0c7d..156fa53 100644 --- a/Ref_Audio_Selector/tool/text_comparison/text_comparison.py +++ b/Ref_Audio_Selector/tool/text_comparison/text_comparison.py @@ -8,14 +8,19 @@ bert_path = os.environ.get( "bert_path", "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large" ) +# Set device to GPU if available, else CPU +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +print(f'使用计算设备: {device}') + tokenizer = AutoTokenizer.from_pretrained(bert_path) -model = AutoModel.from_pretrained(bert_path) +model = AutoModel.from_pretrained(bert_path).to(device) def calculate_similarity(text1, text2, max_length=512): # 预处理文本,设置最大长度 - inputs1 = tokenizer(text1, padding=True, truncation=True, max_length=max_length, return_tensors='pt') - inputs2 = tokenizer(text2, padding=True, truncation=True, max_length=max_length, return_tensors='pt') + inputs1 = tokenizer(text1, padding=True, truncation=True, max_length=max_length, return_tensors='pt').to(device) + inputs2 = tokenizer(text2, padding=True, truncation=True, max_length=max_length, return_tensors='pt').to(device) # 获取句子向量(这里是取CLS token的向量并展平为一维) with torch.no_grad():