# mypy: allow-untyped-defs
import logging
import warnings
from typing import Any, Dict, Iterable, Optional, Tuple

import torch
import torch.export
import torch.export._trace
from torch._utils_internal import log_export_usage


log = logging.getLogger(__name__)

__all__ = ["report_exportability"]


def _generate_inputs_for_submodules(
    model: torch.nn.Module,
    target_submodules: Iterable[str],
    args: Tuple[Any, ...],
    kwargs: Optional[Dict[str, Any]] = None,
) -> Dict[str, Tuple[Any, Any]]:
    """
    Generate inputs for targeting submdoules in the given model. Note that if two submodules refer to the same obj, this
    function doesn't work.

    Args:
        model: root model.
        inputs: inputs to the root model.
        target_submodules: submodules that we want to generate inputs for.

    Returns:
        A dict that maps from submodule name to its inputs.
    """
    kwargs = kwargs or {}

    handles = []
    results = {}
    submodule_to_names = {mod: name for name, mod in model.named_modules()}

    def pre_forward(module, module_args, module_kwargs):
        results[submodule_to_names[module]] = (module_args, module_kwargs)

    try:
        for name, mod in model.named_modules():
            if name in target_submodules:
                handles.append(
                    mod.register_forward_pre_hook(pre_forward, with_kwargs=True)
                )
        model(*args, **kwargs)
    except Exception as e:
        warnings.warn(
            f"Failed to generate submodule inputs because of the following error:\n{e}"
        )
    finally:
        for h in handles:
            h.remove()
    return results


def report_exportability(
    mod: torch.nn.Module,
    args: Tuple[Any, ...],
    kwargs: Optional[Dict[str, Any]] = None,
    *,
    strict: bool = True,
    pre_dispatch: bool = False,
) -> Dict[str, Optional[Exception]]:
    """
    Report exportability issues for a module in one-shot.

    Args:
        mod: root module.
        args: args to the root module.
        kwargs: kwargs to the root module.
    Returns:
        A dict that maps from submodule name to the exception that was raised when trying to export it.
        `None` means the module is exportable without issue.
    Sample output:
        {
            '': UnsupportedOperatorException(func=<OpOverload(op='testlib.op_missing_meta', overload='default')>),
            'submod_1': UnsupportedOperatorException(func=<OpOverload(op='testlib.op_missing_meta', overload='default')>),
            'submod_2': None
        }
    """

    log_export_usage(event="export.report_exportability")

    kwargs = kwargs or {}

    all_submod_names = [name for name, _ in mod.named_modules() if name != ""]
    submod_inputs = _generate_inputs_for_submodules(mod, all_submod_names, args, kwargs)

    tried_module_types = set()
    report: Dict[str, Optional[Exception]] = {}

    def try_export(module, module_name, args, kwargs):
        nonlocal submod_inputs, report, strict, pre_dispatch, tried_module_types

        if type(module) in tried_module_types:
            return
        tried_module_types.add(type(module))

        if args is not None or kwargs is not None:
            try:
                torch.export._trace._export(
                    module,
                    args,
                    kwargs,
                    strict=strict,
                    pre_dispatch=pre_dispatch,
                )
                report[module_name] = None
                log.info("Successfully exported `%s`", module_name)
                return
            except Exception as e:
                short_msg = repr(e).split("\n")[0]
                log.warning(
                    "Failed exporting `%s` with exception: %s", module_name, short_msg
                )
                report[module_name] = e

        for name, submod in module.named_children():
            sub_module_name = name if module_name == "" else f"{module_name}.{name}"

            submod_args, submod_kwargs = submod_inputs.get(
                sub_module_name, (None, None)
            )

            try_export(submod, sub_module_name, submod_args, submod_kwargs)

        return

    try_export(mod, "", args, kwargs)

    unique_issues = set()
    for exception in report.values():
        if exception is not None:
            key = repr(exception).split("\\n")[0]
            unique_issues.add(key)

    log.warning("Found %d export issues:", len(unique_issues))
    for issue in unique_issues:
        log.warning(issue)

    return report
