"""Wrapper around Anyscale Endpoint"""

from typing import (
    Any,
    Dict,
    List,
    Mapping,
    Optional,
    Set,
)

from langchain_core.callbacks import (
    AsyncCallbackManagerForLLMRun,
    CallbackManagerForLLMRun,
)
from langchain_core.outputs import Generation, GenerationChunk, LLMResult
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env, pre_init
from pydantic import Field, SecretStr

from langchain_community.llms.openai import (
    BaseOpenAI,
    acompletion_with_retry,
    completion_with_retry,
)
from langchain_community.utils.openai import is_openai_v1

DEFAULT_BASE_URL = "https://api.endpoints.anyscale.com/v1"
DEFAULT_MODEL = "mistralai/Mixtral-8x7B-Instruct-v0.1"


def update_token_usage(
    keys: Set[str], response: Dict[str, Any], token_usage: Dict[str, Any]
) -> None:
    """Update token usage."""
    _keys_to_use = keys.intersection(response["usage"])
    for _key in _keys_to_use:
        if _key not in token_usage:
            token_usage[_key] = response["usage"][_key]
        else:
            token_usage[_key] += response["usage"][_key]


def create_llm_result(
    choices: Any, prompts: List[str], token_usage: Dict[str, int], model_name: str
) -> LLMResult:
    """Create the LLMResult from the choices and prompts."""
    generations = []
    for i, _ in enumerate(prompts):
        choice = choices[i]
        generations.append(
            [
                Generation(
                    text=choice["message"]["content"],
                    generation_info=dict(
                        finish_reason=choice.get("finish_reason"),
                        logprobs=choice.get("logprobs"),
                    ),
                )
            ]
        )
    llm_output = {"token_usage": token_usage, "model_name": model_name}
    return LLMResult(generations=generations, llm_output=llm_output)


