import json
import os
import shutil
import tempfile
from copy import deepcopy
from typing import Any, Dict, List, Optional

from langchain_core.agents import AgentAction, AgentFinish
from langchain_core.callbacks import BaseCallbackHandler
from langchain_core.outputs import LLMResult

from langchain_community.callbacks.utils import (
    flatten_dict,
)


def save_json(data: dict, file_path: str) -> None:
    """Save dict to local file path.

    Parameters:
        data (dict): The dictionary to be saved.
        file_path (str): Local file path.
    """
    with open(file_path, "w") as outfile:
        json.dump(data, outfile)


class SageMakerCallbackHandler(BaseCallbackHandler):
    """Callback Handler that logs prompt artifacts and metrics to SageMaker Experiments.

    Parameters:
        run (sagemaker.experiments.run.Run): Run object where the experiment is logged.
    """

    def __init__(self, run: Any) -> None:
        """Initialize callback handler."""
        super().__init__()

        self.run = run

        self.metrics = {
            "step": 0,
            "starts": 0,
            "ends": 0,
            "errors": 0,
            "text_ctr": 0,
            "chain_starts": 0,
            "chain_ends": 0,
            "llm_starts": 0,
            "llm_ends": 0,
            "llm_streams": 0,
            "tool_starts": 0,
            "tool_ends": 0,
            "agent_ends": 0,
        }

        # Create a temporary directory
        self.temp_dir = tempfile.mkdtemp()

    def _reset(self) -> None:
        for k, v in self.metrics.items():
            self.metrics[k] = 0

    def on_llm_start(
        self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
    ) -> None:
        """Run when LLM starts."""
        self.metrics["step"] += 1
        self.metrics["llm_starts"] += 1
        self.metrics["starts"] += 1

        llm_starts = self.metrics["llm_starts"]

        resp: Dict[str, Any] = {}
        resp.update({"action": "on_llm_start"})
        resp.update(flatten_dict(serialized))
        resp.update(self.metrics)

        for idx, prompt in enumerate(prompts):
            prompt_resp = deepcopy(resp)
            prompt_resp["prompt"] = prompt
            self.jsonf(
                prompt_resp,
                self.temp_dir,
                f"llm_start_{llm_starts}_prompt_{idx}",
            )

    def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
        """Run when LLM generates a new token."""
        self.metrics["step"] += 1
        self.metrics["llm_streams"] += 1

        llm_streams = self.metrics["llm_streams"]

        resp: Dict[str, Any] = {}
        resp.update({"action": "on_llm_new_token", "token": token})
        resp.update(self.metrics)

        self.jsonf(resp, self.temp_dir, f"llm_new_tokens_{llm_streams}")

    def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
        """Run when LLM ends running."""
        self.metrics["step"] += 1
        self.metrics["llm_ends"] += 1
        self.metrics["ends"] += 1

        llm_ends = self.metrics["llm_ends"]

        resp: Dict[str, Any] = {}
        resp.update({"action": "on_llm_end"})
        resp.update(flatten_dict(response.llm_output or {}))

        resp.update(self.metrics)

        for generations in response.generations:
            for idx, generation in enumerate(generations):
                generation_resp = deepcopy(resp)
                generation_resp.update(flatten_dict(generation.dict()))

                self.jsonf(
                    resp,
                    self.temp_dir,
                    f"llm_end_{llm_ends}_generation_{idx}",
                )

    def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
        """Run when LLM errors."""
        self.metrics["step"] += 1
        self.metrics["errors"] += 1

    def on_chain_start(
        self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
    ) -> None:
        """Run when chain starts running."""
        self.metrics["step"] += 1
        self.metrics["chain_starts"] += 1
        self.metrics["starts"] += 1

        chain_starts = self.metrics["chain_starts"]

        resp: Dict[str, Any] = {}
        resp.update({"action": "on_chain_start"})
        resp.update(flatten_dict(serialized))
        resp.update(self.metrics)

        chain_input = ",".join([f"{k}={v}" for k, v in inputs.items()])
        input_resp = deepcopy(resp)
        input_resp["inputs"] = chain_input

        self.jsonf(input_resp, self.temp_dir, f"chain_start_{chain_starts}")

    def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
        """Run when chain ends running."""
        self.metrics["step"] += 1
        self.metrics["chain_ends"] += 1
        self.metrics["ends"] += 1

        chain_ends = self.metrics["chain_ends"]

        resp: Dict[str, Any] = {}
        chain_output = ",".join([f"{k}={v}" for k, v in outputs.items()])
        resp.update({"action": "on_chain_end", "outputs": chain_output})
        resp.update(self.metrics)

        self.jsonf(resp, self.temp_dir, f"chain_end_{chain_ends}")

    def on_chain_error(self, error: BaseException, **kwargs: Any) -> None:
        """Run when chain errors."""
        self.metrics["step"] += 1
        self.metrics["errors"] += 1

    def on_tool_start(
        self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
    ) -> None:
        """Run when tool starts running."""
        self.metrics["step"] += 1
        self.metrics["tool_starts"] += 1
        self.metrics["starts"] += 1

        tool_starts = self.metrics["tool_starts"]

        resp: Dict[str, Any] = {}
        resp.update({"action": "on_tool_start", "input_str": input_str})
        resp.update(flatten_dict(serialized))
        resp.update(self.metrics)

        self.jsonf(resp, self.temp_dir, f"tool_start_{tool_starts}")

    def on_tool_end(self, output: Any, **kwargs: Any) -> None:
        """Run when tool ends running."""
        output = str(output)
        self.metrics["step"] += 1
        self.metrics["tool_ends"] += 1
        self.metrics["ends"] += 1

        tool_ends = self.metrics["tool_ends"]

        resp: Dict[str, Any] = {}
        resp.update({"action": "on_tool_end", "output": output})
        resp.update(self.metrics)

        self.jsonf(resp, self.temp_dir, f"tool_end_{tool_ends}")

    def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
        """Run when tool errors."""
        self.metrics["step"] += 1
        self.metrics["errors"] += 1

    def on_text(self, text: str, **kwargs: Any) -> None:
        """
        Run when agent is ending.
        """
        self.metrics["step"] += 1
        self.metrics["text_ctr"] += 1

        text_ctr = self.metrics["text_ctr"]

        resp: Dict[str, Any] = {}
        resp.update({"action": "on_text", "text": text})
        resp.update(self.metrics)

        self.jsonf(resp, self.temp_dir, f"on_text_{text_ctr}")

    def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None:
        """Run when agent ends running."""
        self.metrics["step"] += 1
        self.metrics["agent_ends"] += 1
        self.metrics["ends"] += 1

        agent_ends = self.metrics["agent_ends"]
        resp: Dict[str, Any] = {}
        resp.update(
            {
                "action": "on_agent_finish",
                "output": finish.return_values["output"],
                "log": finish.log,
            }
        )
        resp.update(self.metrics)

        self.jsonf(resp, self.temp_dir, f"agent_finish_{agent_ends}")

    def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
        """Run on agent action."""
        self.metrics["step"] += 1
        self.metrics["tool_starts"] += 1
        self.metrics["starts"] += 1

        tool_starts = self.metrics["tool_starts"]
        resp: Dict[str, Any] = {}
        resp.update(
            {
                "action": "on_agent_action",
                "tool": action.tool,
                "tool_input": action.tool_input,
                "log": action.log,
            }
        )
        resp.update(self.metrics)
        self.jsonf(resp, self.temp_dir, f"agent_action_{tool_starts}")

    def jsonf(
        self,
        data: Dict[str, Any],
        data_dir: str,
        filename: str,
        is_output: Optional[bool] = True,
    ) -> None:
        """To log the input data as json file artifact."""
        file_path = os.path.join(data_dir, f"{filename}.json")
        save_json(data, file_path)
        self.run.log_file(file_path, name=filename, is_output=is_output)

    def flush_tracker(self) -> None:
        """Reset the steps and delete the temporary local directory."""
        self._reset()
        shutil.rmtree(self.temp_dir)
