# Represents all kernels used by an Executorch model.
# It maintains a Dict[OperatorName, Dict[ETKernelKey, BackendMetadata]] structure.

from __future__ import annotations

import itertools
from collections import defaultdict, namedtuple
from dataclasses import dataclass
from enum import IntEnum

from torchgen.model import (
    BackendIndex,
    BackendMetadata,
    DispatchKey,
    NativeFunction,
    NativeFunctionsGroup,
    OperatorName,
)
from torchgen.utils import assert_never


KERNEL_KEY_VERSION = 1


# TODO: Duplicated Subset from codegen.tool.gen_oplist, remove declaration in codegen
class ScalarType(IntEnum):
    Byte = 0
    Char = 1
    Short = 2
    Int = 3
    Long = 4
    Float = 6
    Double = 7
    Bool = 11


ETParsedYaml = namedtuple("ETParsedYaml", ["native_functions", "kernel_index"])


@dataclass(frozen=True)
class ETKernelKeyOpArgMeta:
    arg_name: str
    dtype: str
    # The order of the dimensions if entry is a Tensor
    dim_order: tuple[int, ...]

    def to_native_string(self) -> str:
        dtype_str = ScalarType[self.dtype].value
        dim_str = str(self.dim_order)[1:-1].replace(" ", "")
        return f"{dtype_str};{dim_str}"


@dataclass(frozen=True)
class ETKernelKey:
    # Field undefined is default = True
    arg_meta: tuple[ETKernelKeyOpArgMeta, ...] = ()

    # Indicator for this kernel being used as a catch all
    default: bool = False

    version: int = KERNEL_KEY_VERSION

    @staticmethod
    def gen_from_yaml(
        args: dict[str, tuple[str, str]],
        type_alias_map: dict[str, list[str]],  # TODO: Support unwrapped str val
        dim_order_alias_map: dict[str, list[int]],
    ) -> list[ETKernelKey]:
        """Generate ETKernelKeys from arg kernel specs
        Multiple ETKernelKeys are returned due to dtype permutations from utilizing
        type_alias_map (actualizing each potential type permutation as a KernelKey)

        Args:
            args: Mapping from argument name to kernel specs
                Kernel specs are a tuple of (dtype, dim_order).
                Currently tuple entries must be aliased via the alias map arguments
            type_alias_map: Mapping from type alias to potential type enums
                i.e { T0 : [Double, Int] } means T0 can be either Double or Int
                Used for lookup by args
            dim_order_alias_map: Mapping from alias to a list of dimension orders
                Used for lookup by args
        """
        # Cast to dim order to int
        dim_order_alias_map = {
            k: [int(alias) for alias in v] for k, v in dim_order_alias_map.items()
        }
        kernel_keys = []

        # Get all used Dtype Alias
        dtype_alias_used = set()
        for type_alias, dim_order in args.values():
            # Enforce usage of alias initially
            # TODO: Support inlined arguments
            assert type_alias in type_alias_map, "Undefined type alias: " + str(
                type_alias
            )
            assert (
                dim_order in dim_order_alias_map
            ), "Undefined dim_order alias: " + str(dim_order)
            dtype_alias_used.add(type_alias)

        # Generate all permutations of dtype alias values
        alias_dtypes = [
            [(alias, dtype) for dtype in type_alias_map[alias]]
            for alias in dtype_alias_used
        ]
        alias_permutations = [
            dict(permutation) for permutation in list(itertools.product(*alias_dtypes))
        ]

        # Using each alias value permutation, generate kernel keys
        op_arg_cache = {}
        for permutation in alias_permutations:
            arg_list = []
            for arg_name, arg_spec in args.items():
                dtype = permutation[arg_spec[0]]
                dim_order = dim_order_alias_map[arg_spec[1]]  # type: ignore[assignment]
                if (
                    cache_key := (arg_name, dtype, tuple(dim_order))
                ) not in op_arg_cache:
                    op_arg_cache[cache_key] = ETKernelKeyOpArgMeta(*cache_key)  # type: ignore[arg-type]

                arg_list.append(op_arg_cache[cache_key])
            kernel_keys.append(ETKernelKey(tuple(arg_list)))

        return kernel_keys

    def to_native_string(self) -> str:
        if self.default:
            return "default"
        return (
            "v"
            + str(KERNEL_KEY_VERSION)
            + "/"
            + "|".join([arg.to_native_string() for arg in self.arg_meta])
        )


@dataclass(frozen=True)
class ETKernelIndex:
    index: dict[OperatorName, dict[ETKernelKey, BackendMetadata]]

    def has_kernels(self, g: NativeFunction | NativeFunctionsGroup) -> bool:
        m = self.get_kernels(g)
        return m is not None

    def get_kernels(
        self, g: NativeFunction | NativeFunctionsGroup
    ) -> dict[ETKernelKey, BackendMetadata]:
        if isinstance(g, NativeFunction):
            f = g
        elif isinstance(g, NativeFunctionsGroup):
            f = g.functional
        else:
            assert_never(g)
        if f.func.name not in self.index:
            return {}
        return self.index[f.func.name]

    @staticmethod
    def grow_from_backend_indices(
        kernel_index: dict[OperatorName, dict[ETKernelKey, BackendMetadata]],
        backend_indices: dict[DispatchKey, dict[OperatorName, BackendMetadata]],
    ) -> None:
        for dk in backend_indices:
            index = backend_indices[dk]
            for op, backend_metadata in index.items():
                if op in kernel_index:
                    kernel_index[op][ETKernelKey(default=True)] = backend_metadata
                else:
                    kernel_index[op] = {ETKernelKey(default=True): backend_metadata}

    @staticmethod
    def from_backend_indices(
        backend_indices: dict[DispatchKey, dict[OperatorName, BackendMetadata]]
    ) -> ETKernelIndex:
        kernel_index: dict[
            OperatorName, dict[ETKernelKey, BackendMetadata]
        ] = defaultdict(dict)
        ETKernelIndex.grow_from_backend_indices(kernel_index, backend_indices)
        return ETKernelIndex(kernel_index)

    def grow(
        self, backend_indices: dict[DispatchKey, dict[OperatorName, BackendMetadata]]
    ) -> ETKernelIndex:
        ETKernelIndex.grow_from_backend_indices(self.index, backend_indices)
        return self

    def _to_backend_index(self) -> BackendIndex:
        """
        WARNING: this will be deprecated once all the codegen places know how to handle ETKernelIndex.
        """
        index: dict[OperatorName, BackendMetadata] = {}
        for op in self.index:
            kernel_dict = self.index[op]
            assert (
                len(kernel_dict.values()) == 1
            ), f"Can't convert ETKernelIndex to BackendIndex because {op} has more than one kernels. Got {kernel_dict}"
            index[op] = kernel_dict.get(
                ETKernelKey(default=True),
                BackendMetadata(kernel="", structured=False, cpp_namespace=""),
            )
        return BackendIndex(
            dispatch_key=DispatchKey.CPU,
            use_out_as_primary=False,
            device_guard=False,
            external=False,
            index=index,
        )

    # Note duplicate ETKernelKey from index_b will clobber the metadata from index_a
    @staticmethod
    def merge_indices(index_a: ETKernelIndex, index_b: ETKernelIndex) -> ETKernelIndex:
        combined = defaultdict(dict, index_a.index.copy())

        for op, entry in index_b.index.items():
            for key, metadata in entry.items():
                combined[op][key] = metadata

        return ETKernelIndex(combined)
