mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-04-05 19:41:56 +08:00
文本相似度,添加GPU加速
This commit is contained in:
parent
878fef248a
commit
684e1cfd2f
@ -8,14 +8,19 @@ bert_path = os.environ.get(
|
||||
"bert_path", "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large"
|
||||
)
|
||||
|
||||
# Set device to GPU if available, else CPU
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
print(f'使用计算设备: {device}')
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(bert_path)
|
||||
model = AutoModel.from_pretrained(bert_path)
|
||||
model = AutoModel.from_pretrained(bert_path).to(device)
|
||||
|
||||
|
||||
def calculate_similarity(text1, text2, max_length=512):
|
||||
# 预处理文本,设置最大长度
|
||||
inputs1 = tokenizer(text1, padding=True, truncation=True, max_length=max_length, return_tensors='pt')
|
||||
inputs2 = tokenizer(text2, padding=True, truncation=True, max_length=max_length, return_tensors='pt')
|
||||
inputs1 = tokenizer(text1, padding=True, truncation=True, max_length=max_length, return_tensors='pt').to(device)
|
||||
inputs2 = tokenizer(text2, padding=True, truncation=True, max_length=max_length, return_tensors='pt').to(device)
|
||||
|
||||
# 获取句子向量(这里是取CLS token的向量并展平为一维)
|
||||
with torch.no_grad():
|
||||
|
Loading…
x
Reference in New Issue
Block a user