# mypy: allow-untyped-defs
import torch
import torch.utils._pytree as pytree
from collections import namedtuple
import functools


# NOTE [CustomOp autograd kernel indirection]
# We register `inner` as the autograd kernel for this custom_op.
# `inner` either calls the autograd formula registered by the user,
# or goes into an `autograd_not_implemented` kernel.
#
# The reason why this indirection exists is
# so that we can swap out the autograd kernel (the PyTorch dispatcher
# doesn't actually allow us to do this). By default, we want
# the `autograd_not_implemented` behavior, but then the user may come
# and register something that is actually a backward formula
def autograd_kernel_indirection(custom_op):
    autograd_fallback = autograd_not_implemented(custom_op)

    def inner(*args, **kwargs):
        if custom_op._has_impl('autograd'):
            kernel = custom_op._get_impl('autograd').func
            return kernel(*args, **kwargs)
        # As explained in NOTE ["backward", "save_for_backward", and "autograd"],
        # after the user gives us "backward" and "save_for_backward", we generate
        # the "autograd" impl. If the user only provided one, then we tell
        # the user they've done something wrong.
        if custom_op._has_impl('save_for_backward') or custom_op._has_impl('backward'):
            missing = (
                'save_for_backward' if custom_op._has_impl('backward')
                else 'backward'
            )
            found = 'save_for_backward' if missing == 'backward' else 'backward'
            loc = custom_op._get_impl(found).location
            raise RuntimeError(
                f"We found a '{found}' registration for {custom_op} at "
                f"{loc} but were unable to find a '{missing}' registration. "
                f"To use the CustomOp API to register a backward formula, "
                f"please provide us both a backward function and a "
                f"'save for backward' function via `impl_backward` and "
                f"`impl_save_for_backward` respectively.")
        return autograd_fallback(*args, **kwargs)
    return inner


# TODO(#101191): Use the actual C++ autograd not implemented fallback,
# or change the default autograd fallback to the autograd not implemented fallback.
def autograd_not_implemented(custom_op):
    def kernel(*args, **kwargs):
        if torch.is_grad_enabled() and pytree.tree_any(
            lambda x: isinstance(x, torch.Tensor) and x.requires_grad, (args, kwargs)
        ):
            raise RuntimeError("Autograd has not been implemented for operator")
        with torch._C._AutoDispatchBelowAutograd():
            return custom_op(*args, **kwargs)
    return kernel


def mark_non_differentiable(ctx, output, output_differentiability):
    # Output types are restricted to be:
    # - Tensor
    # - Tensor[]
    # - int, bool, Scalar, float
    # See _check_can_register_backward
    if output_differentiability is not None:
        if not isinstance(output, tuple):
            tuple_output = (output,)
        else:
            tuple_output = output  # type: ignore[assignment]
        assert len(output_differentiability) == len(tuple_output)
        non_differentiable_tensors = []
        for idx, (differentiable, out) in enumerate(zip(output_differentiability, tuple_output)):
            if isinstance(out, torch.Tensor):
                if not differentiable:
                    non_differentiable_tensors.append(out)
                continue
            if isinstance(out, list):
                if not differentiable:
                    non_differentiable_tensors.extend(out)
                continue
            if differentiable:
                raise RuntimeError(
                    f"With output_differentiability={output_differentiability}. "
                    f"At idx {idx}, we received an object of type {type(out)} that "
                    f"is not a Tensor, so it cannot have be marked as differentiable in "
                    f"output_differentiability.")
        if non_differentiable_tensors:
            ctx.mark_non_differentiable(*non_differentiable_tensors)


def construct_autograd_kernel(
        schema,
        output_differentiability,
        custom_op,
        op_overload,
        save_for_backward_fn,
        backward_fn):

    def apply(*args):
        flat_args, spec = pytree.tree_flatten(args)
        out_spec = None

        def forward(ctx, *flat_args):
            ctx.set_materialize_grads(True)
            args = pytree.tree_unflatten(list(flat_args), spec)
            with torch._C._AutoDispatchBelowAutograd():
                output = op_overload(*args)

            # We use the info about args to give better error messages in backward
            args_info = namedtuple_args(
                schema, pytree.tree_map(type, args))

            save_for_backward_fn_inputs = namedtuple_args(schema, args)
            to_save = save_for_backward_fn(save_for_backward_fn_inputs, output)

            save_pytree_for_backward(ctx, (to_save, args_info))
            mark_non_differentiable(ctx, output, output_differentiability)

            nonlocal out_spec
            flat_output, out_spec = pytree.tree_flatten(output)
            return tuple(flat_output)

        def backward(ctx, *flat_grad_output):
            assert out_spec is not None
            grads = pytree.tree_unflatten(list(flat_grad_output), out_spec)
            saved, args_info = unpack_saved(ctx)
            # There is nothing on the ctx object for now, it is just there so
            # that we can add additional things in the future.
            inner_ctx = object()
            if not isinstance(grads, tuple):
                grads = (grads,)
            grad_inputs_dict = backward_fn(inner_ctx, saved, *grads)

            # Massage the grad_inputs_dict to a form acceptable by
            # autograd.Function.
            validate_grad_inputs_dict(grad_inputs_dict, custom_op, args_info)
            return grad_inputs_dict_to_flat_tuple(grad_inputs_dict, args_info)

        generated_cls = gen_autograd_function(
            custom_op._opname + '_customop', forward, backward)

        flat_output = generated_cls.apply(*flat_args)
        assert out_spec is not None
        return pytree.tree_unflatten(list(flat_output), out_spec)
    return apply