class Anyscale(BaseOpenAI):  # type: ignore[override]
    """Anyscale large language models.

    To use, you should have the environment variable ``ANYSCALE_API_KEY``set with your
    Anyscale Endpoint, or pass it as a named parameter to the constructor.
    To use with Anyscale Private Endpoint, please also set ``ANYSCALE_BASE_URL``.

    Example:
        .. code-block:: python
            from langchain.llms import Anyscale
            anyscalellm = Anyscale(anyscale_api_key="ANYSCALE_API_KEY")
            # To leverage Ray for parallel processing
            @ray.remote(num_cpus=1)
            def send_query(llm, text):
                resp = llm.invoke(text)
                return resp
            futures = [send_query.remote(anyscalellm, text) for text in texts]
            results = ray.get(futures)
    """

    """Key word arguments to pass to the model."""
    anyscale_api_base: str = Field(default=DEFAULT_BASE_URL)
    anyscale_api_key: SecretStr = Field(default=SecretStr(""))
    model_name: str = Field(default=DEFAULT_MODEL)

    prefix_messages: List = Field(default_factory=list)

    @classmethod
    def is_lc_serializable(cls) -> bool:
        return False

    @pre_init
    def validate_environment(cls, values: Dict) -> Dict:
        """Validate that api key and python package exists in environment."""
        values["anyscale_api_base"] = get_from_dict_or_env(
            values,
            "anyscale_api_base",
            "ANYSCALE_API_BASE",
            default=DEFAULT_BASE_URL,
        )
        values["anyscale_api_key"] = convert_to_secret_str(
            get_from_dict_or_env(values, "anyscale_api_key", "ANYSCALE_API_KEY")
        )
        values["model_name"] = get_from_dict_or_env(
            values,
            "model_name",
            "MODEL_NAME",
            default=DEFAULT_MODEL,
        )

        try:
            import openai

            if is_openai_v1():
                client_params = {
                    "api_key": values["anyscale_api_key"].get_secret_value(),
                    "base_url": values["anyscale_api_base"],
                    # To do: future support
                    # "organization": values["openai_organization"],
                    # "timeout": values["request_timeout"],
                    # "max_retries": values["max_retries"],
                    # "default_headers": values["default_headers"],
                    # "default_query": values["default_query"],
                    # "http_client": values["http_client"],
                }
                if not values.get("client"):
                    values["client"] = openai.OpenAI(**client_params).completions
                if not values.get("async_client"):
                    values["async_client"] = openai.AsyncOpenAI(
                        **client_params
                    ).completions
            else:
                values["openai_api_base"] = values["anyscale_api_base"]
                values["openai_api_key"] = values["anyscale_api_key"].get_secret_value()
                values["client"] = openai.Completion  # type: ignore[attr-defined]
        except ImportError:
            raise ImportError(
                "Could not import openai python package. "
                "Please install it with `pip install openai`."
            )
        if values["streaming"] and values["n"] > 1:
            raise ValueError("Cannot stream results when n > 1.")
        if values["streaming"] and values["best_of"] > 1:
            raise ValueError("Cannot stream results when best_of > 1.")

        return values

    @property
    def _identifying_params(self) -> Mapping[str, Any]:
        """Get the identifying parameters."""
        return {
            **{"model_name": self.model_name},
            **super()._identifying_params,
        }

    @property
    def _invocation_params(self) -> Dict[str, Any]:
        """Get the parameters used to invoke the model."""
        openai_creds: Dict[str, Any] = {
            "model": self.model_name,
        }
        if not is_openai_v1():
            openai_creds.update(
                {
                    "api_key": self.anyscale_api_key.get_secret_value(),
                    "api_base": self.anyscale_api_base,
                }
            )
        return {**openai_creds, **super()._invocation_params}

    @property
    def _llm_type(self) -> str:
        """Return type of llm."""
        return "Anyscale LLM"

    def _generate(
        self,
        prompts: List[str],
        stop: Optional[List[str]] = None,
        run_manager: Optional[CallbackManagerForLLMRun] = None,
        **kwargs: Any,
    ) -> LLMResult:
        """Call out to OpenAI's endpoint with k unique prompts.

        Args:
            prompts: The prompts to pass into the model.
            stop: Optional list of stop words to use when generating.

        Returns:
            The full LLM output.

        Example:
            .. code-block:: python

                response = openai.generate(["Tell me a joke."])
        """
        # TODO: write a unit test for this
        params = self._invocation_params
        params = {**params, **kwargs}
        sub_prompts = self.get_sub_prompts(params, prompts, stop)
        choices = []
        token_usage: Dict[str, int] = {}
        # Get the token usage from the response.
        # Includes prompt, completion, and total tokens used.
        _keys = {"completion_tokens", "prompt_tokens", "total_tokens"}
        system_fingerprint: Optional[str] = None
        for _prompts in sub_prompts:
            if self.streaming:
                if len(_prompts) > 1:
                    raise ValueError("Cannot stream results with multiple prompts.")

                generation: Optional[GenerationChunk] = None
                for chunk in self._stream(_prompts[0], stop, run_manager, **kwargs):
                    if generation is None:
                        generation = chunk
                    else:
                        generation += chunk
                assert generation is not None
                choices.append(
                    {
                        "text": generation.text,
                        "finish_reason": generation.generation_info.get("finish_reason")
                        if generation.generation_info
                        else None,
                        "logprobs": generation.generation_info.get("logprobs")
                        if generation.generation_info
                        else None,
                    }
                )
            else:
                response = completion_with_retry(
                    ## THis is the ONLY change from BaseOpenAI()._generate()
                    self,
                    prompt=_prompts[0],
                    run_manager=run_manager,
                    **params,
                )
                if not isinstance(response, dict):
                    # V1 client returns the response in an PyDantic object instead of
                    # dict. For the transition period, we deep convert it to dict.
                    response = response.dict()

                choices.extend(response["choices"])
                update_token_usage(_keys, response, token_usage)
                if not system_fingerprint:
                    system_fingerprint = response.get("system_fingerprint")
        return self.create_llm_result(
            choices,
            prompts,
            params,
            token_usage,
            system_fingerprint=system_fingerprint,
        )

    async def _agenerate(
        self,
        prompts: List[str],
        stop: Optional[List[str]] = None,
        run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
        **kwargs: Any,
    ) -> LLMResult:
        """Call out to OpenAI's endpoint async with k unique prompts."""
        params = self._invocation_params
        params = {**params, **kwargs}
        sub_prompts = self.get_sub_prompts(params, prompts, stop)
        choices = []
        token_usage: Dict[str, int] = {}
        # Get the token usage from the response.
        # Includes prompt, completion, and total tokens used.
        _keys = {"completion_tokens", "prompt_tokens", "total_tokens"}
        system_fingerprint: Optional[str] = None
        for _prompts in sub_prompts:
            if self.streaming:
                if len(_prompts) > 1:
                    raise ValueError("Cannot stream results with multiple prompts.")

                generation: Optional[GenerationChunk] = None
                async for chunk in self._astream(
                    _prompts[0], stop, run_manager, **kwargs
                ):
                    if generation is None:
                        generation = chunk
                    else:
                        generation += chunk
                assert generation is not None
                choices.append(
                    {
                        "text": generation.text,
                        "finish_reason": generation.generation_info.get("finish_reason")
                        if generation.generation_info
                        else None,
                        "logprobs": generation.generation_info.get("logprobs")
                        if generation.generation_info
                        else None,
                    }
                )
            else:
                response = await acompletion_with_retry(
                    ## THis is the ONLY change from BaseOpenAI()._agenerate()
                    self,
                    prompt=_prompts[0],
                    run_manager=run_manager,
                    **params,
                )
                if not isinstance(response, dict):
                    response = response.dict()
                choices.extend(response["choices"])
                update_token_usage(_keys, response, token_usage)
        return self.create_llm_result(
            choices,
            prompts,
            params,
            token_usage,
            system_fingerprint=system_fingerprint,
        )
