"""Callback handler for promptlayer."""

from __future__ import annotations

import datetime
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple
from uuid import UUID

from langchain_core.callbacks import BaseCallbackHandler
from langchain_core.messages import (
    AIMessage,
    BaseMessage,
    ChatMessage,
    HumanMessage,
    SystemMessage,
)
from langchain_core.outputs import (
    ChatGeneration,
    LLMResult,
)

if TYPE_CHECKING:
    import promptlayer


def _lazy_import_promptlayer() -> promptlayer:
    """Lazy import promptlayer to avoid circular imports."""
    try:
        import promptlayer
    except ImportError:
        raise ImportError(
            "The PromptLayerCallbackHandler requires the promptlayer package. "
            " Please install it with `pip install promptlayer`."
        )
    return promptlayer


class PromptLayerCallbackHandler(BaseCallbackHandler):
    """Callback handler for promptlayer."""

    def __init__(
        self,
        pl_id_callback: Optional[Callable[..., Any]] = None,
        pl_tags: Optional[List[str]] = None,
    ) -> None:
        """Initialize the PromptLayerCallbackHandler."""
        _lazy_import_promptlayer()
        self.pl_id_callback = pl_id_callback
        self.pl_tags = pl_tags or []
        self.runs: Dict[UUID, Dict[str, Any]] = {}

    def on_chat_model_start(
        self,
        serialized: Dict[str, Any],
        messages: List[List[BaseMessage]],
        *,
        run_id: UUID,
        parent_run_id: Optional[UUID] = None,
        tags: Optional[List[str]] = None,
        **kwargs: Any,
    ) -> Any:
        self.runs[run_id] = {
            "messages": [self._create_message_dicts(m)[0] for m in messages],
            "invocation_params": kwargs.get("invocation_params", {}),
            "name": ".".join(serialized["id"]),
            "request_start_time": datetime.datetime.now().timestamp(),
            "tags": tags,
        }

    def on_llm_start(
        self,
        serialized: Dict[str, Any],
        prompts: List[str],
        *,
        run_id: UUID,
        parent_run_id: Optional[UUID] = None,
        tags: Optional[List[str]] = None,
        **kwargs: Any,
    ) -> Any:
        self.runs[run_id] = {
            "prompts": prompts,
            "invocation_params": kwargs.get("invocation_params", {}),
            "name": ".".join(serialized["id"]),
            "request_start_time": datetime.datetime.now().timestamp(),
            "tags": tags,
        }

    def on_llm_end(
        self,
        response: LLMResult,
        *,
        run_id: UUID,
        parent_run_id: Optional[UUID] = None,
        **kwargs: Any,
    ) -> None:
        from promptlayer.utils import get_api_key, promptlayer_api_request

        run_info = self.runs.get(run_id, {})
        if not run_info:
            return
        run_info["request_end_time"] = datetime.datetime.now().timestamp()
        for i in range(len(response.generations)):
            generation = response.generations[i][0]

            resp = {
                "text": generation.text,
                "llm_output": response.llm_output,
            }
            model_params = run_info.get("invocation_params", {})
            is_chat_model = run_info.get("messages", None) is not None
            model_input = (
                run_info.get("messages", [])[i]
                if is_chat_model
                else [run_info.get("prompts", [])[i]]
            )
            model_response = (
                [self._convert_message_to_dict(generation.message)]
                if is_chat_model and isinstance(generation, ChatGeneration)
                else resp
            )

            pl_request_id = promptlayer_api_request(
                run_info.get("name"),
                "langchain",
                model_input,
                model_params,
                self.pl_tags,
                model_response,
                run_info.get("request_start_time"),
                run_info.get("request_end_time"),
                get_api_key(),
                return_pl_id=bool(self.pl_id_callback is not None),
                metadata={
                    "_langchain_run_id": str(run_id),
                    "_langchain_parent_run_id": str(parent_run_id),
                    "_langchain_tags": str(run_info.get("tags", [])),
                },
            )

            if self.pl_id_callback:
                self.pl_id_callback(pl_request_id)

    def _convert_message_to_dict(self, message: BaseMessage) -> Dict[str, Any]:
        if isinstance(message, HumanMessage):
            message_dict = {"role": "user", "content": message.content}
        elif isinstance(message, AIMessage):
            message_dict = {"role": "assistant", "content": message.content}
        elif isinstance(message, SystemMessage):
            message_dict = {"role": "system", "content": message.content}
        elif isinstance(message, ChatMessage):
            message_dict = {"role": message.role, "content": message.content}
        else:
            raise ValueError(f"Got unknown type {message}")
        if "name" in message.additional_kwargs:
            message_dict["name"] = message.additional_kwargs["name"]
        return message_dict

    def _create_message_dicts(
        self, messages: List[BaseMessage]
    ) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
        params: Dict[str, Any] = {}
        message_dicts = [self._convert_message_to_dict(m) for m in messages]
        return message_dicts, params
