156 lines
6 KiB
Python
156 lines
6 KiB
Python
import json
|
|
import pathlib
|
|
|
|
import torch
|
|
|
|
from haystack import Document
|
|
from haystack.utils import ComponentDevice
|
|
from haystack import Pipeline
|
|
|
|
from haystack.components.embedders import SentenceTransformersDocumentEmbedder
|
|
from haystack.components.writers import DocumentWriter
|
|
|
|
from haystack_integrations.document_stores.pgvector import PgvectorDocumentStore
|
|
from haystack_integrations.components.retrievers.pgvector import PgvectorEmbeddingRetriever, PgvectorKeywordRetriever
|
|
|
|
from haystack.components.embedders import SentenceTransformersTextEmbedder
|
|
from haystack.components.joiners import DocumentJoiner
|
|
from haystack.components.rankers import TransformersSimilarityRanker
|
|
|
|
class AIBackend:
|
|
model_embeddings = "BAAI/bge-base-en-v1.5"
|
|
model_ranker = "BAAI/bge-reranker-base"
|
|
_ready = False
|
|
|
|
query_pipeline: Pipeline
|
|
index_pipeline: Pipeline
|
|
document_store: PgvectorDocumentStore
|
|
documents: list[Document] = []
|
|
|
|
def __init__(self, load_dataset=False):
|
|
get_torch_info()
|
|
try:
|
|
self.gpu = ComponentDevice.from_str("cuda:0")
|
|
except:
|
|
self.gpu = None
|
|
print("No CUDA gpu device found")
|
|
|
|
if load_dataset:
|
|
dataset = pathlib.Path(__file__).parents[1] / "data" / "dataset.jsonl"
|
|
self.documents = [ Document(content=d["text"], meta=d["meta"]) for d in load_data(dataset) ]
|
|
|
|
self.document_store = PgvectorDocumentStore(
|
|
embedding_dimension=768,
|
|
vector_function="cosine_similarity",
|
|
recreate_table=False,
|
|
search_strategy="hnsw",
|
|
)
|
|
|
|
self.index_pipeline = self._create_indexing_pipeline()
|
|
self.query_pipeline = self._create_query_pipeline()
|
|
|
|
def warmup(self):
|
|
print("Running warmup routine ...")
|
|
print("Launching indexing pipeline to generate document embeddings")
|
|
res = self.index_pipeline.run({"document_embedder": {"documents": self.documents}})
|
|
print(f"Finished running indexing pipeline\nDocument Store: Wrote {res['document_writer']['documents_written']} documents")
|
|
self._ready = True
|
|
print("'.query(\"text\")' is now ready to be used")
|
|
|
|
def _create_indexing_pipeline(self):
|
|
print("Creating indexing pipeline ...")
|
|
document_embedder = SentenceTransformersDocumentEmbedder(model=self.model_embeddings, device=self.gpu)
|
|
document_writer = DocumentWriter(document_store=self.document_store)
|
|
|
|
indexing_pipeline = Pipeline()
|
|
indexing_pipeline.add_component("document_embedder", document_embedder)
|
|
indexing_pipeline.add_component("document_writer", document_writer)
|
|
|
|
indexing_pipeline.connect("document_embedder", "document_writer")
|
|
|
|
return indexing_pipeline
|
|
|
|
def _create_query_pipeline(self):
|
|
print("Creating hybrid retrival pipeline ...")
|
|
text_embedder = SentenceTransformersTextEmbedder(model=self.model_embeddings, device=self.gpu)
|
|
ranker = TransformersSimilarityRanker(model=self.model_ranker, device=self.gpu)
|
|
|
|
embedding_retriever = PgvectorEmbeddingRetriever(document_store=self.document_store)
|
|
keyword_retriever = PgvectorKeywordRetriever(document_store=self.document_store)
|
|
|
|
document_joiner = DocumentJoiner()
|
|
|
|
hybrid_retrieval = Pipeline()
|
|
hybrid_retrieval.add_component("text_embedder", text_embedder)
|
|
hybrid_retrieval.add_component("embedding_retriever", embedding_retriever)
|
|
hybrid_retrieval.add_component("keyword_retriever", keyword_retriever)
|
|
hybrid_retrieval.add_component("document_joiner", document_joiner)
|
|
hybrid_retrieval.add_component("ranker", ranker)
|
|
|
|
hybrid_retrieval.connect("text_embedder", "embedding_retriever")
|
|
hybrid_retrieval.connect("keyword_retriever", "document_joiner")
|
|
hybrid_retrieval.connect("embedding_retriever", "document_joiner")
|
|
hybrid_retrieval.connect("document_joiner", "ranker")
|
|
|
|
return hybrid_retrieval
|
|
|
|
def query(self, query: str):
|
|
if not self._ready:
|
|
raise SystemError("Cannot query when warmup hasn't been run yet")
|
|
|
|
return self.query_pipeline.run(
|
|
data={
|
|
"text_embedder": { "text": query },
|
|
"keyword_retriever": { "query": query },
|
|
"ranker": { "query": query, "top_k": 5 }
|
|
}
|
|
)
|
|
|
|
@staticmethod
|
|
def format_result(result):
|
|
result_table = []
|
|
x: Document | None
|
|
for x in result["ranker"]["documents"]:
|
|
if x is None:
|
|
continue
|
|
result_table.append([f"[ {x.meta['id']:4} ]", x.meta["title"], x.meta["url"]])
|
|
return result_table
|
|
|
|
@staticmethod
|
|
def format_for_api(result):
|
|
results = []
|
|
x: Document | None
|
|
for x in result["ranker"]["documents"]:
|
|
if x is None:
|
|
continue
|
|
results.append({
|
|
"id": x.meta["id"],
|
|
"title": x.meta["title"],
|
|
"url": x.meta["url"],
|
|
"image_url": x.meta["image_url"]
|
|
})
|
|
return results
|
|
|
|
|
|
def load_data(dataset_path):
|
|
data = []
|
|
|
|
with open(dataset_path, "r") as f:
|
|
for x in f.readlines():
|
|
j: dict = json.loads(x)
|
|
j.update({
|
|
"text": f"{j['title']} | {j['transcript']} | {j['explanation']}",
|
|
"meta": { "title": j["title"], "url": j["url"], "image_url": j["image_url"], "id": j["id"] }
|
|
}
|
|
)
|
|
data.append(j)
|
|
return data
|
|
|
|
def get_torch_info():
|
|
print("---------- Getting information about pytorch setup ----------")
|
|
print(f"Is CUDA or ROCm available? { 'Yes' if torch.cuda.is_available() else 'No'}")
|
|
print("Available devices:")
|
|
for i in range(torch.cuda.device_count()):
|
|
dev = torch.cuda.get_device_properties(i)
|
|
print(f"- [{i}] {dev.name} [ {dev.multi_processor_count} processors, {dev.total_memory / 1_000_000_000:.2f} GB ]")
|
|
print("-------------------------------------------------------------")
|