import os
import warnings
from datetime import datetime
from enum import Enum
from typing import Any, Dict, List, Optional, Tuple, Union
from uuid import UUID

from langchain_core.agents import AgentAction, AgentFinish
from langchain_core.callbacks import BaseCallbackHandler
from langchain_core.messages import BaseMessage, ChatMessage
from langchain_core.outputs import Generation, LLMResult


class LabelStudioMode(Enum):
    """Label Studio mode enumerator."""

    PROMPT = "prompt"
    CHAT = "chat"


def get_default_label_configs(
    mode: Union[str, LabelStudioMode],
) -> Tuple[str, LabelStudioMode]:
    """Get default Label Studio configs for the given mode.

    Parameters:
        mode: Label Studio mode ("prompt" or "chat")

    Returns: Tuple of Label Studio config and mode
    """
    _default_label_configs = {
        LabelStudioMode.PROMPT.value: """
<View>
<Style>
    .prompt-box {
        background-color: white;
        border-radius: 10px;
        box-shadow: 0px 4px 6px rgba(0, 0, 0, 0.1);
        padding: 20px;
    }
</Style>
<View className="root">
    <View className="prompt-box">
        <Text name="prompt" value="$prompt"/>
    </View>
    <TextArea name="response" toName="prompt"
              maxSubmissions="1" editable="true"
              required="true"/>
</View>
<Header value="Rate the response:"/>
<Rating name="rating" toName="prompt"/>
</View>""",
        LabelStudioMode.CHAT.value: """
<View>
<View className="root">
     <Paragraphs name="dialogue"
               value="$prompt"
               layout="dialogue"
               textKey="content"
               nameKey="role"
               granularity="sentence"/>
  <Header value="Final response:"/>
    <TextArea name="response" toName="dialogue"
              maxSubmissions="1" editable="true"
              required="true"/>
</View>
<Header value="Rate the response:"/>
<Rating name="rating" toName="dialogue"/>
</View>""",
    }

    if isinstance(mode, str):
        mode = LabelStudioMode(mode)

    return _default_label_configs[mode.value], mode


