赞
踩
前面我们已完成在Qdrant创建了startups集合,导入了startups_demo.json数据,让我们开始构建神经搜索类。
为了处理传入请求,神经搜索需要两件事:1)将查询转换为向量的模型,2)Qdrant 客户端来执行搜索查询。
- from qdrant_client import QdrantClient
- from sentence_transformers import SentenceTransformer
-
-
- class NeuralSearcher:
- def __init__(self, collection_name):
- self.collection_name = collection_name
- # Initialize encoder model
- self.model = SentenceTransformer("all-MiniLM-L6-v2", device="cpu")
- # initialize Qdrant client
- self.qdrant_client = QdrantClient("http://localhost:6333")
- def search(self, text: str):
- vector = self.model.encode(text).tolist()
-
- search_result = self.qdrant_client.search(
- collection_name=self.collection_name,
- query_vector=vector,
- query_filter=None,
- limit=5
- )
-
- payloads = [hit.payload for hit in search_result]
- return payloads
现在已经创建了一个用于神经搜索查询的类。现在将其包装到服务中
要构建该服务,您将使用 FastAPI 框架。
要安装它,请使用命令
pip install fastapi uvicorn
创建一个名为的文件service.py并指定以下内容。
该服务只有一个 API 端点,如下所示:
- from fastapi import FastAPI
-
- # The file where NeuralSearcher is stored
- from neural_searcher import NeuralSearcher
-
- app = FastAPI()
-
- # Create a neural searcher instance
- neural_searcher = NeuralSearcher(collection_name='startups')
-
- @app.get("/api/search")
- def search_startup(q: str):
- return {
- "result": neural_searcher.search(text=q)
- # "result": neural_searcher.async_search(text=q) # 异步非阻塞
- }
-
-
- if __name__ == "__main__":
- import uvicorn
- uvicorn.run(app, host="0.0.0.0", port=8000)
python service.py
打开浏览器http://localhost:8000/docs
就可以看到服务的调试界面
请随意使用它,查询语料库中的公司,并查看结果
http://127.0.0.1:8000/api/search?q=Artificial%20intelligence%20machine%20learning
调用QdrantClient替换为AsyncQdrantClient,AsyncQdrantClient提供与同步对应项相同的方法QdrantClient
注:AsyncQdrantClient提供与同步对应项相同的方法QdrantClient,异步是在qdrant-client1.6.1版本中引入
- from qdrant_client import QdrantClient, AsyncQdrantClient
-
- class NeuralSearcher:
- # .../ init()
- # 异步查询
- async def async_search(self, text: str):
- # AsyncQdrantClient提供与同步对应项相同的方法QdrantClient,异步客户端是在qdrant-client1.6.1版本中引入
- client = AsyncQdrantClient("http://localhost:6333")
-
- vector = self.model.encode(text).tolist()
-
- search_result = await client.search(
- collection_name=self.collection_name,
- query_vector=vector,
- query_filter=None,
- limit=5
- )
-
- payloads = [hit.payload for hit in search_result]
- return payloads
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。