import logging
from operator import itemgetter
from typing import (
    Any,
    Callable,
    Dict,
    List,
    Literal,
    Optional,
    Sequence,
    Type,
    Union,
    cast,
)
from uuid import uuid4

import requests
from langchain.schema import AIMessage, ChatGeneration, ChatResult, HumanMessage
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models import LanguageModelInput
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import (
    AIMessageChunk,
    BaseMessage,
    SystemMessage,
    ToolCall,
    ToolMessage,
)
from langchain_core.messages.tool import tool_call
from langchain_core.output_parsers import (
    JsonOutputParser,
    PydanticOutputParser,
)
from langchain_core.output_parsers.base import OutputParserLike
from langchain_core.output_parsers.openai_tools import (
    JsonOutputKeyToolsParser,
    PydanticToolsParser,
)
from langchain_core.runnables import Runnable, RunnablePassthrough
from langchain_core.runnables.base import RunnableMap
from langchain_core.tools import BaseTool
from langchain_core.utils.function_calling import convert_to_openai_tool
from langchain_core.utils.pydantic import is_basemodel_subclass
from pydantic import BaseModel, Field

# Initialize logging
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s - %(levelname)s - %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S",
)
_logger = logging.getLogger(__name__)


def _is_pydantic_class(obj: Any) -> bool:
    return isinstance(obj, type) and is_basemodel_subclass(obj)


def _convert_messages_to_cloudflare_messages(
    messages: List[BaseMessage],
) -> List[Dict[str, Any]]:
    """Convert LangChain messages to Cloudflare Workers AI format."""
    cloudflare_messages = []
    msg: Dict[str, Any]
    for message in messages:
        # Base structure for each message
        msg = {
            "role": "",
            "content": message.content if isinstance(message.content, str) else "",
        }

        # Determine role and additional fields based on message type
        if isinstance(message, HumanMessage):
            msg["role"] = "user"
        elif isinstance(message, AIMessage):
            msg["role"] = "assistant"
            # If the AIMessage includes tool calls, format them as needed
            if message.tool_calls:
                tool_calls = [
                    {"name": tool_call["name"], "arguments": tool_call["args"]}
                    for tool_call in message.tool_calls
                ]
                msg["tool_calls"] = tool_calls
        elif isinstance(message, SystemMessage):
            msg["role"] = "system"
        elif isinstance(message, ToolMessage):
            msg["role"] = "tool"
            msg["tool_call_id"] = (
                message.tool_call_id
            )  # Use tool_call_id if it's a ToolMessage

        # Add the formatted message to the list
        cloudflare_messages.append(msg)

    return cloudflare_messages


def _get_tool_calls_from_response(response: requests.Response) -> List[ToolCall]:
    """Get tool calls from ollama response."""
    tool_calls = []
    if "tool_calls" in response.json()["result"]:
        for tc in response.json()["result"]["tool_calls"]:
            tool_calls.append(
                tool_call(
                    id=str(uuid4()),
                    name=tc["name"],
                    args=tc["arguments"],
                )
            )
    return tool_calls


