import json
import pathlib

import torch

from haystack import Document
from haystack.utils import ComponentDevice
from haystack import Pipeline

from haystack.components.embedders import SentenceTransformersDocumentEmbedder
from haystack.components.writers import DocumentWriter

from haystack_integrations.document_stores.pgvector import PgvectorDocumentStore
from haystack_integrations.components.retrievers.pgvector import PgvectorEmbeddingRetriever, PgvectorKeywordRetriever

from haystack.components.embedders import SentenceTransformersTextEmbedder
from haystack.components.joiners import DocumentJoiner
from haystack.components.rankers import TransformersSimilarityRanker

class AIBackend:
    model_embeddings = "BAAI/bge-base-en-v1.5"
    model_ranker = "BAAI/bge-reranker-base"
    _ready = False

    query_pipeline: Pipeline
    index_pipeline: Pipeline
    document_store: PgvectorDocumentStore
    documents: list[Document] = []

    def __init__(self, load_dataset = False):
        get_torch_info()
        try:
            self.gpu = ComponentDevice.from_str("cuda:0")
        except:
            self.gpu = None
            print("No CUDA gpu device found")

        if load_dataset:
            dataset = pathlib.Path(__file__).parents[1] / "data" / "dataset.jsonl"
            self.documents = [ Document(content=d["text"], meta=d["meta"]) for d in load_data(dataset) ]

        self.document_store = PgvectorDocumentStore(
            embedding_dimension=768,
            vector_function="cosine_similarity",
            recreate_table=False,
            search_strategy="hnsw",
        )

        self.index_pipeline = self._create_indexing_pipeline()
        self.query_pipeline = self._create_query_pipeline()

    def warmup(self):
        print("Running warmup routine ...")
        print("Launching indexing pipeline to generate document embeddings")
        res = self.index_pipeline.run({"document_embedder": {"documents": self.documents}})
        print(f"Finished running indexing pipeline\nDocument Store: Wrote {res['document_writer']['documents_written']} documents")
        self._ready = True
        print("'.query(\"text\")' is now ready to be used")

    def _create_indexing_pipeline(self):
        print("Creating indexing pipeline ...")
        document_embedder = SentenceTransformersDocumentEmbedder(model=self.model_embeddings, device=self.gpu)
        document_writer = DocumentWriter(document_store=self.document_store)

        indexing_pipeline = Pipeline()
        indexing_pipeline.add_component("document_embedder", document_embedder)
        indexing_pipeline.add_component("document_writer", document_writer)

        indexing_pipeline.connect("document_embedder", "document_writer")

        return indexing_pipeline
    
    def _create_query_pipeline(self):
        print("Creating hybrid retrival pipeline ...")
        text_embedder = SentenceTransformersTextEmbedder(model=self.model_embeddings, device=self.gpu)
        ranker = TransformersSimilarityRanker(model=self.model_ranker, device=self.gpu)

        embedding_retriever = PgvectorEmbeddingRetriever(document_store=self.document_store)
        keyword_retriever = PgvectorKeywordRetriever(document_store=self.document_store)

        document_joiner = DocumentJoiner()

        hybrid_retrieval = Pipeline()
        hybrid_retrieval.add_component("text_embedder", text_embedder)
        hybrid_retrieval.add_component("embedding_retriever", embedding_retriever)
        hybrid_retrieval.add_component("keyword_retriever", keyword_retriever)
        hybrid_retrieval.add_component("document_joiner", document_joiner)
        hybrid_retrieval.add_component("ranker", ranker)

        hybrid_retrieval.connect("text_embedder", "embedding_retriever")
        hybrid_retrieval.connect("keyword_retriever", "document_joiner")
        hybrid_retrieval.connect("embedding_retriever", "document_joiner")
        hybrid_retrieval.connect("document_joiner", "ranker")

        return hybrid_retrieval

    def query(self, query: str):
        if not self._ready:
            raise SystemError("Cannot query when warmup hasn't been run yet")

        return self.query_pipeline.run(
            data={
                "text_embedder": { "text": query },
                "keyword_retriever": { "query": query },
                "ranker": { "query": query, "top_k": 5 }
            }
        )
    
    @staticmethod
    def format_result(result):
        result_table = []
        x: Document | None
        for x in result["ranker"]["documents"]:
            if x is None:
                continue
            result_table.append([f"[ {x.meta['id']:4} ]", x.meta["title"], x.meta["url"]])
        return result_table
    
    @staticmethod
    def format_for_api(result):
        results = []
        x: Document | None
        for x in result["ranker"]["documents"]:
            if x is None:
                continue
            results.append({
                "id": x.meta["id"],
                "title": x.meta["title"], 
                "url": x.meta["url"],
                "image_url": x.meta["image_url"]
            })
        return results


def load_data(dataset_path):
    data = []

    with open(dataset_path, "r") as f:
        for x in f.readlines():
            j: dict = json.loads(x)
            j.update({
                "text": f"{j['title']} | {j['transcript']} | {j['explanation']}",
                "meta": { "title": j["title"], "url": j["url"], "image_url": j["image_url"], "id": j["id"] }
                }
            )
            data.append(j)
    return data

def get_torch_info():
    print("---------- Getting information about pytorch setup ----------")
    print(f"Is CUDA or ROCm available? { 'Yes' if torch.cuda.is_available() else 'No'}")
    print("Available devices:")
    for i in range(torch.cuda.device_count()):
        dev = torch.cuda.get_device_properties(i)
        print(f"- [{i}] {dev.name} [ {dev.multi_processor_count} processors, {dev.total_memory / 1_000_000_000:.2f} GB ]")
    print("-------------------------------------------------------------")