from __future__ import annotations

import json
from abc import ABC, abstractmethod
from enum import Enum
from typing import Any, Dict, Iterator, List, Mapping, Optional

from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models.llms import LLM
from langchain_core.outputs import GenerationChunk
from langchain_core.utils import pre_init
from pydantic import BaseModel, ConfigDict, Field

from langchain_community.llms.utils import enforce_stop_tokens

CUSTOM_ENDPOINT_PREFIX = "ocid1.generativeaiendpoint"


class Provider(ABC):
    @property
    @abstractmethod
    def stop_sequence_key(self) -> str: ...

    @abstractmethod
    def completion_response_to_text(self, response: Any) -> str: ...


class CohereProvider(Provider):
    stop_sequence_key: str = "stop_sequences"

    def __init__(self) -> None:
        from oci.generative_ai_inference import models

        self.llm_inference_request = models.CohereLlmInferenceRequest

    def completion_response_to_text(self, response: Any) -> str:
        return response.data.inference_response.generated_texts[0].text


class MetaProvider(Provider):
    stop_sequence_key: str = "stop"

    def __init__(self) -> None:
        from oci.generative_ai_inference import models

        self.llm_inference_request = models.LlamaLlmInferenceRequest

    def completion_response_to_text(self, response: Any) -> str:
        return response.data.inference_response.choices[0].text


class OCIAuthType(Enum):
    """OCI authentication types as enumerator."""

    API_KEY = 1
    SECURITY_TOKEN = 2
    INSTANCE_PRINCIPAL = 3
    RESOURCE_PRINCIPAL = 4


class OCIGenAIBase(BaseModel, ABC):
    """Base class for OCI GenAI models"""

    client: Any = Field(default=None, exclude=True)  #: :meta private:

    auth_type: Optional[str] = "API_KEY"
    """Authentication type, could be 
    
    API_KEY, 
    SECURITY_TOKEN, 
    INSTANCE_PRINCIPAL, 
    RESOURCE_PRINCIPAL

    If not specified, API_KEY will be used
    """

    auth_profile: Optional[str] = "DEFAULT"
    """The name of the profile in ~/.oci/config
    If not specified , DEFAULT will be used 
    """

    model_id: Optional[str] = None
    """Id of the model to call, e.g., cohere.command"""

    provider: Optional[str] = None
    """Provider name of the model. Default to None, 
    will try to be derived from the model_id
    otherwise, requires user input
    """

    model_kwargs: Optional[Dict] = None
    """Keyword arguments to pass to the model"""

    service_endpoint: Optional[str] = None
    """service endpoint url"""

    compartment_id: Optional[str] = None
    """OCID of compartment"""

    is_stream: bool = False
    """Whether to stream back partial progress"""

    model_config = ConfigDict(
        extra="forbid", arbitrary_types_allowed=True, protected_namespaces=()
    )

    @pre_init
    def validate_environment(cls, values: Dict) -> Dict:
        """Validate that OCI config and python package exists in environment."""

        # Skip creating new client if passed in constructor
        if values["client"] is not None:
            return values

        try:
            import oci

            client_kwargs = {
                "config": {},
                "signer": None,
                "service_endpoint": values["service_endpoint"],
                "retry_strategy": oci.retry.DEFAULT_RETRY_STRATEGY,
                "timeout": (10, 240),  # default timeout config for OCI Gen AI service
            }

            if values["auth_type"] == OCIAuthType(1).name:
                client_kwargs["config"] = oci.config.from_file(
                    profile_name=values["auth_profile"]
                )
                client_kwargs.pop("signer", None)
            elif values["auth_type"] == OCIAuthType(2).name:

                def make_security_token_signer(oci_config):  # type: ignore[no-untyped-def]
                    pk = oci.signer.load_private_key_from_file(
                        oci_config.get("key_file"), None
                    )
                    with open(
                        oci_config.get("security_token_file"), encoding="utf-8"
                    ) as f:
                        st_string = f.read()
                    return oci.auth.signers.SecurityTokenSigner(st_string, pk)

                client_kwargs["config"] = oci.config.from_file(
                    profile_name=values["auth_profile"]
                )
                client_kwargs["signer"] = make_security_token_signer(
                    oci_config=client_kwargs["config"]
                )
            elif values["auth_type"] == OCIAuthType(3).name:
                client_kwargs["signer"] = (
                    oci.auth.signers.InstancePrincipalsSecurityTokenSigner()
                )
            elif values["auth_type"] == OCIAuthType(4).name:
                client_kwargs["signer"] = (
                    oci.auth.signers.get_resource_principals_signer()
                )
            else:
                raise ValueError(
                    "Please provide valid value to auth_type, "
                    f"{values['auth_type']} is not valid."
                )

            values["client"] = oci.generative_ai_inference.GenerativeAiInferenceClient(
                **client_kwargs
            )

        except ImportError as ex:
            raise ModuleNotFoundError(
                "Could not import oci python package. "
                "Please make sure you have the oci package installed."
            ) from ex
        except Exception as e:
            raise ValueError(
                """Could not authenticate with OCI client. 
                Please check if ~/.oci/config exists.
                If INSTANCE_PRINCIPAL or RESOURCE_PRINCIPAL is used, 
                please check the specified
                auth_profile and auth_type are valid.""",
                e,
            ) from e

        return values

    @property
    def _identifying_params(self) -> Mapping[str, Any]:
        """Get the identifying parameters."""
        _model_kwargs = self.model_kwargs or {}
        return {
            **{"model_kwargs": _model_kwargs},
        }

    def _get_provider(self, provider_map: Mapping[str, Any]) -> Any:
        if self.provider is not None:
            provider = self.provider
        else:
            if self.model_id is None:
                raise ValueError(
                    "model_id is required to derive the provider, "
                    "please provide the provider explicitly or specify "
                    "the model_id to derive the provider."
                )
            provider = self.model_id.split(".")[0].lower()

        if provider not in provider_map:
            raise ValueError(
                f"Invalid provider derived from model_id: {self.model_id} "
                "Please explicitly pass in the supported provider "
                "when using custom endpoint"
            )
        return provider_map[provider]


