diff --git a/.gitignore b/.gitignore index 43a9d49..1ca6115 100644 --- a/.gitignore +++ b/.gitignore @@ -4,4 +4,8 @@ # Binaries for installation *.deb -*.whl \ No newline at end of file +*.whl + +# Test data +data/pgvector.sql +backend.Dockerfile diff --git a/data/dataset.jsonl b/data/dataset.jsonl index de990bb..518dc58 100644 --- a/data/dataset.jsonl +++ b/data/dataset.jsonl @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:f43890db5669aa2af6193ef7c9a32e4126b8c9271f8055b307dd18ac1a54a6b9 -size 12942390 +oid sha256:74ecc3c0ee4a211bf498ef676b034b9fb85f8cf2e67eeca1b698d319e5e2a2a3 +size 15167159 diff --git a/data/init.sql b/data/init.sql new file mode 100644 index 0000000..76571af --- /dev/null +++ b/data/init.sql @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:458fca4a56476e81656f8e486ee37dee786dc3e6269f11819f38d44ea1ca687e +size 224905425 diff --git a/docker-compose.yml b/docker-compose.yml new file mode 100644 index 0000000..b332faa --- /dev/null +++ b/docker-compose.yml @@ -0,0 +1,12 @@ +name: xkcd-finder +services: + pgvector: + ports: + - 5432:5432 + environment: + - POSTGRES_USER=postgres + - POSTGRES_PASSWORD=postgres + - POSTGRES_DB=postgres + image: ankane/pgvector + volumes: + - ./data/init.sql:/docker-entrypoint-initdb.d/init.sql diff --git a/hybrid-retrieval.png b/hybrid-retrieval.png index 8d446ad..9332102 100644 Binary files a/hybrid-retrieval.png and b/hybrid-retrieval.png differ diff --git a/nlp.ipynb b/nlp.ipynb index cabe4a7..6f41d96 100644 --- a/nlp.ipynb +++ b/nlp.ipynb @@ -258,7 +258,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "705b147162734470908bbf2ab6d45db3", + "model_id": "10c0a04365264393851ab80b8bd3d3ed", "version_major": 2, "version_minor": 0 }, @@ -316,7 +316,7 @@ { "data": { "text/plain": [ - "<haystack.core.pipeline.pipeline.Pipeline object at 0x7ef6cfa25900>\n", + "<haystack.core.pipeline.pipeline.Pipeline object at 0x7f5dd4730e80>\n", "🚅 Components\n", " - text_embedder: SentenceTransformersTextEmbedder\n", " - embedding_retriever: InMemoryEmbeddingRetriever\n", @@ -371,29 +371,29 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Results for 'Give an example for dependency'\n", + "Results for 'dependency'\n", "\n", - "| ID | Title | Link |\n", - "|----------|------------------|---------------------------|\n", - "| [ 1579 ] | Tech Loops | https://www.xkcd.com/1579 |\n", - "| [ 1906 ] | Making Progress | https://www.xkcd.com/1906 |\n", - "| [ 2102 ] | Internet Archive | https://www.xkcd.com/2102 |\n", - "| [ 2347 ] | Dependency | https://www.xkcd.com/2347 |\n", - "| [ 2166 ] | Stack | https://www.xkcd.com/2166 |\n" + "| ID | Title | Link |\n", + "|----------|--------------------------|---------------------------|\n", + "| [ 1579 ] | Tech Loops | https://www.xkcd.com/1579 |\n", + "| [ 2102 ] | Internet Archive | https://www.xkcd.com/2102 |\n", + "| [ 1906 ] | Making Progress | https://www.xkcd.com/1906 |\n", + "| [ 2347 ] | Dependency | https://www.xkcd.com/2347 |\n", + "| [ 1654 ] | Universal Install Script | https://www.xkcd.com/1654 |\n" ] } ], "source": [ "from tabulate import tabulate\n", "\n", - "query = \"Give an example for dependency\"\n", + "query = \"dependency\"\n", "\n", "result = hybrid_retrieval.run(\n", " data={\n", diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..31e8f33 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,10 @@ +haystack-ai~=2.7.0 +torch~=2.0.0 +gevent~=24.11.1 +Flask~=3.1.0 +pgvector-haystack~=1.2.0 +tabulate~=0.9.0 +transformers~=4.46.3 +accelerate~=1.1.1 +numpy<2 +sentence-transformers>=3.0.0 \ No newline at end of file diff --git a/server-implementation/__init__.py b/server-implementation/__init__.py new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/server-implementation/__init__.py @@ -0,0 +1 @@ + diff --git a/server-implementation/__pycache__/backend.cpython-310.pyc b/server-implementation/__pycache__/backend.cpython-310.pyc new file mode 100644 index 0000000..bc486c3 Binary files /dev/null and b/server-implementation/__pycache__/backend.cpython-310.pyc differ diff --git a/server-implementation/__pycache__/config_backend.cpython-310.pyc b/server-implementation/__pycache__/config_backend.cpython-310.pyc new file mode 100644 index 0000000..ae2d370 Binary files /dev/null and b/server-implementation/__pycache__/config_backend.cpython-310.pyc differ diff --git a/server-implementation/app.py b/server-implementation/app.py new file mode 100644 index 0000000..df75429 --- /dev/null +++ b/server-implementation/app.py @@ -0,0 +1,51 @@ +import pathlib +from gevent.pywsgi import WSGIServer +from flask import Flask, request, render_template +from tabulate import tabulate +from wtforms import Form, StringField, SubmitField +from wtforms.validators import DataRequired + +from backend import AIBackend + +app = Flask("xkcd_retriever", template_folder=pathlib.Path(__file__).parent / "templates") + +print("Initializing backend") +backend = AIBackend() + +# backend.warmup() # Only needed when there is no data in the pgvector database +backend._ready = True +print("AI backend initialized ...") + +class BasicForm(Form): + ids = StringField("ID",validators=[DataRequired()]) + submit = SubmitField("Submit") + +@app.route("/",methods =['POST','GET']) +def main(): + form = BasicForm() + return render_template("index.html", form = form) + +@app.route("/search",methods =['POST','GET']) +def results(): + topic = request.form.get('ids') + + result = backend.query(f"Give an example for {topic}") + + headers = ["ID", "Title", "Link"] + print(tabulate(backend.format_result(result), headers=headers, tablefmt="github")) + + res = backend.format_for_api(result) + return render_template("results.html", results=res, topic=topic) + +@app.route("/find") +def query_backend(): + topic = request.args.get('topic') + result = backend.query(f"Give an example for {topic}") + + headers = ["ID", "Title", "Link"] + print(tabulate(backend.format_result(result), headers=headers, tablefmt="github")) + + return backend.format_for_api(result) + +http_server = WSGIServer(("0.0.0.0", 8000), app) +http_server.serve_forever() diff --git a/server-implementation/backend.py b/server-implementation/backend.py new file mode 100644 index 0000000..ffdbd6d --- /dev/null +++ b/server-implementation/backend.py @@ -0,0 +1,161 @@ +import json +import pathlib +import config_backend + +if config_backend.needs_torch: + import torch + +from haystack import Document +from haystack.utils import ComponentDevice +from haystack import Pipeline +from haystack.components.embedders import SentenceTransformersDocumentEmbedder +from haystack.components.preprocessors.document_splitter import DocumentSplitter +from haystack.components.writers import DocumentWriter + +from haystack_integrations.document_stores.pgvector import PgvectorDocumentStore +from haystack_integrations.components.retrievers.pgvector import PgvectorEmbeddingRetriever, PgvectorKeywordRetriever + +from haystack.components.embedders import SentenceTransformersTextEmbedder +from haystack.components.joiners import DocumentJoiner +from haystack.components.rankers import TransformersSimilarityRanker + +class AIBackend: + model_embeddings = "BAAI/bge-base-en-v1.5" + model_ranker = "BAAI/bge-reranker-base" + _ready = False + + query_pipeline: Pipeline + index_pipeline: Pipeline + document_store: PgvectorDocumentStore + documents: list[Document] = [] + + def __init__(self): + if config_backend.needs_torch: + get_torch_info() + self.gpu = ComponentDevice.from_str("cuda:0") + + dataset = pathlib.Path(__file__).parents[1] / "data" / "dataset.jsonl" + if config_backend.load_dataset: + self.documents = [ Document(content=d["text"], meta=d["meta"]) for d in load_data(dataset) ] + self.document_store = PgvectorDocumentStore( + embedding_dimension=768, + vector_function="cosine_similarity", + recreate_table=False, + search_strategy="hnsw", + ) + + self.index_pipeline = self._create_indexing_pipeline() + self.query_pipeline = self._create_query_pipeline() + + def warmup(self): + print("Running warmup routine ...") + print("Launching indexing pipeline to generate document embeddings") + res = self.index_pipeline.run({"document_splitter": {"documents": self.documents}}) + print(f"Finished running indexing pipeline\nDocument Store: Wrote {res['document_writer']['documents_written']} documents") + self._ready = True + print("'.query(\"text\")' is now ready to be used") + + def _create_indexing_pipeline(self): + print("Creating indexing pipeline ...") + document_splitter = DocumentSplitter(split_by="word", split_length=128, split_overlap=4) + if config_backend.needs_torch: + document_embedder = SentenceTransformersDocumentEmbedder(model=self.model_embeddings, device=self.gpu) + else: + document_embedder = SentenceTransformersDocumentEmbedder(model=self.model_embeddings) + document_writer = DocumentWriter(document_store=self.document_store) + + indexing_pipeline = Pipeline() + indexing_pipeline.add_component("document_splitter", document_splitter) + indexing_pipeline.add_component("document_embedder", document_embedder) + indexing_pipeline.add_component("document_writer", document_writer) + + indexing_pipeline.connect("document_splitter", "document_embedder") + indexing_pipeline.connect("document_embedder", "document_writer") + + return indexing_pipeline + + def _create_query_pipeline(self): + print("Creating hybrid retrival pipeline ...") + if config_backend.needs_torch: + text_embedder = SentenceTransformersTextEmbedder(model=self.model_embeddings, device=self.gpu) + ranker = TransformersSimilarityRanker(model=self.model_ranker, device=self.gpu) + else: + text_embedder = SentenceTransformersTextEmbedder(model=self.model_embeddings) + ranker = TransformersSimilarityRanker(model=self.model_ranker) + embedding_retriever = PgvectorEmbeddingRetriever(document_store=self.document_store) + keyword_retriever = PgvectorKeywordRetriever(document_store=self.document_store) + document_joiner = DocumentJoiner() + + hybrid_retrieval = Pipeline() + hybrid_retrieval.add_component("text_embedder", text_embedder) + hybrid_retrieval.add_component("embedding_retriever", embedding_retriever) + hybrid_retrieval.add_component("keyword_retriever", keyword_retriever) + hybrid_retrieval.add_component("document_joiner", document_joiner) + hybrid_retrieval.add_component("ranker", ranker) + + hybrid_retrieval.connect("text_embedder", "embedding_retriever") + hybrid_retrieval.connect("keyword_retriever", "document_joiner") + hybrid_retrieval.connect("embedding_retriever", "document_joiner") + hybrid_retrieval.connect("document_joiner", "ranker") + + return hybrid_retrieval + + def query(self, query: str): + if not self._ready: + raise SystemError("Cannot query when warmup hasn't been run yet") + + return self.query_pipeline.run( + data={ + "text_embedder": { "text": query }, + "keyword_retriever": { "query": query }, + "ranker": { "query": query, "top_k": 5 } + } + ) + + @staticmethod + def format_result(result): + result_table = [] + x: Document | None + for x in result["ranker"]["documents"]: + if x is None: + continue + result_table.append([f"[ {x.meta['id']:4} ]", x.meta["title"], x.meta["url"]]) + return result_table + + @staticmethod + def format_for_api(result): + results = [] + x: Document | None + for x in result["ranker"]["documents"]: + if x is None: + continue + results.append({ + "id": x.meta["id"], + "title": x.meta["title"], + "url": x.meta["url"] + }) + return results + + +def load_data(dataset_path): + data = [] + + with open(dataset_path, "r") as f: + for x in f.readlines(): + j: dict = json.loads(x) + j.update({ + "text": f"{j['title']} | {j['transcript']} | {j['explanation']}", + "meta": { "title": j["title"], "url": j["url"], "image_url": j["image_url"], "id": j["id"] } + } + ) + data.append(j) + return data + +def get_torch_info(): + print("---------- Getting information about pytorch setup ----------") + print(f"Is CUDA or ROCm available? { 'Yes' if torch.cuda.is_available() else 'No'}") + print("Available devices:") + for i in range(torch.cuda.device_count()): + dev = torch.cuda.get_device_properties(i) + print(f"- [{i}] {dev.name} [ {dev.multi_processor_count} processors, {dev.total_memory / 1_000_000_000:.2f} GB ]") + print("-------------------------------------------------------------") diff --git a/server-implementation/config_backend.py b/server-implementation/config_backend.py new file mode 100644 index 0000000..7966785 --- /dev/null +++ b/server-implementation/config_backend.py @@ -0,0 +1,2 @@ +needs_torch = True +load_dataset = True \ No newline at end of file diff --git a/server-implementation/templates/index.html b/server-implementation/templates/index.html new file mode 100644 index 0000000..8fe8bb7 --- /dev/null +++ b/server-implementation/templates/index.html @@ -0,0 +1,101 @@ +<!DOCTYPE html> +<head> + <title>xkcd-finder</title> +</head> +<body> + <div> + <h1 class="head">xkcd-finder</h1> + <h2 class="sub">Find me an xkcd for the topic ...</h2> + <form action="/search" method="POST"> + {{form.csrf_token}} + {{form.ids}} + {{form.submit}} + </form> + </div> +</body> + +<style> + * { + font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif; + } + + body { + background: #1b1b1b; + margin: 0px; + width: auto; + } + + div { + display: flex; + flex-flow: column; + align-items: center; + justify-content: center; + padding-top: 12.5rem; + } + + form { + display: flex; + flex-flow: column; + width: 45%; + height: auto; + align-items: center; + } + + form > input[type=text] { + font-weight: 500; + color: #eee; + padding: 0.5rem 0.75rem; + text-decoration: none; + background-color: #333; + border-radius: 10px; + border-style: solid; + border-width: 2px; + border-color: #b86cff; + height: 2rem; + width: 100%; + font-size: 2rem; + text-align: center; + } + + form > input[type=submit] { + font-weight: 500; + font-size: 1.5rem; + color: #eee; + margin: 1rem 0px 0px; + padding: 0.5rem 0.75rem; + text-decoration: none; + background-color: #333; + border-radius: 10px; + border-style: solid; + border-width: 2px; + border-color: #333; + text-transform: capitalize; + text-align: center; + width: 50%; + } + + form > input[type=submit]:hover { + border-color: #b86cff; + } + + .head { + margin: 0 .5rem; + hyphens: auto; + padding-top: 2rem; + font-size: 3.25rem; + font-weight: 700; + } + + .sub { + margin: 0 .5rem; + font-size: 1.2rem; + font-weight: 500; + color: #eee; + padding-bottom: 1rem; + } + + .head { + color: #b86cff; + } + +</style> \ No newline at end of file diff --git a/server-implementation/templates/results.html b/server-implementation/templates/results.html new file mode 100644 index 0000000..1a09d22 --- /dev/null +++ b/server-implementation/templates/results.html @@ -0,0 +1,151 @@ +<!doctype html> +<head> + <meta charset="UTF-8" /> + <meta name="viewport" content="width=device-width" /> + <title>xkcd-finder - results for {{ topic }}</title> +</head> +<body> + <section> + <h1 class="head">Results</h1> + <span class="sub">These are the top 5 results for the topic "{{ topic }}"</span> + <ul> + <li> + <span class="title">{{ results[0]["title"] }}</span> + <span class="id">xkdc {{ results[0]["id"] }}</span> + <a target="_blank" rel="noopener noreferrer" class="xkcdlink" href={{ results[0]["url"] }}>Link >></a> + </li> + <li> + <span class="title">{{ results[1]["title"] }}</span> + <span class="id">xkdc {{ results[1]["id"] }}</span> + <a target="_blank" rel="noopener noreferrer" class="xkcdlink" href={{ results[1]["url"] }}>Link >></a> + </li> + <li> + <span class="title">{{ results[2]["title"] }}</span> + <span class="id">xkdc {{ results[2]["id"] }}</span> + <a target="_blank" rel="noopener noreferrer" class="xkcdlink" href={{ results[2]["url"] }}>Link >></a> + </li> + <li> + <span class="title">{{ results[3]["title"] }}</span> + <span class="id">xkdc {{ results[3]["id"] }}</span> + <a target="_blank" rel="noopener noreferrer" class="xkcdlink" href={{ results[3]["url"] }}>Link >></a> + </li> + <li> + <span class="title">{{ results[4]["title"] }}</span> + <span class="id">xkdc {{ results[4]["id"] }}</span> + <a target="_blank" rel="noopener noreferrer" class="xkcdlink" href={{ results[4]["url"] }}>Link >></a> + </li> + </ul> + <a href="/" class="backbutton"><< Back to Home</a> +</section> +</body> + +<style> + * { + font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif; + } + + body { + background: #1b1b1b; + margin: 0px; + width: auto; + } + + section { + display: flex; + flex-direction: column; + align-items: center; + } + + ul { + display: flex; + flex-direction: row; + width: 100%; + height: auto; + align-items: center; + justify-content: center; + text-decoration: none; + list-style: none; + padding: 0; + } + + li { + display: flex; + flex-direction: column; + position: relative; + width: 220px; + height: 250px; + border: 2px solid #b86cff; + margin: 1rem; + justify-items: flex-start; + align-items: center; + border-radius: 10px; + } + + span { + text-wrap: wrap; + } + + .title { + padding-top: 10px; + padding-left: 5px; + padding-right: 5px; + color: #eee; + font-size: 2.1rem; + font-weight: 600; + text-align: center; + } + + .id { + position: absolute; + bottom: 56px; + width: 70%; + margin-left: auto; + margin-right: auto; + text-align: center; + font-size: 1.2rem; + color: #eee; + } + + .xkcdlink { + position: absolute; + bottom: 10px; + width: 70%; + } + + a { + font-weight: 500; + color: #eee; + padding: 0.5rem 0.75rem; + text-decoration: none; + background-color: #333; + border-radius: 10px; + border-style: solid; + border-width: 2px; + border-color: #333; + text-transform: capitalize; + text-align: center; + } + + a:hover { + border-color: #b86cff; + } + + .head { + margin: 0 .5rem; + hyphens: auto; + padding-top: 5rem; + font-size: 3.25rem; + font-weight: 700; + } + + .sub { + margin: 0 .5rem; + font-size: 1.2rem; + font-weight: 500; + color: #eee; + } + + .head { + color: #b86cff; + } +</style>