class ChatCloudflareWorkersAI(BaseChatModel):
    """Custom chat model for Cloudflare Workers AI"""

    account_id: str = Field(...)
    api_token: str = Field(...)
    model: str = Field(...)
    ai_gateway: str = ""
    url: str = ""
    base_url: str = "https://api.cloudflare.com/client/v4/accounts"
    gateway_url: str = "https://gateway.ai.cloudflare.com/v1"

    def __init__(self, **kwargs: Any) -> None:
        """Initialize with necessary credentials."""
        super().__init__(**kwargs)
        if self.ai_gateway:
            self.url = (
                f"{self.gateway_url}/{self.account_id}/"
                f"{self.ai_gateway}/workers-ai/run/{self.model}"
            )
        else:
            self.url = f"{self.base_url}/{self.account_id}/ai/run/{self.model}"

    def _generate(
        self,
        messages: List[BaseMessage],
        stop: Optional[List[str]] = None,
        run_manager: Optional[CallbackManagerForLLMRun] = None,
        **kwargs: Any,
    ) -> ChatResult:
        """Generate a response based on the messages provided."""
        formatted_messages = _convert_messages_to_cloudflare_messages(messages)

        headers = {"Authorization": f"Bearer {self.api_token}"}
        prompt = "\n".join(
            f"role: {msg['role']}, content: {msg['content']}"
            + (f", tools: {msg['tool_calls']}" if "tool_calls" in msg else "")
            + (
                f", tool_call_id: {msg['tool_call_id']}"
                if "tool_call_id" in msg
                else ""
            )
            for msg in formatted_messages
        )

        # Initialize `data` with `prompt`
        data = {
            "prompt": prompt,
            "tools": kwargs["tools"] if "tools" in kwargs else None,
            **{key: value for key, value in kwargs.items() if key not in ["tools"]},
        }

        # Ensure `tools` is a list if it's included in `kwargs`
        if data["tools"] is not None and not isinstance(data["tools"], list):
            data["tools"] = [data["tools"]]

        _logger.info(f"Sending prompt to Cloudflare Workers AI: {data}")

        response = requests.post(self.url, headers=headers, json=data)
        tool_calls = _get_tool_calls_from_response(response)
        ai_message = AIMessage(
            content=str(response.json()), tool_calls=cast(AIMessageChunk, tool_calls)
        )
        chat_generation = ChatGeneration(message=ai_message)
        return ChatResult(generations=[chat_generation])

    def bind_tools(
        self,
        tools: Sequence[Union[Dict[str, Any], Type, Callable[..., Any], BaseTool]],
        **kwargs: Any,
    ) -> Runnable[LanguageModelInput, BaseMessage]:
        """Bind tools for use in model generation."""
        formatted_tools = [convert_to_openai_tool(tool) for tool in tools]
        return super().bind(tools=formatted_tools, **kwargs)

    def with_structured_output(
        self,
        schema: Union[Dict, Type[BaseModel]],
        *,
        include_raw: bool = False,
        method: Optional[Literal["json_mode", "function_calling"]] = "function_calling",
        **kwargs: Any,
    ) -> Runnable[LanguageModelInput, Union[Dict, BaseModel]]:
        """Model wrapper that returns outputs formatted to match the given schema."""

        if kwargs:
            raise ValueError(f"Received unsupported arguments {kwargs}")
        is_pydantic_schema = _is_pydantic_class(schema)
        if method == "function_calling":
            if schema is None:
                raise ValueError(
                    "schema must be specified when method is 'function_calling'. "
                    "Received None."
                )
            tool_name = convert_to_openai_tool(schema)["function"]["name"]
            llm = self.bind_tools([schema], tool_choice=tool_name)
            if is_pydantic_schema:
                output_parser: OutputParserLike = PydanticToolsParser(
                    tools=[schema],  # type: ignore[list-item]
                    first_tool_only=True,  # type: ignore[list-item]
                )
            else:
                output_parser = JsonOutputKeyToolsParser(
                    key_name=tool_name, first_tool_only=True
                )
        elif method == "json_mode":
            llm = self.bind(response_format={"type": "json_object"})
            output_parser = (
                PydanticOutputParser(pydantic_object=schema)  # type: ignore[type-var, arg-type]
                if is_pydantic_schema
                else JsonOutputParser()
            )
        else:
            raise ValueError(
                f"Unrecognized method argument. Expected one of 'function_calling' or "
                f"'json_mode'. Received: '{method}'"
            )

        if include_raw:
            parser_assign = RunnablePassthrough.assign(
                parsed=itemgetter("raw") | output_parser, parsing_error=lambda _: None
            )
            parser_none = RunnablePassthrough.assign(parsed=lambda _: None)
            parser_with_fallback = parser_assign.with_fallbacks(
                [parser_none], exception_key="parsing_error"
            )
            return RunnableMap(raw=llm) | parser_with_fallback
        else:
            return llm | output_parser

    @property
    def _llm_type(self) -> str:
        """Return the type of the LLM (for Langchain compatibility)."""
        return "cloudflare-workers-ai"
