import itertools
import logging
import sys
from typing import TYPE_CHECKING, Any, Iterator, List, Optional, Tuple

from langchain_core.documents import Document

from langchain_community.document_loaders.base import BaseLoader

logger = logging.getLogger(__file__)

if TYPE_CHECKING:
    from pyspark.sql import SparkSession


class PySparkDataFrameLoader(BaseLoader):
    """Load `PySpark` DataFrames."""

    def __init__(
        self,
        spark_session: Optional["SparkSession"] = None,
        df: Optional[Any] = None,
        page_content_column: str = "text",
        fraction_of_memory: float = 0.1,
    ):
        """Initialize with a Spark DataFrame object.

        Args:
            spark_session: The SparkSession object.
            df: The Spark DataFrame object.
            page_content_column: The name of the column containing the page content.
             Defaults to "text".
            fraction_of_memory: The fraction of memory to use. Defaults to 0.1.
        """
        try:
            from pyspark.sql import DataFrame, SparkSession
        except ImportError:
            raise ImportError(
                "pyspark is not installed. "
                "Please install it with `pip install pyspark`"
            )

        self.spark = (
            spark_session if spark_session else SparkSession.builder.getOrCreate()
        )

        if not isinstance(df, DataFrame):
            raise ValueError(
                f"Expected data_frame to be a PySpark DataFrame, got {type(df)}"
            )
        self.df = df
        self.page_content_column = page_content_column
        self.fraction_of_memory = fraction_of_memory
        self.num_rows, self.max_num_rows = self.get_num_rows()
        self.rdd_df = self.df.rdd.map(list)
        self.column_names = self.df.columns

    def get_num_rows(self) -> Tuple[int, int]:
        """Gets the number of "feasible" rows for the DataFrame"""
        try:
            import psutil
        except ImportError as e:
            raise ImportError(
                "psutil not installed. Please install it with `pip install psutil`."
            ) from e
        row = self.df.limit(1).collect()[0]
        estimated_row_size = sys.getsizeof(row)
        mem_info = psutil.virtual_memory()
        available_memory = mem_info.available
        max_num_rows = int(
            (available_memory / estimated_row_size) * self.fraction_of_memory
        )
        return min(max_num_rows, self.df.count()), max_num_rows

    def lazy_load(self) -> Iterator[Document]:
        """A lazy loader for document content."""
        for row in self.rdd_df.toLocalIterator():
            metadata = {self.column_names[i]: row[i] for i in range(len(row))}
            text = metadata[self.page_content_column]
            metadata.pop(self.page_content_column)
            yield Document(page_content=text, metadata=metadata)

    def load(self) -> List[Document]:
        """Load from the dataframe."""
        if self.df.count() > self.max_num_rows:
            logger.warning(
                f"The number of DataFrame rows is {self.df.count()}, "
                f"but we will only include the amount "
                f"of rows that can reasonably fit in memory: {self.num_rows}."
            )
        lazy_load_iterator = self.lazy_load()
        return list(itertools.islice(lazy_load_iterator, self.num_rows))
