from __future__ import annotations

import asyncio
from asyncio import InvalidStateError, Task
from typing import (
    TYPE_CHECKING,
    AsyncIterator,
    Iterator,
    List,
    Optional,
    Sequence,
    Tuple,
)

from langchain_core.stores import ByteStore

from langchain_community.utilities.cassandra import SetupMode, aexecute_cql

if TYPE_CHECKING:
    from cassandra.cluster import Session
    from cassandra.query import PreparedStatement

CREATE_TABLE_CQL_TEMPLATE = """
    CREATE TABLE IF NOT EXISTS {keyspace}.{table} 
    (row_id TEXT, body_blob BLOB, PRIMARY KEY (row_id));
"""
SELECT_TABLE_CQL_TEMPLATE = (
    """SELECT row_id, body_blob FROM  {keyspace}.{table} WHERE row_id IN ?;"""
)
SELECT_ALL_TABLE_CQL_TEMPLATE = """SELECT row_id, body_blob FROM  {keyspace}.{table};"""
INSERT_TABLE_CQL_TEMPLATE = (
    """INSERT INTO {keyspace}.{table} (row_id, body_blob) VALUES (?, ?);"""
)
DELETE_TABLE_CQL_TEMPLATE = """DELETE FROM {keyspace}.{table} WHERE row_id IN ?;"""


class CassandraByteStore(ByteStore):
    """A ByteStore implementation using Cassandra as the backend.

    Parameters:
        table: The name of the table to use.
        session: A Cassandra session object. If not provided, it will be resolved
            from the cassio config.
        keyspace: The keyspace to use. If not provided, it will be resolved
            from the cassio config.
        setup_mode: The setup mode to use. Default is SYNC  (SetupMode.SYNC).
    """

    def __init__(
        self,
        table: str,
        *,
        session: Optional[Session] = None,
        keyspace: Optional[str] = None,
        setup_mode: SetupMode = SetupMode.SYNC,
    ) -> None:
        if not session or not keyspace:
            try:
                from cassio.config import check_resolve_keyspace, check_resolve_session

                self.keyspace = keyspace or check_resolve_keyspace(keyspace)
                self.session = session or check_resolve_session()
            except (ImportError, ModuleNotFoundError):
                raise ImportError(
                    "Could not import a recent cassio package."
                    "Please install it with `pip install --upgrade cassio`."
                )
        else:
            self.keyspace = keyspace
            self.session = session
        self.table = table
        self.select_statement = None
        self.insert_statement = None
        self.delete_statement = None

        create_cql = CREATE_TABLE_CQL_TEMPLATE.format(
            keyspace=self.keyspace,
            table=self.table,
        )
        self.db_setup_task: Optional[Task[None]] = None
        if setup_mode == SetupMode.ASYNC:
            self.db_setup_task = asyncio.create_task(
                aexecute_cql(self.session, create_cql)
            )
        else:
            self.session.execute(create_cql)

    def ensure_db_setup(self) -> None:
        """Ensure that the DB setup is finished. If not, raise a ValueError."""
        if self.db_setup_task:
            try:
                self.db_setup_task.result()
            except InvalidStateError:
                raise ValueError(
                    "Asynchronous setup of the DB not finished. "
                    "NB: AstraDB components sync methods shouldn't be called from the "
                    "event loop. Consider using their async equivalents."
                )

    async def aensure_db_setup(self) -> None:
        """Ensure that the DB setup is finished. If not, wait for it."""
        if self.db_setup_task:
            await self.db_setup_task

    def get_select_statement(self) -> PreparedStatement:
        """Get the prepared select statement for the table.
        If not available, prepare it.

        Returns:
            PreparedStatement: The prepared statement.
        """
        if not self.select_statement:
            self.select_statement = self.session.prepare(
                SELECT_TABLE_CQL_TEMPLATE.format(
                    keyspace=self.keyspace, table=self.table
                )
            )
        return self.select_statement

    def get_insert_statement(self) -> PreparedStatement:
        """Get the prepared insert statement for the table.
        If not available, prepare it.

        Returns:
            PreparedStatement: The prepared statement.
        """
        if not self.insert_statement:
            self.insert_statement = self.session.prepare(
                INSERT_TABLE_CQL_TEMPLATE.format(
                    keyspace=self.keyspace, table=self.table
                )
            )
        return self.insert_statement

    def get_delete_statement(self) -> PreparedStatement:
        """Get the prepared delete statement for the table.
        If not available, prepare it.

        Returns:
            PreparedStatement: The prepared statement.
        """

        if not self.delete_statement:
            self.delete_statement = self.session.prepare(
                DELETE_TABLE_CQL_TEMPLATE.format(
                    keyspace=self.keyspace, table=self.table
                )
            )
        return self.delete_statement

    def mget(self, keys: Sequence[str]) -> List[Optional[bytes]]:
        from cassandra.query import ValueSequence

        self.ensure_db_setup()
        docs_dict = {}
        for row in self.session.execute(
            self.get_select_statement(), [ValueSequence(keys)]
        ):
            docs_dict[row.row_id] = row.body_blob
        return [docs_dict.get(key) for key in keys]

    async def amget(self, keys: Sequence[str]) -> List[Optional[bytes]]:
        from cassandra.query import ValueSequence

        await self.aensure_db_setup()
        docs_dict = {}
        for row in await aexecute_cql(
            self.session, self.get_select_statement(), parameters=[ValueSequence(keys)]
        ):
            docs_dict[row.row_id] = row.body_blob
        return [docs_dict.get(key) for key in keys]

    def mset(self, key_value_pairs: Sequence[Tuple[str, bytes]]) -> None:
        self.ensure_db_setup()
        insert_statement = self.get_insert_statement()
        for k, v in key_value_pairs:
            self.session.execute(insert_statement, (k, v))

    async def amset(self, key_value_pairs: Sequence[Tuple[str, bytes]]) -> None:
        await self.aensure_db_setup()
        insert_statement = self.get_insert_statement()
        for k, v in key_value_pairs:
            await aexecute_cql(self.session, insert_statement, parameters=(k, v))

    def mdelete(self, keys: Sequence[str]) -> None:
        from cassandra.query import ValueSequence

        self.ensure_db_setup()
        self.session.execute(self.get_delete_statement(), [ValueSequence(keys)])

    async def amdelete(self, keys: Sequence[str]) -> None:
        from cassandra.query import ValueSequence

        await self.aensure_db_setup()
        await aexecute_cql(
            self.session, self.get_delete_statement(), parameters=[ValueSequence(keys)]
        )

    def yield_keys(self, *, prefix: Optional[str] = None) -> Iterator[str]:
        self.ensure_db_setup()
        for row in self.session.execute(
            SELECT_ALL_TABLE_CQL_TEMPLATE.format(
                keyspace=self.keyspace, table=self.table
            )
        ):
            key = row.row_id
            if not prefix or key.startswith(prefix):
                yield key

    async def ayield_keys(self, *, prefix: Optional[str] = None) -> AsyncIterator[str]:
        await self.aensure_db_setup()
        for row in await aexecute_cql(
            self.session,
            SELECT_ALL_TABLE_CQL_TEMPLATE.format(
                keyspace=self.keyspace, table=self.table
            ),
        ):
            key = row.row_id
            if not prefix or key.startswith(prefix):
                yield key
