import asyncio
import json
import logging
from typing import Any, Dict, List, Optional

from langchain_core.documents import Document

from langchain_community.document_loaders.base import BaseLoader

logger = logging.getLogger(__name__)


class SurrealDBLoader(BaseLoader):
    """Load SurrealDB documents."""

    def __init__(
        self,
        filter_criteria: Optional[Dict] = None,
        **kwargs: Any,
    ) -> None:
        try:
            from surrealdb import Surreal
        except ImportError as e:
            raise ImportError(
                """Cannot import from surrealdb.
                please install with `pip install surrealdb`."""
            ) from e

        self.dburl = kwargs.pop("dburl", "ws://localhost:8000/rpc")

        if self.dburl[0:2] == "ws":
            self.sdb = Surreal(self.dburl)
        else:
            raise ValueError("Only websocket connections are supported at this time.")

        self.filter_criteria = filter_criteria or {}

        if "table" in self.filter_criteria:
            raise ValueError(
                "key `table` is not a valid criteria for `filter_criteria` argument."
            )

        self.ns = kwargs.pop("ns", "langchain")
        self.db = kwargs.pop("db", "database")
        self.table = kwargs.pop("table", "documents")
        self.sdb = Surreal(self.dburl)
        self.kwargs = kwargs

    async def initialize(self) -> None:
        """
        Initialize connection to surrealdb database
        and authenticate if credentials are provided
        """
        await self.sdb.connect()
        if "db_user" in self.kwargs and "db_pass" in self.kwargs:
            user = self.kwargs.get("db_user")
            password = self.kwargs.get("db_pass")
            await self.sdb.signin({"user": user, "pass": password})

        await self.sdb.use(self.ns, self.db)

    def load(self) -> List[Document]:
        async def _load() -> List[Document]:
            await self.initialize()
            return await self.aload()

        return asyncio.run(_load())

    async def aload(self) -> List[Document]:
        """Load data into Document objects."""

        query = "SELECT * FROM type::table($table)"
        if self.filter_criteria is not None and len(self.filter_criteria) > 0:
            query += " WHERE "
            for idx, key in enumerate(self.filter_criteria):
                query += f""" {"AND" if idx > 0 else ""} {key} = ${key}"""

        metadata = {
            "ns": self.ns,
            "db": self.db,
            "table": self.table,
        }
        results = await self.sdb.query(
            query, {"table": self.table, **self.filter_criteria}
        )

        return [
            (
                Document(
                    page_content=json.dumps(result),
                    metadata={"id": result["id"], **result["metadata"], **metadata},
                )
            )
            for result in results[0]["result"]
        ]