def gen_autograd_function(name, forward, backward):
    generated_cls = type(
        name,
        (torch.autograd.Function,),
        {
            'forward': staticmethod(forward),
            'backward': staticmethod(backward),
        }
    )
    return generated_cls


@functools.lru_cache
def namedtuple_args_cls(schema):
    attribs = [arg.name for arg in schema.arguments.flat_all]
    name = str(schema.name) + "_args"
    # mypy doesn't support dynamic namedtuple name
    tuple_cls = namedtuple(name, attribs)  # type: ignore[misc]
    return tuple_cls


def namedtuple_args(schema, args):
    assert isinstance(args, tuple)
    tuple_cls = namedtuple_args_cls(schema)
    return tuple_cls(*args)


def validate_grad_inputs_dict(grad_inputs_dict, forward_op, args_info):
    def error(what):
        backward = forward_op._get_impl('backward')
        raise RuntimeError(
            f"In the backward function defined for {forward_op} at "
            f"{backward.location} using the CustomOp API, {what}")

    if not isinstance(grad_inputs_dict, dict):
        error(f"expected the output of the backward function to be a dict but "
              f"got {type(grad_inputs_dict)}")

    expected_keys = {arg.name for arg in forward_op._schema.arguments.flat_all
                     if arg.type.is_tensor_like()}
    actual_keys = grad_inputs_dict.keys()
    if expected_keys != actual_keys:
        error(f"expected the returned grad_input dict to have keys "
              f"{expected_keys} but got {actual_keys}. The backward "
              f"function must return a gradient (can be None) for each arg "
              f"to the CustomOp that may be a Tensor or Sequence[Tensor]. "
              f"Args declared to be non-Tensor-like types should not appear "
              f"in the grad_input dict")

    for name, grad in grad_inputs_dict.items():
        arg_info = getattr(args_info, name)

        if isinstance(arg_info, list):
            if not isinstance(grad, (tuple, list)):
                error(f"for input '{name}' expected the grad_input dict to "
                      f"hold a list of gradients but got object of type "
                      f"{type(grad)}.")
            if not len(grad) == len(arg_info):
                error(f"for input '{name}' expected the grad_input dict to "
                      f"hold a list of {len(arg_info)} gradients but got "
                      f"{len(grad)}")
            for idx, (g, info) in enumerate(zip(grad, arg_info)):
                if g is None:
                    continue
                if not isinstance(g, torch.Tensor):
                    error(f"for input '{name}' expected the grad_input dict to "
                          f"hold a list of None or Tensor gradients but got "
                          f"object of {type(g)} at index {idx}")
                if not issubclass(info, torch.Tensor):
                    error(f"for input '{name}', got a Tensor as the gradient "
                          f"for the {idx}-th value but expected None because "
                          f"the {idx}-th value was not a Tensor (it was "
                          f"type {arg_info}")
            continue

        if grad is None:
            continue
        if not isinstance(grad, torch.Tensor):
            error(f"got object of type {type(grad)} as the gradient for input "
                  f"'{name}', "
                  f"but expected the gradient to be either None or a Tensor")
        if not issubclass(arg_info, torch.Tensor):
            error(f"got a Tensor as the gradient for input '{name}' but "
                  f"expected None as the gradient because input '{name}' "
                  f"was not a Tensor (it was type {arg_info}).")


def grad_inputs_dict_to_flat_tuple(grad_inputs_dict, args_info):
    result = []
    for name, arg_info in args_info._asdict().items():
        if name not in grad_inputs_dict:
            result.append(pytree.tree_map(lambda x: None, arg_info))
            continue
        result.append(grad_inputs_dict[name])
    return tuple(pytree.tree_leaves(result))

# Saves "stuff" (a pytree) onto the ctx object. Use unpack_saved to unpack it.
# autograd.Function prefers that users use ctx.save_for_backward to
# save Tensors (to avoid reference cycles) and for non-Tensors to go onto the
# ctx object.
def save_pytree_for_backward(ctx, stuff):
    flat_stuff, spec = pytree.tree_flatten(stuff)
    num_elts = len(flat_stuff)
    tensor_idxs = [idx for idx, thing in enumerate(flat_stuff)
                   if isinstance(thing, torch.Tensor)]
    non_tensor_idxs = [idx for idx, thing in enumerate(flat_stuff)
                       if not isinstance(thing, torch.Tensor)]
    tensors = [thing for thing in flat_stuff if isinstance(thing, torch.Tensor)]
    non_tensors = [thing for thing in flat_stuff if not isinstance(thing, torch.Tensor)]

    ctx.spec = spec
    ctx.num_elts = num_elts
    ctx.save_for_backward(*tensors)
    ctx.tensor_idxs = tensor_idxs
    ctx.saved_non_tensors = non_tensors
    ctx.non_tensor_idxs = non_tensor_idxs


# Inverse operation to save_pytree_for_backward
def unpack_saved(ctx):
    flat_stuff = [None] * ctx.num_elts
    for tensor, idx in zip(ctx.saved_tensors, ctx.tensor_idxs):
        flat_stuff[idx] = tensor
    for non_tensor, idx in zip(ctx.saved_non_tensors, ctx.non_tensor_idxs):
        flat_stuff[idx] = non_tensor
    stuff = pytree.tree_unflatten(flat_stuff, ctx.spec)
    return stuff
