import json
import logging
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional

import requests
from langchain_core.callbacks import (
    AsyncCallbackManagerForLLMRun,
    CallbackManagerForLLMRun,
)
from langchain_core.language_models.llms import LLM
from langchain_core.outputs import GenerationChunk
from pydantic import Field

logger = logging.getLogger(__name__)


class TextGen(LLM):
    """Text generation models from WebUI.

    To use, you should have the text-generation-webui installed, a model loaded,
    and --api added as a command-line option.

    Suggested installation, use one-click installer for your OS:
    https://github.com/oobabooga/text-generation-webui#one-click-installers

    Parameters below taken from text-generation-webui api example:
    https://github.com/oobabooga/text-generation-webui/blob/main/api-examples/api-example.py

    Example:
        .. code-block:: python

            from langchain_community.llms import TextGen
            llm = TextGen(model_url="http://localhost:8500")
    """

    model_url: str
    """The full URL to the textgen webui including http[s]://host:port """

    preset: Optional[str] = None
    """The preset to use in the textgen webui """

    max_new_tokens: Optional[int] = 250
    """The maximum number of tokens to generate."""

    do_sample: bool = Field(True, alias="do_sample")
    """Do sample"""

    temperature: Optional[float] = 1.3
    """Primary factor to control randomness of outputs. 0 = deterministic
    (only the most likely token is used). Higher value = more randomness."""

    top_p: Optional[float] = 0.1
    """If not set to 1, select tokens with probabilities adding up to less than this
    number. Higher value = higher range of possible random results."""

    typical_p: Optional[float] = 1
    """If not set to 1, select only tokens that are at least this much more likely to
    appear than random tokens, given the prior text."""

    epsilon_cutoff: Optional[float] = 0  # In units of 1e-4
    """Epsilon cutoff"""

    eta_cutoff: Optional[float] = 0  # In units of 1e-4
    """ETA cutoff"""

    repetition_penalty: Optional[float] = 1.18
    """Exponential penalty factor for repeating prior tokens. 1 means no penalty,
    higher value = less repetition, lower value = more repetition."""

    top_k: Optional[float] = 40
    """Similar to top_p, but select instead only the top_k most likely tokens.
    Higher value = higher range of possible random results."""

    min_length: Optional[int] = 0
    """Minimum generation length in tokens."""

    no_repeat_ngram_size: Optional[int] = 0
    """If not set to 0, specifies the length of token sets that are completely blocked
    from repeating at all. Higher values = blocks larger phrases,
    lower values = blocks words or letters from repeating.
    Only 0 or high values are a good idea in most cases."""

    num_beams: Optional[int] = 1
    """Number of beams"""

    penalty_alpha: Optional[float] = 0
    """Penalty Alpha"""

    length_penalty: Optional[float] = 1
    """Length Penalty"""

    early_stopping: bool = Field(False, alias="early_stopping")
    """Early stopping"""

    seed: int = Field(-1, alias="seed")
    """Seed (-1 for random)"""

    add_bos_token: bool = Field(True, alias="add_bos_token")
    """Add the bos_token to the beginning of prompts.
    Disabling this can make the replies more creative."""

    truncation_length: Optional[int] = 2048
    """Truncate the prompt up to this length. The leftmost tokens are removed if
    the prompt exceeds this length. Most models require this to be at most 2048."""

    ban_eos_token: bool = Field(False, alias="ban_eos_token")
    """Ban the eos_token. Forces the model to never end the generation prematurely."""

    skip_special_tokens: bool = Field(True, alias="skip_special_tokens")
    """Skip special tokens. Some specific models need this unset."""

    stopping_strings: Optional[List[str]] = []
    """A list of strings to stop generation when encountered."""

    streaming: bool = False
    """Whether to stream the results, token by token."""

    @property
    def _default_params(self) -> Dict[str, Any]:
        """Get the default parameters for calling textgen."""
        return {
            "max_new_tokens": self.max_new_tokens,
            "do_sample": self.do_sample,
            "temperature": self.temperature,
            "top_p": self.top_p,
            "typical_p": self.typical_p,
            "epsilon_cutoff": self.epsilon_cutoff,
            "eta_cutoff": self.eta_cutoff,
            "repetition_penalty": self.repetition_penalty,
            "top_k": self.top_k,
            "min_length": self.min_length,
            "no_repeat_ngram_size": self.no_repeat_ngram_size,
            "num_beams": self.num_beams,
            "penalty_alpha": self.penalty_alpha,
            "length_penalty": self.length_penalty,
            "early_stopping": self.early_stopping,
            "seed": self.seed,
            "add_bos_token": self.add_bos_token,
            "truncation_length": self.truncation_length,
            "ban_eos_token": self.ban_eos_token,
            "skip_special_tokens": self.skip_special_tokens,
            "stopping_strings": self.stopping_strings,
        }

    @property
    def _identifying_params(self) -> Dict[str, Any]:
        """Get the identifying parameters."""
        return {**{"model_url": self.model_url}, **self._default_params}

    @property
    def _llm_type(self) -> str:
        """Return type of llm."""
        return "textgen"

    def _get_parameters(self, stop: Optional[List[str]] = None) -> Dict[str, Any]:
        """
        Performs sanity check, preparing parameters in format needed by textgen.

        Args:
            stop (Optional[List[str]]): List of stop sequences for textgen.

        Returns:
            Dictionary containing the combined parameters.
        """

        # Raise error if stop sequences are in both input and default params
        # if self.stop and stop is not None:
        if self.stopping_strings and stop is not None:
            raise ValueError("`stop` found in both the input and default params.")

        if self.preset is None:
            params = self._default_params
        else:
            params = {"preset": self.preset}

        # then sets it as configured, or default to an empty list:
        params["stopping_strings"] = self.stopping_strings or stop or []

        return params

    def _call(
        self,
        prompt: str,
        stop: Optional[List[str]] = None,
        run_manager: Optional[CallbackManagerForLLMRun] = None,
        **kwargs: Any,
    ) -> str:
        """Call the textgen web API and return the output.

        Args:
            prompt: The prompt to use for generation.
            stop: A list of strings to stop generation when encountered.

        Returns:
            The generated text.

        Example:
            .. code-block:: python

                from langchain_community.llms import TextGen
                llm = TextGen(model_url="http://localhost:5000")
                llm.invoke("Write a story about llamas.")
        """
        if self.streaming:
            combined_text_output = ""
            for chunk in self._stream(
                prompt=prompt, stop=stop, run_manager=run_manager, **kwargs
            ):
                combined_text_output += chunk.text
            result = combined_text_output

        else:
            url = f"{self.model_url}/api/v1/generate"
            params = self._get_parameters(stop)
            request = params.copy()
            request["prompt"] = prompt
            response = requests.post(url, json=request)

            if response.status_code == 200:
                result = response.json()["results"][0]["text"]
            else:
                print(f"ERROR: response: {response}")  # noqa: T201
                result = ""

        return result

    async def _acall(
        self,
        prompt: str,
        stop: Optional[List[str]] = None,
        run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
        **kwargs: Any,
    ) -> str:
        """Call the textgen web API and return the output.

        Args:
            prompt: The prompt to use for generation.
            stop: A list of strings to stop generation when encountered.

        Returns:
            The generated text.

        Example:
            .. code-block:: python

                from langchain_community.llms import TextGen
                llm = TextGen(model_url="http://localhost:5000")
                llm.invoke("Write a story about llamas.")
        """
        if self.streaming:
            combined_text_output = ""
            async for chunk in self._astream(
                prompt=prompt, stop=stop, run_manager=run_manager, **kwargs
            ):
                combined_text_output += chunk.text
            result = combined_text_output

        else:
            url = f"{self.model_url}/api/v1/generate"
            params = self._get_parameters(stop)
            request = params.copy()
            request["prompt"] = prompt
            response = requests.post(url, json=request)

            if response.status_code == 200:
                result = response.json()["results"][0]["text"]
            else:
                print(f"ERROR: response: {response}")  # noqa: T201
                result = ""

        return result

    def _stream(
        self,
        prompt: str,
        stop: Optional[List[str]] = None,
        run_manager: Optional[CallbackManagerForLLMRun] = None,
        **kwargs: Any,
    ) -> Iterator[GenerationChunk]:
        """Yields results objects as they are generated in real time.

        It also calls the callback manager's on_llm_new_token event with
        similar parameters to the OpenAI LLM class method of the same name.

        Args:
            prompt: The prompts to pass into the model.
            stop: Optional list of stop words to use when generating.

        Returns:
            A generator representing the stream of tokens being generated.

        Yields:
            A dictionary like objects containing a string token and metadata.
            See text-generation-webui docs and below for more.

        Example:
            .. code-block:: python

                from langchain_community.llms import TextGen
                llm = TextGen(
                    model_url = "ws://localhost:5005"
                    streaming=True
                )
                for chunk in llm.stream("Ask 'Hi, how are you?' like a pirate:'",
                        stop=["'","\n"]):
                    print(chunk, end='', flush=True)  # noqa: T201

        """
        try:
            import websocket
        except ImportError:
            raise ImportError(
                "The `websocket-client` package is required for streaming."
            )

        params = {**self._get_parameters(stop), **kwargs}

        url = f"{self.model_url}/api/v1/stream"

        request = params.copy()
        request["prompt"] = prompt

        websocket_client = websocket.WebSocket()

        websocket_client.connect(url)

        websocket_client.send(json.dumps(request))

        while True:
            result = websocket_client.recv()
            result = json.loads(result)

            if result["event"] == "text_stream":  # type: ignore[call-overload, index]
                chunk = GenerationChunk(
                    text=result["text"],  # type: ignore[call-overload, index]
                    generation_info=None,
                )
                if run_manager:
                    run_manager.on_llm_new_token(token=chunk.text)
                yield chunk
            elif result["event"] == "stream_end":  # type: ignore[call-overload, index]
                websocket_client.close()
                return

    async def _astream(
        self,
        prompt: str,
        stop: Optional[List[str]] = None,
        run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
        **kwargs: Any,
    ) -> AsyncIterator[GenerationChunk]:
        """Yields results objects as they are generated in real time.

        It also calls the callback manager's on_llm_new_token event with
        similar parameters to the OpenAI LLM class method of the same name.

        Args:
            prompt: The prompts to pass into the model.
            stop: Optional list of stop words to use when generating.

        Returns:
            A generator representing the stream of tokens being generated.

        Yields:
            A dictionary like objects containing a string token and metadata.
            See text-generation-webui docs and below for more.

        Example:
            .. code-block:: python

                from langchain_community.llms import TextGen
                llm = TextGen(
                    model_url = "ws://localhost:5005"
                    streaming=True
                )
                for chunk in llm.stream("Ask 'Hi, how are you?' like a pirate:'",
                        stop=["'","\n"]):
                    print(chunk, end='', flush=True)  # noqa: T201

        """
        try:
            import websocket
        except ImportError:
            raise ImportError(
                "The `websocket-client` package is required for streaming."
            )

        params = {**self._get_parameters(stop), **kwargs}

        url = f"{self.model_url}/api/v1/stream"

        request = params.copy()
        request["prompt"] = prompt

        websocket_client = websocket.WebSocket()

        websocket_client.connect(url)

        websocket_client.send(json.dumps(request))

        while True:
            result = websocket_client.recv()
            result = json.loads(result)

            if result["event"] == "text_stream":  # type: ignore[call-overload, index]
                chunk = GenerationChunk(
                    text=result["text"],  # type: ignore[call-overload, index]
                    generation_info=None,
                )
                if run_manager:
                    await run_manager.on_llm_new_token(token=chunk.text)
                yield chunk
            elif result["event"] == "stream_end":  # type: ignore[call-overload, index]
                websocket_client.close()
                return