class OCIGenAI(LLM, OCIGenAIBase):
    """OCI large language models.

    To authenticate, the OCI client uses the methods described in
    https://docs.oracle.com/en-us/iaas/Content/API/Concepts/sdk_authentication_methods.htm

    The authentifcation method is passed through auth_type and should be one of:
    API_KEY (default), SECURITY_TOKEN, INSTANCE_PRINCIPAL, RESOURCE_PRINCIPAL

    Make sure you have the required policies (profile/roles) to
    access the OCI Generative AI service.
    If a specific config profile is used, you must pass
    the name of the profile (from ~/.oci/config) through auth_profile.

    To use, you must provide the compartment id
    along with the endpoint url, and model id
    as named parameters to the constructor.

    Example:
        .. code-block:: python

            from langchain_community.llms import OCIGenAI

            llm = OCIGenAI(
                    model_id="MY_MODEL_ID",
                    service_endpoint="https://inference.generativeai.us-chicago-1.oci.oraclecloud.com",
                    compartment_id="MY_OCID"
                )
    """

    model_config = ConfigDict(
        extra="forbid",
        arbitrary_types_allowed=True,
    )

    @property
    def _llm_type(self) -> str:
        """Return type of llm."""
        return "oci_generative_ai_completion"

    @property
    def _provider_map(self) -> Mapping[str, Any]:
        """Get the provider map"""
        return {
            "cohere": CohereProvider(),
            "meta": MetaProvider(),
        }

    @property
    def _provider(self) -> Any:
        """Get the internal provider object"""
        return self._get_provider(provider_map=self._provider_map)

    def _prepare_invocation_object(
        self, prompt: str, stop: Optional[List[str]], kwargs: Dict[str, Any]
    ) -> Dict[str, Any]:
        from oci.generative_ai_inference import models

        _model_kwargs = self.model_kwargs or {}
        if stop is not None:
            _model_kwargs[self._provider.stop_sequence_key] = stop

        if self.model_id is None:
            raise ValueError(
                "model_id is required to call the model, "
                "please provide the model_id."
            )

        if self.model_id.startswith(CUSTOM_ENDPOINT_PREFIX):
            serving_mode = models.DedicatedServingMode(endpoint_id=self.model_id)
        else:
            serving_mode = models.OnDemandServingMode(model_id=self.model_id)

        inference_params = {**_model_kwargs, **kwargs}
        inference_params["prompt"] = prompt
        inference_params["is_stream"] = self.is_stream

        invocation_obj = models.GenerateTextDetails(
            compartment_id=self.compartment_id,
            serving_mode=serving_mode,
            inference_request=self._provider.llm_inference_request(**inference_params),
        )

        return invocation_obj

    def _process_response(self, response: Any, stop: Optional[List[str]]) -> str:
        text = self._provider.completion_response_to_text(response)

        if stop is not None:
            text = enforce_stop_tokens(text, stop)

        return text

    def _call(
        self,
        prompt: str,
        stop: Optional[List[str]] = None,
        run_manager: Optional[CallbackManagerForLLMRun] = None,
        **kwargs: Any,
    ) -> str:
        """Call out to OCIGenAI generate endpoint.

        Args:
            prompt: The prompt to pass into the model.
            stop: Optional list of stop words to use when generating.

        Returns:
            The string generated by the model.

        Example:
            .. code-block:: python

               response = llm.invoke("Tell me a joke.")
        """
        if self.is_stream:
            text = ""
            for chunk in self._stream(prompt, stop, run_manager, **kwargs):
                text += chunk.text
            if stop is not None:
                text = enforce_stop_tokens(text, stop)
            return text

        invocation_obj = self._prepare_invocation_object(prompt, stop, kwargs)
        response = self.client.generate_text(invocation_obj)
        return self._process_response(response, stop)

    def _stream(
        self,
        prompt: str,
        stop: Optional[List[str]] = None,
        run_manager: Optional[CallbackManagerForLLMRun] = None,
        **kwargs: Any,
    ) -> Iterator[GenerationChunk]:
        """Stream OCIGenAI LLM on given prompt.

        Args:
            prompt: The prompt to pass into the model.
            stop: Optional list of stop words to use when generating.

        Returns:
            An iterator of GenerationChunks.

        Example:
            .. code-block:: python

            response = llm.stream("Tell me a joke.")
        """

        self.is_stream = True
        invocation_obj = self._prepare_invocation_object(prompt, stop, kwargs)
        response = self.client.generate_text(invocation_obj)

        for event in response.data.events():
            json_load = json.loads(event.data)
            if "text" in json_load:
                event_data_text = json_load["text"]
            else:
                event_data_text = ""
            chunk = GenerationChunk(text=event_data_text)
            if run_manager:
                run_manager.on_llm_new_token(chunk.text, chunk=chunk)
            yield chunk
