赞
踩
引入 ElasticSearch
依赖库:
pip install elasticsearch -i https://pypi.tuna.tsinghua.edu.cn/simple
from elasticsearch import Elasticsearch
from transformers import BertTokenizer, BertModel
import torch
import pandas as pd
def embeddings_doc(doc, tokenizer, model, max_length=300):
encoded_dict = tokenizer.encode_plus(
doc,
add_special_tokens=True,
max_length=max_length,
padding=‘max_length’,
truncation=True,
return_attention_mask=True,
return_tensors=‘pt’
)
input_id = encoded_dict[‘input_ids’]
attention_mask = encoded_dict[‘attention_mask’]
with torch.no_grad():
outputs = model(input_id, attention_mask=attention_mask)
last_hidden_state = outputs.last_hidden_state
cls_embeddings = last_hidden_state[:, 0, :]
return cls_embeddings[0]
def add_doc(index_name, id, embedding_ask, ask, answer, es):
body = {
“ask_vector”: embedding_ask.tolist(),
“ask”: ask,
“answer”: answer
}
result = es.create(index=index_name, id=id, doc_type=“_doc”, body=body)
return result
def main():
model_name = ‘D:\AIGC\model\chinese-roberta-wwm-ext-large’
es_host = “http://127.0.0.1”
es_port = 9200
es_user = “elastic”
es_password = “elastic”
index_name = “medical_index”
path = “D:\AIGC\dataset\Chinese-medical-dialogue-data\Chinese-medical-dialogue-data\Data_数据\IM_内科\内科5000-33000.csv”
tokenizer = BertTokenizer.from_pretrained(model_name)
model = BertModel.from_pretrained(model_name)
es = Elasticsearch(
[es_host],
port=es_port,
http_auth=(es_user, es_password)
)
data = pd.read_csv(path, encoding=‘ANSI’)
for index, row in data.iterrows():
if index >= 500:
break
ask = row[“ask”]
answer = row[“answer”]
embedding_ask = embeddings_doc(ask, tokenizer, model)
result = add_doc(index_name, index, embedding_ask, ask, answer, es)
print(result)
if name == ‘__main__’:
main()
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。