class LabelStudioCallbackHandler(BaseCallbackHandler):
    """Label Studio callback handler.
    Provides the ability to send predictions to Label Studio
    for human evaluation, feedback and annotation.

    Parameters:
        api_key: Label Studio API key
        url: Label Studio URL
        project_id: Label Studio project ID
        project_name: Label Studio project name
        project_config: Label Studio project config (XML)
        mode: Label Studio mode ("prompt" or "chat")

    Examples:
        >>> from langchain_community.llms import OpenAI
        >>> from langchain_community.callbacks import LabelStudioCallbackHandler
        >>> handler = LabelStudioCallbackHandler(
        ...             api_key='<your_key_here>',
        ...             url='http://localhost:8080',
        ...             project_name='LangChain-%Y-%m-%d',
        ...             mode='prompt'
        ... )
        >>> llm = OpenAI(callbacks=[handler])
        >>> llm.invoke('Tell me a story about a dog.')
    """

    DEFAULT_PROJECT_NAME: str = "LangChain-%Y-%m-%d"

    def __init__(
        self,
        api_key: Optional[str] = None,
        url: Optional[str] = None,
        project_id: Optional[int] = None,
        project_name: str = DEFAULT_PROJECT_NAME,
        project_config: Optional[str] = None,
        mode: Union[str, LabelStudioMode] = LabelStudioMode.PROMPT,
    ):
        super().__init__()

        # Import LabelStudio SDK
        try:
            import label_studio_sdk as ls
        except ImportError:
            raise ImportError(
                f"You're using {self.__class__.__name__} in your code,"
                f" but you don't have the LabelStudio SDK "
                f"Python package installed or upgraded to the latest version. "
                f"Please run `pip install -U label-studio-sdk`"
                f" before using this callback."
            )

        # Check if Label Studio API key is provided
        if not api_key:
            if os.getenv("LABEL_STUDIO_API_KEY"):
                api_key = str(os.getenv("LABEL_STUDIO_API_KEY"))
            else:
                raise ValueError(
                    f"You're using {self.__class__.__name__} in your code,"
                    f" Label Studio API key is not provided. "
                    f"Please provide Label Studio API key: "
                    f"go to the Label Studio instance, navigate to "
                    f"Account & Settings -> Access Token and copy the key. "
                    f"Use the key as a parameter for the callback: "
                    f"{self.__class__.__name__}"
                    f"(label_studio_api_key='<your_key_here>', ...) or "
                    f"set the environment variable LABEL_STUDIO_API_KEY=<your_key_here>"
                )
        self.api_key = api_key

        if not url:
            if os.getenv("LABEL_STUDIO_URL"):
                url = os.getenv("LABEL_STUDIO_URL")
            else:
                warnings.warn(
                    f"Label Studio URL is not provided, "
                    f"using default URL: {ls.LABEL_STUDIO_DEFAULT_URL}"
                    f"If you want to provide your own URL, use the parameter: "
                    f"{self.__class__.__name__}"
                    f"(label_studio_url='<your_url_here>', ...) "
                    f"or set the environment variable LABEL_STUDIO_URL=<your_url_here>"
                )
                url = ls.LABEL_STUDIO_DEFAULT_URL
        self.url = url

        # Maps run_id to prompts
        self.payload: Dict[str, Dict] = {}

        self.ls_client = ls.Client(url=self.url, api_key=self.api_key)
        self.project_name = project_name
        if project_config:
            self.project_config = project_config
            self.mode = None
        else:
            self.project_config, self.mode = get_default_label_configs(mode)

        self.project_id = project_id or os.getenv("LABEL_STUDIO_PROJECT_ID")
        if self.project_id is not None:
            self.ls_project = self.ls_client.get_project(int(self.project_id))
        else:
            project_title = datetime.today().strftime(self.project_name)
            existing_projects = self.ls_client.get_projects(title=project_title)
            if existing_projects:
                self.ls_project = existing_projects[0]
                self.project_id = self.ls_project.id
            else:
                self.ls_project = self.ls_client.create_project(
                    title=project_title, label_config=self.project_config
                )
                self.project_id = self.ls_project.id
        self.parsed_label_config = self.ls_project.parsed_label_config

        # Find the first TextArea tag
        # "from_name", "to_name", "value" will be used to create predictions
        self.from_name, self.to_name, self.value, self.input_type = (
            None,
            None,
            None,
            None,
        )
        for tag_name, tag_info in self.parsed_label_config.items():
            if tag_info["type"] == "TextArea":
                self.from_name = tag_name
                self.to_name = tag_info["to_name"][0]
                self.value = tag_info["inputs"][0]["value"]
                self.input_type = tag_info["inputs"][0]["type"]
                break
        if not self.from_name:
            error_message = (
                f'Label Studio project "{self.project_name}" '
                f"does not have a TextArea tag. "
                f"Please add a TextArea tag to the project."
            )
            if self.mode == LabelStudioMode.PROMPT:
                error_message += (
                    "\nHINT: go to project Settings -> "
                    "Labeling Interface -> Browse Templates"
                    ' and select "Generative AI -> '
                    'Supervised Language Model Fine-tuning" template.'
                )
            else:
                error_message += (
                    "\nHINT: go to project Settings -> "
                    "Labeling Interface -> Browse Templates"
                    " and check available templates under "
                    '"Generative AI" section.'
                )
            raise ValueError(error_message)

    def add_prompts_generations(
        self, run_id: str, generations: List[List[Generation]]
    ) -> None:
        # Create tasks in Label Studio
        tasks = []
        prompts = self.payload[run_id]["prompts"]
        model_version = (
            self.payload[run_id]["kwargs"]
            .get("invocation_params", {})
            .get("model_name")
        )
        for prompt, generation in zip(prompts, generations):
            tasks.append(
                {
                    "data": {
                        self.value: prompt,
                        "run_id": run_id,
                    },
                    "predictions": [
                        {
                            "result": [
                                {
                                    "from_name": self.from_name,
                                    "to_name": self.to_name,
                                    "type": "textarea",
                                    "value": {"text": [g.text for g in generation]},
                                }
                            ],
                            "model_version": model_version,
                        }
                    ],
                }
            )
        self.ls_project.import_tasks(tasks)

    def on_llm_start(
        self,
        serialized: Dict[str, Any],
        prompts: List[str],
        **kwargs: Any,
    ) -> None:
        """Save the prompts in memory when an LLM starts."""
        if self.input_type != "Text":
            raise ValueError(
                f'\nLabel Studio project "{self.project_name}" '
                f"has an input type <{self.input_type}>. "
                f'To make it work with the mode="chat", '
                f"the input type should be <Text>.\n"
                f"Read more here https://labelstud.io/tags/text"
            )
        run_id = str(kwargs["run_id"])
        self.payload[run_id] = {"prompts": prompts, "kwargs": kwargs}

    def _get_message_role(self, message: BaseMessage) -> str:
        """Get the role of the message."""
        if isinstance(message, ChatMessage):
            return message.role
        else:
            return message.__class__.__name__

    def on_chat_model_start(
        self,
        serialized: Dict[str, Any],
        messages: List[List[BaseMessage]],
        *,
        run_id: UUID,
        parent_run_id: Optional[UUID] = None,
        tags: Optional[List[str]] = None,
        metadata: Optional[Dict[str, Any]] = None,
        **kwargs: Any,
    ) -> Any:
        """Save the prompts in memory when an LLM starts."""
        if self.input_type != "Paragraphs":
            raise ValueError(
                f'\nLabel Studio project "{self.project_name}" '
                f"has an input type <{self.input_type}>. "
                f'To make it work with the mode="chat", '
                f"the input type should be <Paragraphs>.\n"
                f"Read more here https://labelstud.io/tags/paragraphs"
            )

        prompts = []
        for message_list in messages:
            dialog = []
            for message in message_list:
                dialog.append(
                    {
                        "role": self._get_message_role(message),
                        "content": message.content,
                    }
                )
            prompts.append(dialog)
        self.payload[str(run_id)] = {
            "prompts": prompts,
            "tags": tags,
            "metadata": metadata,
            "run_id": run_id,
            "parent_run_id": parent_run_id,
            "kwargs": kwargs,
        }

    def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
        """Do nothing when a new token is generated."""
        pass

    def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
        """Create a new Label Studio task for each prompt and generation."""
        run_id = str(kwargs["run_id"])

        # Submit results to Label Studio
        self.add_prompts_generations(run_id, response.generations)

        # Pop current run from `self.runs`
        self.payload.pop(run_id)

    def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
        """Do nothing when LLM outputs an error."""
        pass

    def on_chain_start(
        self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
    ) -> None:
        pass

    def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
        pass

    def on_chain_error(self, error: BaseException, **kwargs: Any) -> None:
        """Do nothing when LLM chain outputs an error."""
        pass

    def on_tool_start(
        self,
        serialized: Dict[str, Any],
        input_str: str,
        **kwargs: Any,
    ) -> None:
        """Do nothing when tool starts."""
        pass

    def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
        """Do nothing when agent takes a specific action."""
        pass

    def on_tool_end(
        self,
        output: str,
        observation_prefix: Optional[str] = None,
        llm_prefix: Optional[str] = None,
        **kwargs: Any,
    ) -> None:
        """Do nothing when tool ends."""
        pass

    def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
        """Do nothing when tool outputs an error."""
        pass

    def on_text(self, text: str, **kwargs: Any) -> None:
        """Do nothing"""
        pass

    def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None:
        """Do nothing"""
        pass
