from __future__ import annotations

from typing import Any, AsyncIterator, Dict, Iterator, List, Optional

from langchain_core.callbacks import (
    AsyncCallbackManagerForLLMRun,
    CallbackManagerForLLMRun,
)
from langchain_core.language_models.chat_models import (
    BaseChatModel,
    agenerate_from_stream,
    generate_from_stream,
)
from langchain_core.messages import (
    AIMessage,
    AIMessageChunk,
    BaseMessage,
    ChatMessage,
    HumanMessage,
    SystemMessage,
)
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult

from langchain_community.llms.friendli import BaseFriendli


def get_role(message: BaseMessage) -> str:
    """Get role of the message.

    Args:
        message (BaseMessage): The message object.

    Raises:
        ValueError: Raised when the message is of an unknown type.

    Returns:
        str: The role of the message.
    """
    if isinstance(message, ChatMessage) or isinstance(message, HumanMessage):
        return "user"
    if isinstance(message, AIMessage):
        return "assistant"
    if isinstance(message, SystemMessage):
        return "system"
    raise ValueError(f"Got unknown type {message}")


def get_chat_request(messages: List[BaseMessage]) -> Dict[str, Any]:
    """Get a request of the Friendli chat API.

    Args:
        messages (List[BaseMessage]): Messages comprising the conversation so far.

    Returns:
        Dict[str, Any]: The request for the Friendli chat API.
    """
    return {
        "messages": [
            {"role": get_role(message), "content": message.content}
            for message in messages
        ]
    }


class ChatFriendli(BaseChatModel, BaseFriendli):
    """Friendli LLM for chat.

    ``friendli-client`` package should be installed with `pip install friendli-client`.
    You must set ``FRIENDLI_TOKEN`` environment variable or provide the value of your
    personal access token for the ``friendli_token`` argument.

    Example:
        .. code-block:: python

            from langchain_community.chat_models import FriendliChat

            chat = Friendli(
                model="llama-2-13b-chat", friendli_token="YOUR FRIENDLI TOKEN"
            )
            chat.invoke("What is generative AI?")
    """

    model: str = "llama-2-13b-chat"

    @property
    def lc_secrets(self) -> Dict[str, str]:
        return {"friendli_token": "FRIENDLI_TOKEN"}

    @property
    def _default_params(self) -> Dict[str, Any]:
        """Get the default parameters for calling Friendli completions API."""
        return {
            "frequency_penalty": self.frequency_penalty,
            "presence_penalty": self.presence_penalty,
            "max_tokens": self.max_tokens,
            "stop": self.stop,
            "temperature": self.temperature,
            "top_p": self.top_p,
        }

    @property
    def _identifying_params(self) -> Dict[str, Any]:
        """Get the identifying parameters."""
        return {"model": self.model, **self._default_params}

    @property
    def _llm_type(self) -> str:
        return "friendli-chat"

    def _get_invocation_params(
        self, stop: Optional[List[str]] = None, **kwargs: Any
    ) -> Dict[str, Any]:
        """Get the parameters used to invoke the model."""
        params = self._default_params
        if self.stop is not None and stop is not None:
            raise ValueError("`stop` found in both the input and default params.")
        elif self.stop is not None:
            params["stop"] = self.stop
        else:
            params["stop"] = stop
        return {**params, **kwargs}

    def _stream(
        self,
        messages: List[BaseMessage],
        stop: Optional[List[str]] = None,
        run_manager: Optional[CallbackManagerForLLMRun] = None,
        **kwargs: Any,
    ) -> Iterator[ChatGenerationChunk]:
        params = self._get_invocation_params(stop=stop, **kwargs)
        stream = self.client.chat.completions.create(
            **get_chat_request(messages), stream=True, model=self.model, **params
        )
        for chunk in stream:
            delta = chunk.choices[0].delta.content
            if delta:
                if run_manager:
                    run_manager.on_llm_new_token(delta)
                yield ChatGenerationChunk(message=AIMessageChunk(content=delta))

    async def _astream(
        self,
        messages: List[BaseMessage],
        stop: Optional[List[str]] = None,
        run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
        **kwargs: Any,
    ) -> AsyncIterator[ChatGenerationChunk]:
        params = self._get_invocation_params(stop=stop, **kwargs)
        stream = await self.async_client.chat.completions.create(
            **get_chat_request(messages), stream=True, model=self.model, **params
        )
        async for chunk in stream:
            delta = chunk.choices[0].delta.content
            if delta:
                if run_manager:
                    await run_manager.on_llm_new_token(delta)
                yield ChatGenerationChunk(message=AIMessageChunk(content=delta))

    def _generate(
        self,
        messages: List[BaseMessage],
        stop: Optional[List[str]] = None,
        run_manager: Optional[CallbackManagerForLLMRun] = None,
        **kwargs: Any,
    ) -> ChatResult:
        if self.streaming:
            stream_iter = self._stream(
                messages, stop=stop, run_manager=run_manager, **kwargs
            )
            return generate_from_stream(stream_iter)

        params = self._get_invocation_params(stop=stop, **kwargs)
        response = self.client.chat.completions.create(
            messages=[
                {
                    "role": get_role(message),
                    "content": message.content,
                }
                for message in messages
            ],
            stream=False,
            model=self.model,
            **params,
        )

        message = AIMessage(content=response.choices[0].message.content)
        return ChatResult(generations=[ChatGeneration(message=message)])

    async def _agenerate(
        self,
        messages: List[BaseMessage],
        stop: Optional[List[str]] = None,
        run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
        **kwargs: Any,
    ) -> ChatResult:
        if self.streaming:
            stream_iter = self._astream(
                messages, stop=stop, run_manager=run_manager, **kwargs
            )
            return await agenerate_from_stream(stream_iter)

        params = self._get_invocation_params(stop=stop, **kwargs)
        response = await self.async_client.chat.completions.create(
            messages=[
                {
                    "role": get_role(message),
                    "content": message.content,
                }
                for message in messages
            ],
            stream=False,
            model=self.model,
            **params,
        )

        message = AIMessage(content=response.choices[0].message.content)
        return ChatResult(generations=[ChatGeneration(message=message)])
