from __future__ import annotations

from dataclasses import dataclass
from typing import Sequence

from torchgen.api import cpp
from torchgen.api.types import Binding, CppSignature, CppSignatureGroup
from torchgen.gen import pythonify_default
from torchgen.model import (
    Argument,
    BaseTy,
    BaseType,
    FunctionSchema,
    ListType,
    NativeFunction,
    OptionalType,
    Return,
    Type,
    Variant,
)


# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
#
#                           Data Models
#
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
#
# [Notes] python binding codegen
#
# The Python binding codegen produces code that takes the input list of
# PyObjects, finds the matching ATen C++ function using PythonArgParser,
# converts the PyObjects into C++ types and calls the ATen C++ function:
#
# +--------+  parsing   +------------------------+  binding   +-----------------------+
# | PyObjs | ---------> | PythonArgParser Output | ---------> | Cpp Function Dispatch |
# +--------+            +------------------------+            +-----------------------+
#
# The following examples demonstrate the data models the Python binding
# codegen needs to deal with and the tasks it needs to accomplish. It
# helps understand the purpose of the new data types we introduced below.
#
#  - Function Schema (source of truth)
#
#      aten::empty.names(int[] size, *, Dimname[]? names,
#                        ScalarType? dtype=None, Layout? layout=None,
#                        Device? device=None, bool? pin_memory=None,
#                        MemoryFormat? memory_format=None) -> Tensor
#
#  - Python Signature
#
#    It's used to generate input schema string for PythonArgParser.
#    Note: TensorOptions fields are reordered and the additional
#    'requires_grad' field is added:
#
#      empty(IntArrayRef size, *, DimnameList? names,
#            MemoryFormat? memory_format=None, ScalarType dtype=None,
#            Layout layout=torch.strided, Device device=None,
#            bool pin_memory=False, bool requires_grad=False)
#
#  - C++ Signature
#
#    It's used to generate C++ lambda formals & dispatch call.
#    Note: the scattered TensorOptions fields are packed into 'options'.
#
#      auto dispatch_empty =
#          [](IntArrayRef size, std::optional<DimnameList> names,
#             const TensorOptions & options,
#             std::optional<MemoryFormat> memory_format) -> Tensor {
#          pybind11::gil_scoped_release no_gil;
#          return torch::empty(size, names, options, memory_format);
#      };
#
#  - Binding between Python Arguments and C++ Arguments
#
#    Given a set of Python Arguments in scope, we need produce the
#    binding expressions that translate the Python API into C++ API:
#
#            Python Args               Cpp Args       Binding Exprs
#     -----------------------------------------------------------------
#         0: size                      size           '_r.intlist(0)'
#         1: names                     names          'names' [special init]
#         2: memory_format -------+
#         3: dtype         -----+-|--> options        'options' [special packing]
#         4: layout            /  |
#         5: device           /   +--> memory_format  '_r.memoryformatOptional(2)'
#         6: pin_memory      /
#         7: requires_grad -+
#
#    So the full dispatch expression would look like:
#
#      dispatch_empty(_r.intlist(0), names, options,
#                     _r.memoryformatOptional(2))
#
#    Where does 'names' come from? It involves special local init:
#
#      auto __names = _r.toDimnameListOptional(1);
#      std::optional<DimnameList> names =
#          __names ? std::make_optional(DimnameList(__names.value()))
#                  : std::nullopt;
#
#    Where does 'options' come from? It involves special local init
#    for TensorOptions. Note that Python side has the additional
#    'requires_grad' field:
#
#      const auto options = TensorOptions()
#          .dtype(_r.scalartype(3))
#          .device(_r.device(5))
#          .layout(_r.layoutOptional(4))
#          .requires_grad(_r.toBool(7))
#          .pinned_memory(_r.toBool(6));
#
#    In some other cases one Python Argument can map to multiple C++
#    Arguments. For example:
#
#     aten::max.names_dim(Tensor self, Dimname dim, bool keepdim=False)
#       -> (Tensor values, Tensor indices)
#
#            Python Args               Cpp Args          Binding Exprs
#     ---------------------------------------------------------------------
#                               +----> max               'out[0]'
#                              /-----> max_values        'out[1]
#         0: input            /        self              '_r.tensor(0)'
#         1: dim             /         dim               '_r.dimname(1)'
#         2: keepdim        /          keepdim           '_r.toBool(2)'
#         3: out      -----+           [local init] out  '_r.tensorlist_n<2>(3)'
#
#    As demonstrated above, the binding can involve reordering,
#    packing, unpacking and special local inits.
#
#
#  Let's look at a concrete example:
#
#      static PythonArgParser parser({
#        "abs(Tensor input, *, Tensor out=None)",
#        ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
#         ^
#         +--- Python Schema, represented by PythonSignature and PythonArgument
#
#      }, /*traceable=*/true);
#
#      ParsedArgs<2> parsed_args;
#      auto _r = parser.parse(nullptr, args, kwargs, parsed_args);
#
#      ...
#
#      if (_r.isNone(1)) {
#          ~~~~~~~~~~~~  <--- Scattered PythonArgParser output (arg name = 'out')
#                             represented by PythonArgParserOutputExpr
#
#        // aten::abs(Tensor self) -> Tensor
#        ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
#         ^
#         +--- NativeFunction schema, base version
#
#        auto dispatch_abs = [](const Tensor & self) -> Tensor {
#                            ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
#                             ^
#                             +--- dispatch_lambda_args / dispatch_lambda_return_str
#                                  generated from NativeFunction / CppSignature
#                                  (deprecated PythonSignature is special)
#                                  arguments are represented by DispatchLambdaArgument
#
#          pybind11::gil_scoped_release no_gil;
#          return self.abs();
#                 ~~~~~~~~~~~  <--- cpp_dispatch_target / cpp_dispatch_exprs
#                                   generated from NativeFunction / CppSignature
#        };
#        return wrap(dispatch_abs(_r.tensor(0)));
#                                 ~~~~~~~~~~~~~
#                                  ^
#                                  +--- dispatch_lambda_exprs
#                                       binding PythonArgParserOutputExpr (python args)
#                                       and DispatchLambdaArgument (c++ args)
#
#      } else {
#        // aten::abs.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
#        ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
#         ^
#         +--- NativeFunction schema, out-variant
#
#        auto dispatch_abs_out = [](Tensor out, const Tensor & self) -> Tensor {
#          pybind11::gil_scoped_release no_gil;
#          return at::abs_out(out, self);
#        };
#        return wrap(dispatch_abs_out(_r.tensor(1), _r.tensor(0)));
#      }
#
#
# [Notes] python interface codegen
# The python dataclasses below are used used to generate both python binding code
# and pyi type hint signatures.
# In theory these two should look very similar, but there are number of differences
# in how pyi signatures vs. python_arg_parser signatures are generated.
# These differences have been encapsulated in signature_str() vs. signature_str_pyi()
# to display the full signatures, and argument_str() vs argument_str_pyi() to display arguments.
# For examples, only pyi signatures include return types.


@dataclass(frozen=True)
class PythonReturns:
    returns: tuple[Return, ...]


@dataclass(frozen=True)
class PythonArgument:
    name: str
    type: Type
    default: str | None

    # Used to generate the default init expr for some PythonArgParser outputs, e.g.:
    #
    #   _r.layoutWithDefault(3, layout_from_backend(self.options().backend())))
    #                           ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
    #                            ^
    #                            +--- default_init str
    default_init: str | None

    # Compute argument formal for python argument parsing.
    # Needs to be consistent with torch/csrc/utils/python_arg_parser.h.
    def argument_str(self, *, method: bool = False, symint: bool = True) -> str:
        type_str = (
            argument_type_str(self.type, symint=symint)
            .replace("const ", "")
            .replace(" &", "")
        )

        name = self.name
        # s/self/input/ outside method bindings
        # [old codegen] TODO: remove this? doesn't rename in codegen, it's just
        # for the parse string
        if name == "self" and type_str in ["Tensor", "Number"] and not method:
            name = "input"

        # add default
        if self.default is not None:
            default = {
                "nullptr": "None",
                "::std::nullopt": "None",
                "std::nullopt": "None",
                "{}": "None",
            }.get(self.default, self.default)
            return f"{type_str} {name}={default}"
        else:
            return f"{type_str} {name}"

    def argument_str_pyi(
        self, *, method: bool = False, deprecated: bool = False
    ) -> str:
        type_str = argument_type_str_pyi(self.type)

        name = self.name
        # s/self/input/ outside method bindings
        # [old codegen] TODO: remove this? doesn't rename in codegen, it's just
        # for the parse string
        if name == "self" and type_str == "Tensor" and not method and not deprecated:
            name = "input"

        if name == "from":  # from is a Python keyword...
            name += "_"

        # pyi merges the _out and functional variants into the same signature, with an optional out arg
        if name == "out" and type_str == "Tensor" and not deprecated:
            type_str = "Optional[" + type_str + "]"

        # pyi deprecated signatures don't get defaults for their out arg
        treat_as_no_default = (
            deprecated
            and isinstance(self, PythonOutArgument)
            and self.default == "None"
        )

        # add default
        if self.default is not None and not treat_as_no_default:
            if (
                isinstance(self.type, ListType)
                and self.type.elem == BaseType(BaseTy.int)
                and self.default.startswith("{")
                and self.default.endswith("}")
            ):
                default = (
                    "(" + ", ".join(map(str.strip, self.default[1:-1].split(","))) + ")"
                )
            else:
                default = {
                    "nullptr": "None",
                    "::std::nullopt": "None",
                    "std::nullopt": "None",
                    "{}": "None",
                    "c10::MemoryFormat::Contiguous": "contiguous_format",
                    "QScheme::PER_TENSOR_AFFINE": "per_tensor_affine",
                }.get(self.default, self.default)
            return f"{name}: {type_str} = {default}"
        else:
            return f"{name}: {type_str}"


@dataclass(frozen=True)
class PythonOutArgument(PythonArgument):
    # In Python signature multiple output fields are packed into one 'out' argument.
    # When binding to C++, it's first binded to a local 'out' variable:
    #   'auto out = _r.tensorlist_n<2>(2);',
    # then binded to scattered C++ output arguments as 'out[0]', 'out[1]', and etc.
    # TODO: maybe don't need keep scattered out fields for python signature?
    outputs: tuple[PythonArgument, ...]

    @staticmethod
    def from_outputs(outputs: tuple[PythonArgument, ...]) -> PythonOutArgument | None:
        if not outputs:
            return None

        size = len(outputs)
        if size == 1:
            return PythonOutArgument(
                name=outputs[0].name,
                type=outputs[0].type,
                default="None",
                default_init=None,
                outputs=outputs,
            )
        elif size > 1:
            if any(not a.type.is_tensor_like() for a in outputs):
                raise RuntimeError(f"Unsupported output type: {outputs}")
            return PythonOutArgument(
                name="out",
                # TODO: shouldn't this be OptionalType[ListType[...]], since it defaults to None?
                type=ListType(BaseType(BaseTy.Tensor), size),
                default="None",
                default_init=None,
                outputs=outputs,
            )
        raise AssertionError(r"Unexpected PythonOutArgument size")


@dataclass(frozen=True)
class PythonSignature:
    # Base operator name, without inplace/outplace suffix.
    name: str

    # Positional arguments.
    # TODO: create a dedicated SelfArgument type for 'self'?
    input_args: tuple[PythonArgument, ...]

    # Keyword arguments excluding the 'out' argument and scattered kwargs belonging
    # to TensorOptions (dtype, layout, device, pin_memory, requires_grad, etc).
    input_kwargs: tuple[PythonArgument, ...]

    output_args: PythonOutArgument | None

    # Return types, which are only used by pyi
    returns: PythonReturns

    # These are scattered kwargs arguments belonging to TensorOptions.
    # When binding to C++, they are packed into a TensorOptions object 'options'.
    # It's possible that the C++ signature doesn't take TensorOptions object (e.g.
    # for out variant), in which case they will be used as scattered fields without
    # being packed into 'options'.
    # TODO: maybe create a PythonTensorOptionsArgument?
    tensor_options_args: tuple[PythonArgument, ...]

    # method or function signature?
    method: bool

    @property
    def deprecated(self) -> bool:
        return False

    def arguments(
        self, *, skip_outputs: bool = False, skip_tensor_options: bool = False
    ) -> tuple[PythonArgument | PythonOutArgument, ...]:
        result: list[PythonArgument | PythonOutArgument] = []
        result.extend(self.input_args)
        result.extend(self.input_kwargs)
        if self.output_args is not None and not skip_outputs:
            result.append(self.output_args)
        if not skip_tensor_options:
            result.extend(self.tensor_options_args)
        return tuple(result)

    def arguments_count(self) -> int:
        return len(self.arguments())

    def output_idx(self) -> int:
        return len(self.input_args) + len(self.input_kwargs)

    # [old codegen] Compute the Python function signature for argument parsing,
    # as specified in torch/csrc/utils/python_arg_parser.h.  WARNING:
    # this is NOT the same type signature as specified by PEP 484
    # as understood by mypy; our format was independently developed
    # and has some quirks to make it more suitable specifically
    # for error parsing.
    #
    # For a translation to mypy-valid type signatures, see
    # signature_str_pyi().
    def signature_str(self, *, skip_outputs: bool = False, symint: bool = True) -> str:
        args = self.arguments(skip_outputs=skip_outputs)
        schema_formals: list[str] = [
            a.argument_str(method=self.method, symint=symint) for a in args
        ]
        positional_argc = len(self.input_args)
        if len(schema_formals) > positional_argc:
            schema_formals.insert(positional_argc, "*")

        return f'{self.name}({", ".join(schema_formals)})'

    def signature_str_pyi(self, *, skip_outputs: bool = False) -> str:
        args = self.arguments(skip_outputs=skip_outputs)
        schema_formals: list[str] = [
            a.argument_str_pyi(method=self.method) for a in args
        ]
        positional_argc = len(self.input_args)
        if len(schema_formals) > positional_argc:
            schema_formals.insert(positional_argc, "*")

        # only pyi signatures include returns
        returns_str = returns_str_pyi(self)
        # pyi also includes self (with no typing/defaults) for methods
        if self.method:
            schema_formals.insert(0, "self")
        return f'def {self.name}({", ".join(schema_formals)}) -> {returns_str}: ...'

    def signature_str_pyi_vararg(self, *, skip_outputs: bool = False) -> str | None:
        # only pyi uses vararg signatures
        args = self.arguments(skip_outputs=skip_outputs)
        schema_formals: list[str] = [
            a.argument_str_pyi(method=self.method) for a in args
        ]
        # vararg only applies to pyi signatures. vararg variants are not generated for all signatures
        num_args = self.arguments_count()
        num_positionalargs = len(self.input_args)

        have_vararg_version = False
        if num_args > 0:
            vararg_type = args[0].type
            if (
                isinstance(vararg_type, ListType)
                and str(vararg_type.elem) in ["int", "SymInt"]
                and num_positionalargs == 1
            ):
                have_vararg_version = True

        if not have_vararg_version:
            return None

        # Below are the major changes in vararg vs. regular pyi signatures
        # vararg signatures also omit the asterix
        assert isinstance(vararg_type, ListType)
        schema_formals[0] = (
            "*" + args[0].name + ": " + argument_type_str_pyi(vararg_type.elem)
        )

        returns_str = returns_str_pyi(self)
        # pyi also includes self (with no typing/defaults) for methods
        if self.method:
            schema_formals.insert(0, "self")
        return f'def {self.name}({", ".join(schema_formals)}) -> {returns_str}: ...'


# The deprecated python signature involves some special logic, so create a
# dedicated data model to store these extra properties.
@dataclass(frozen=True)
class PythonSignatureDeprecated(PythonSignature):
    # Schema for the deprecated function
    deprecated_schema: FunctionSchema

    # The deprecated signature might miss some arguments that the corresponding
    # C++ signature expects. We need store the constant default values to pass in.
    # For example:
    #   [deprecate signature]: addmm(Scalar beta, Tensor self, Tensor mat1, Tensor mat2)
    #   [func schema]: aten::addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor
    #   [func call]: self.addmm(mat1, mat2, beta, 1)
    # We store ['self', 'mat1', 'mat2', 'beta', '1'] in this case.
    deprecated_args_exprs: tuple[str, ...]

    @property
    def deprecated(self) -> bool:
        return True

    def signature_str(self, *, skip_outputs: bool = False, symint: bool = True) -> str:
        return (
            PythonSignature.signature_str(
                self, skip_outputs=skip_outputs, symint=symint
            )
            + "|deprecated"
        )

    def signature_str_pyi(self, *, skip_outputs: bool = False) -> str:
        args = self.arguments(skip_outputs=skip_outputs)
        schema_formals: list[str] = [
            a.argument_str_pyi(method=self.method, deprecated=True) for a in args
        ]
        positional_argc = len(self.input_args)
        if len(schema_formals) > positional_argc:
            schema_formals.insert(positional_argc, "*")

        returns_str = returns_str_pyi(self)
        return f'def {self.name}({", ".join(schema_formals)}) -> {returns_str}: ...'

    def signature_str_pyi_vararg(self, *, skip_outputs: bool = False) -> str | None:
        # the codegen doesn't include vararg variants for deprecated signatures
        return None


# This struct is used to hold the PythonSignature and its corresponding
# NativeFunction BEFORE grouping base and out-variant functions.
# Why not store NativeFunction in PythonSignature or construct PythonSignature
# from NativeFunction? Because they are not 1-1 mapped.
# One native function could have both deprecated and non-deprecated python
# signatures - NativeFunction doesn't contain information to construct the
# deprecated python signature.
# One python signature is used to handle both the base and the out-variant
# function - see 'PythonSignatureGroup'.
@dataclass(frozen=True)
class PythonSignatureNativeFunctionPair:
    signature: PythonSignature
    function: NativeFunction


# We merge pairs of functions with signatures that are equivalent mod
# output arguments, and use a single entry in the python_arg_parser sig
# list for both (output arguments become optional).
@dataclass(frozen=True)
class PythonSignatureGroup:
    # The signature used for Python argument parsing. The outplace signature
    # is preferred if exists, because it can be used to parse inputs for both
    # the out-place variant and the base version (with output omitted).
    signature: PythonSignature

    # The regular ATen declaration (e.g. conv2d)
    base: NativeFunction

    # The out variant (e.g. conv2d_out)
    outplace: NativeFunction | None

    @classmethod
    def from_pairs(
        cls,
        functional: PythonSignatureNativeFunctionPair,
        out: PythonSignatureNativeFunctionPair | None,
    ) -> PythonSignatureGroup:
        if out is None:
            return PythonSignatureGroup(
                signature=functional.signature,
                base=functional.function,
                outplace=None,
            )

        # prefer the signature with optional out=... arguments because it's the
        # superset that can be used to parse input for both base and outplace.
        signature_kwargs = out.signature.__dict__.copy()

        # Out overloads in C++ don't have TensorOptions arguments,
        # so take these from the functional variant
        signature_kwargs[
            "tensor_options_args"
        ] = functional.signature.tensor_options_args

        return PythonSignatureGroup(
            signature=type(out.signature)(**signature_kwargs),
            base=functional.function,
            outplace=out.function,
        )


# C++ function dispatch is wrapped in a lambda function. The lambda function
# has almost the same signature as the C++ function, only with some small
# variants - see details below.
# This data model is used to represent arguments of the lambda function
# signature.
@dataclass(frozen=True)
class DispatchLambdaArgument:
    name: str
    type_str: str
    is_out_arg: bool


# To pass PyObjects arguments to C++ function (via the lambda wrapper),
# we need first convert PyObjects into simple C++ objects. This work
# is done by PythonArgParser.
# This data model is used to represent the output of PythonArgParser.
# It has 1-1 mapping with PythonArgument in PythonSignature.
@dataclass(frozen=True)
class PythonArgParserOutputExpr:
    # argument name
    name: str

    # RHS expression to reference PythonArgParser output.
    expr: str

    # In some special cases we need create different expr, e.g.:
    # '_r.isNone(1)' instead of '_r.tensor(1)'.
    index: int

    # The python argument it maps to.
    argument: PythonArgument

    @property
    def is_none_expr(self) -> str:
        return f"_r.isNone({self.index})"


# To pass PythonArgParser output to the lambda wrapper, we need bind
# PythonArgParserOutputExpr to DispatchLambdaArgument.
# They are not always 1-1 mapped, e.g. scattered TensorOptions fields
# need be packed into a TensorOptions object, which is the argument
# that the lambda function wrapper takes.
@dataclass(frozen=True)
class DispatchLambdaArgumentExprs:
    # The exprs that provide the binding for lambda arguments, e.g.:
    #
    #   'self' -> '_r.tensor(0)'
    #   'min' -> 'out[0]' / 'min_indices' -> 'out[1]'
    #   'options' -> 'options'
    #
    # It has 1-1 mapping with DispatchLambdaArgument.
    exprs: Sequence[str]

    # Special local inits, which might introduce new variables that
    # the 'exprs' above reference, e.g.:
    #
    #   'auto out = _r.tensorlist_n<2>(2);'
    #
    inits: Sequence[str]


# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
#
#                          Helper Functions
#
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #


def _cpp_signature(f: NativeFunction, *, method: bool = False) -> CppSignature:
    return CppSignatureGroup.from_native_function(f, method=method).signature


def has_tensor_options(f: NativeFunction) -> bool:
    return f.func.arguments.tensor_options is not None


# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
#
#                          Python Signature
#
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #


# 'simple_type' was introduced by the old codegen, which is slightly
# different from the python schema type, e.g.: doesn't have '?' suffix
# for optional Tensor/TensorList; doesn't have '[size]' suffix for list type.
def argument_type_str(
    t: Type, *, simple_type: bool = False, symint: bool = True
) -> str:
    if isinstance(t, BaseType):
        if t.name == BaseTy.Tensor:
            return "Tensor"
        elif t.name == BaseTy.int:
            return "int64_t"
        elif t.name == BaseTy.float:
            return "double"
        elif t.name == BaseTy.str:
            return "c10::string_view"
        elif t.name in [
            BaseTy.bool,
            BaseTy.QScheme,
            BaseTy.Scalar,
            BaseTy.ScalarType,
            BaseTy.Generator,
            BaseTy.Storage,
            BaseTy.Layout,
            BaseTy.Device,
            BaseTy.DeviceIndex,
            BaseTy.MemoryFormat,
            BaseTy.Dimname,
            BaseTy.Stream,
            BaseTy.ConstQuantizerPtr,
            BaseTy.SymInt,
        ]:
            # These python schema type names line up with their function schema names
            return t.name.name

    elif isinstance(t, OptionalType):
        if str(t.elem) == "Tensor":
            # Is it desired to keep '?' for simple_type with new style dispatcher?
            return "Tensor?"
        elem = argument_type_str(t.elem, simple_type=simple_type, symint=symint)
        return f"{elem}?"
    elif isinstance(t, ListType):
        size = t.size if not simple_type else None
        if str(t.elem) == "bool":
            assert t.size is not None
            return f"::std::array<bool,{t.size}>"
        elif str(t.elem) == "int":
            return f"IntArrayRef[{size}]" if size is not None else "IntArrayRef"
        elif str(t.elem) == "SymInt":
            if symint:
                return (
                    f"SymIntArrayRef[{size}]" if size is not None else "SymIntArrayRef"
                )
            else:
                return f"IntArrayRef[{size}]" if size is not None else "IntArrayRef"
        elif str(t.elem) == "Tensor":
            return f"TensorList[{size}]" if size is not None else "TensorList"
        elif str(t.elem) == "Scalar":
            return f"ScalarList[{size}]" if size is not None else "ScalarList"
        elif str(t.elem) == "Tensor?":
            if simple_type:
                return "c10::List<::std::optional<Tensor>>"
            else:
                return "const c10::List<::std::optional<Tensor>> &"
        elif str(t.elem) == "Dimname":
            return f"DimnameList[{size}]" if size is not None else "DimnameList"
        elem = argument_type_str(t.elem, simple_type=simple_type, symint=symint)
        return f"ArrayRef<{elem}>"

    raise RuntimeError(f"unrecognized type {repr(t)}")


def argument_type_size(t: Type) -> int | None:
    l = t.is_list_like()
    if l is not None and str(l.elem) != "bool":
        return l.size
    else:
        return None


def argument(a: Argument) -> PythonArgument:
    return PythonArgument(
        name=a.name,
        type=a.type,
        # TODO: directly translate a.default to python default
        default=(
            str(pythonify_default(cpp.default_expr(a.default, a.type, symint=False)))
            if a.default is not None
            else None
        ),
        default_init=None,
    )


# Generates a PythonSignature that can be used for either .pyi or PythonArgParser codegen
def signature(
    f: NativeFunction, *, method: bool = False, pyi: bool = False
) -> PythonSignature:
    return signature_from_schema(
        f.func, category_override=f.category_override, method=method, pyi=pyi
    )


def signature_from_schema(
    func: FunctionSchema,
    *,
    category_override: str | None,
    method: bool = False,
    pyi: bool = False,
) -> PythonSignature:
    args: list[Argument] = []
    args.extend(func.arguments.pre_self_positional)
    # Skip SelfArgument if this is method.
    if not method and func.arguments.self_arg is not None:
        args.append(func.arguments.self_arg.argument)
    args.extend(func.arguments.post_self_positional)
    args.extend(func.arguments.pre_tensor_options_kwarg_only)
    # Skip TensorOptionsArguments. Python side TensorOptions
    # arguments are created based on different rules - see below.
    args.extend(func.arguments.post_tensor_options_kwarg_only)
    args.extend(func.arguments.out)

    input_arg_set = {a.name for a in func.arguments.flat_positional}
    kwarg_only_set = {a.name for a in func.arguments.flat_kwarg_only}
    out_arg_set = {a.name for a in func.arguments.out}

    input_args = tuple(map(argument, filter(lambda a: a.name in input_arg_set, args)))
    input_kwargs = tuple(
        map(argument, filter(lambda a: a.name in kwarg_only_set, args))
    )
    outputs = tuple(map(argument, filter(lambda a: a.name in out_arg_set, args)))

    # Reintroduce the scattered fields of TensorOptions for Python.
    # Compared to the cpp counterpart, the python arguments have new property
    # (default_init) and a new argument 'requires_grad', which require some
    # special handlings.
    # [old codegen] TODO: because these aren't guaranteed to be 100% faithful
    # to the original versions in the yaml, this recreation is a potential
    # source of drift between eager and JIT. Pull this logic out to a shared place.

    has_tensor_input_arg = any(
        a.type.is_tensor_like() for a in func.arguments.flat_non_out
    )
    if any(a.name == "requires_grad" for a in func.schema_order_arguments()):
        raise ValueError(
            "argument named requires_grad is reserved, should not explicitly add it in the schema"
        )

    # [old codegen] this probably won't work if one of the returns is not a tensor,
    # but it will produce a compile-time error that is obvious.
    has_tensor_return = any(r.type.is_tensor_like() for r in func.returns)

    name: str = cpp.name(func)
    is_factory_function = category_override == "factory" or (
        has_tensor_return and not has_tensor_input_arg
    )
    is_like_or_new_function = (
        category_override in ("new", "like")
        or name.startswith("new_")
        or name.endswith("_like")
    )
    is_dummy_function = category_override == "dummy"

    tensor_options_args: list[PythonArgument] = []
    if (is_factory_function or is_like_or_new_function) and not is_dummy_function:

        def topt_default_init(name: str) -> str | None:
            topt_args = func.arguments.tensor_options
            if topt_args is None:
                return None
            a = getattr(topt_args, name)
            if a.default is None or a.default == "None":
                return None
            return cpp.default_expr(a.default, a.type, symint=False)

        tensor_options_args.append(
            PythonArgument(
                name="dtype",
                type=OptionalType(BaseType(BaseTy.ScalarType)),
                default="None",
                default_init=(
                    None if is_like_or_new_function else topt_default_init("dtype")
                ),
            )
        )
        tensor_options_args.append(
            PythonArgument(
                name="layout",
                type=OptionalType(BaseType(BaseTy.Layout)),
                default="None",
                default_init=(
                    None if is_like_or_new_function else topt_default_init("layout")
                ),
            )
        )
        tensor_options_args.append(
            PythonArgument(
                name="device",
                type=OptionalType(BaseType(BaseTy.Device)),
                default="None",
                default_init=(
                    None
                    if is_like_or_new_function
                    else (
                        topt_default_init("device")
                        or "torch::tensors::get_default_device()"
                    )
                ),
            )
        )
        tensor_options_args.append(
            PythonArgument(
                name="pin_memory",
                type=OptionalType(BaseType(BaseTy.bool)),
                default="False",
                default_init=None,
            )
        )
        tensor_options_args.append(
            PythonArgument(
                name="requires_grad",
                type=OptionalType(BaseType(BaseTy.bool)),
                default="False",
                default_init=None,
            )
        )

    returns = PythonReturns(returns=func.returns)

    return PythonSignature(
        name=str(func.name.name),
        input_args=input_args,
        input_kwargs=input_kwargs,
        output_args=PythonOutArgument.from_outputs(outputs),
        tensor_options_args=tuple(tensor_options_args),
        returns=returns,
        method=method,
    )


# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
#
#                          Python Interface
#
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #


def structseq_fieldnames(returns: tuple[Return, ...]) -> list[str]:
    if len(returns) <= 1 or all(r.name is None for r in returns):
        return []
    else:
        if any(r.name is None for r in returns):
            # When building on Windows, `PyStructSequence_UnnamedField` could not be
            # resolved by the linker for some reason, which cause error in building:
            #
            # python_nn_functions.cpp.obj : error LNK2001: unresolved external symbol
            # PyStructSequence_UnnamedField
            #
            # Thus, at this point in time, we do not support unnamed
            # fields in structseq; you must either name all fields,
            # or none of them.
            raise ValueError("Unnamed field is not supported by codegen")

        return [str(r.name) for r in returns]


def argument_type_str_pyi(t: Type) -> str:
    add_optional = False
    if isinstance(t, OptionalType):
        t = t.elem
        add_optional = True

    if isinstance(t, BaseType):
        if t.name in [BaseTy.int, BaseTy.DeviceIndex]:
            ret = "_int"
        if t.name == BaseTy.SymInt:
            ret = "Union[_int, SymInt]"
        elif t.name == BaseTy.float:
            ret = "_float"
        elif t.name == BaseTy.str:
            ret = "str"
        elif t.name == BaseTy.Scalar:
            ret = "Union[Number, _complex]"
        elif t.name == BaseTy.ScalarType:
            ret = "_dtype"
        elif t.name == BaseTy.bool:
            ret = "_bool"
        elif t.name == BaseTy.QScheme:
            ret = "_qscheme"
        elif t.name == BaseTy.Layout:
            ret = "_layout"
        elif t.name == BaseTy.Device:
            ret = "Optional[DeviceLikeType]"
        elif t.name == BaseTy.MemoryFormat:
            ret = "memory_format"
        elif t.name == BaseTy.Dimname:
            ret = "Union[str, ellipsis, None]"
        elif t.name == BaseTy.Storage:
            ret = "Union[Storage, UntypedStorage]"
        elif t.name in [BaseTy.Tensor, BaseTy.Generator, BaseTy.Stream]:
            # These python schema type names line up with their function schema names
            ret = t.name.name

    elif isinstance(t, ListType):
        if str(t.elem) == "int":
            ret = "Union[_int, _size]" if t.size is not None else "_size"
        elif t.is_tensor_like():
            # TODO: this doesn't seem right...
            # Tensor?[] currently translates to Optional[Union[Tuple[Tensor, ...], List[Tensor]]]
            # It should probably translate to   Union[Tuple[Optional[Tensor], ...], List[Optional[Tensor]]]
            if isinstance(t.elem, OptionalType):
                add_optional = True
            ret = (
                "Union[Tensor, Tuple[Tensor, ...], List[Tensor]]"
                if t.size is not None
                else "Union[Tuple[Tensor, ...], List[Tensor]]"
            )
        elif str(t.elem) == "float":
            ret = "Sequence[_float]"
        elif str(t.elem) == "SymInt" and t.size is not None:
            elem = argument_type_str_pyi(t.elem)
            ret = f"Union[{elem}, Sequence[{elem}]]"
        else:
            elem = argument_type_str_pyi(t.elem)
            ret = f"Sequence[{elem}]"

    else:
        raise RuntimeError(f"unrecognized type {repr(t)}")

    if add_optional:
        ret = "Optional[" + ret + "]"

    return ret


def return_type_str_pyi(t: Type) -> str:
    # Where arguments are open to accepting Union, return types should return
    # concrete types

    if isinstance(t, OptionalType):
        inner = return_type_str_pyi(t.elem)
        return f"Optional[{inner}]"

    if isinstance(t, BaseType):
        if t.name == BaseTy.Device:
            return "_device"
        elif t.name == BaseTy.Dimname:
            ret = "Optional[str]"
        else:
            return argument_type_str_pyi(t)

    if isinstance(t, ListType):
        inner = return_type_str_pyi(t.elem)
        return f"Tuple[{inner}, ...]"

    return argument_type_str_pyi(t)


def returns_structseq_pyi(signature: PythonSignature) -> tuple[str, str] | None:
    python_returns = [return_type_str_pyi(r.type) for r in signature.returns.returns]
    structseq_name = signature.name
    field_names = structseq_fieldnames(signature.returns.returns)
    if field_names:
        # These types are structseq objects which act like named NamedTuples, but
        # the constructor acts like the constructor of tuple. Using typing.NamedTuple
        # does not allow us to override __init__.
        seq_type = f"Tuple[{', '.join(python_returns)}]"
        structseq_def_lines = [
            f"class {structseq_name}({seq_type}):",
        ]
        for name, typ in zip(field_names, python_returns):
            structseq_def_lines.extend(
                [
                    "    @property",
                    f"    def {name}(self) -> {typ}: ...",
                ]
            )
        structseq_def_lines.extend(
            [
                f"    def __new__(cls, sequence: {seq_type}): ...",
                f"    n_fields: _int = {len(field_names)}",
                f"    n_sequeunce_fields: _int = {len(field_names)}",
                "    n_unnamed_fields: _int = 0",
                "    def __init_subclass__(cls) -> NoReturn: ...  # prohibit subclassing",
                "",  # add an extra newline
            ]
        )
        structseq_def = "\n".join(structseq_def_lines)
        # Example:
        # structseq_def = (
        #     "class max(Tuple[Tensor, Tensor]):\n"
        #     "    @property\n"
        #     "    def values(self) -> Tensor: ...\n"
        #     "    @property\n"
        #     "    def indices(self) -> Tensor: ...\n"
        #     "    def __new__(cls, sequence: Tuple[Tensor, Tensor]): ...\n"
        #     "    n_fields: _int = 2",
        #     "    n_sequeunce_fields: _int = 2",
        #     "    n_unnamed_fields: _int = 0",
        #     "    def __init_subclass__(cls) -> NoReturn: ...  # prohibit subclassing",
        # )
        return structseq_name, structseq_def
    return None


def returns_str_pyi(signature: PythonSignature) -> str:
    field_names = structseq_fieldnames(signature.returns.returns)
    if field_names:
        return f"torch.return_types.{signature.name}"

    python_returns = [return_type_str_pyi(r.type) for r in signature.returns.returns]
    if len(python_returns) > 1:
        return "Tuple[" + ", ".join(python_returns) + "]"
    if len(python_returns) == 1:
        return python_returns[0]
    return "None"


# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
#
#                        C++ Function Dispatch
#
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
# This section provides APIs to generate the code that does C++ function
# dispatch. The C++ function call is wrapped by a lambda function.
# For example:
#
#    // aten::selu_(Tensor(a!) self) -> Tensor(a!)
#    auto dispatch_selu_ = [](Tensor self) -> Tensor {
#      pybind11::gil_scoped_release no_gil;
#      return at::selu_(self);
#    };
#
# The lambda function's signature follows the C++ signature in common
# cases, e.g.:
#
#   // aten::add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor
#   [](const Tensor & self, const Tensor & other, Scalar alpha) -> Tensor
#
# For out variant the 'out' argument's type is changed from 'Tensor &'
# to 'Tensor'. It's because when calling the lambda it passes in the
# PythonArgParser output '_r.tensor(3)', which is stack allocated object
# and needs to pass by value. Also see comments in 'dispatch_lambda_return_str()'.
#
#   // aten::add.out(Tensor self, Tensor other, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!)
#   [](Tensor out, const Tensor & self, const Tensor & other, Scalar alpha) -> Tensor
#
# For multi-output case it can keep using reference type because the
# PythonArgParser output has been unpacked to local variables, e.g.:
#
#   // aten::max.names_dim_max(Tensor self, Dimname dim, bool keepdim=False, *,
#   //     Tensor(a!) max, Tensor(b!) max_values) -> (Tensor(a!) values, Tensor(b!) indices)
#   [](Tensor & max, Tensor & max_values, const Tensor & self, Dimname dim, bool keepdim) -> std::tuple<Tensor,Tensor>
#
# For deprecated python signature, it should follow deprecated python arg order.
# TODO: This is to keep same byte-for-byte result as the old codegen - maybe unnecessary?


def dispatch_lambda_args(
    ps: PythonSignature, f: NativeFunction, symint: bool = True
) -> tuple[DispatchLambdaArgument, ...]:
    if isinstance(ps, PythonSignatureDeprecated):
        schema = ps.deprecated_schema
    else:
        schema = f.func

    # Start with cpp arguments - dispatch lambda signature always include 'self'
    cpp_args = cpp.arguments(
        arguments=schema.arguments,
        faithful=False,
        symint=symint,
        method=False,
        cpp_no_default_args=f.cpp_no_default_args,
    )
    out_args: set[str] = {a.name for a in schema.arguments.out}

    # Convert from cpp argument to lambda argument
    def dispatch_lambda_arg(cpp_arg: Binding) -> DispatchLambdaArgument:
        type_str = cpp_arg.type
        is_out_arg = cpp_arg.name in out_args
        if ps.method and cpp_arg.name == "self":
            # For method's 'self', we can use 'const Tensor &' and simply ignore mutability!
            type_str = "const at::Tensor &"
        else:
            # For other cases we need prevent dangling refs to temps (unless it's
            # unpacked scattered output)
            # The reason is explained in the comments above and in 'dispatch_lambda_return_str()'.
            # TODO: avoid this special handling?
            ensure_temp_safe = len(out_args) <= 1 or not is_out_arg
            if ensure_temp_safe:
                type_str = {
                    "at::Tensor &": "at::Tensor",
                }.get(type_str, type_str)
        return DispatchLambdaArgument(
            name=cpp_arg.name,
            type_str=type_str,
            is_out_arg=is_out_arg,
        )

    return tuple(map(dispatch_lambda_arg, cpp_args))


# [old codegen] XXX: if you got here because of an assertion failure, it doesn't mean
# it's enough to just extend the list here. Before you do this, make sure
# to add an appropriate wrap() overload in torch/csrc/autograd/utils/wrap_outputs.h.
SUPPORTED_RETURN_TYPES = {
    "at::Tensor",
    "::std::tuple<at::Tensor,at::Tensor>",
    "::std::tuple<at::Tensor,at::Tensor,at::Tensor>",
    "::std::tuple<at::Tensor,at::Tensor,at::Tensor,at::Tensor>",
    "::std::tuple<at::Tensor,at::Tensor,at::Tensor,at::Tensor,at::Tensor>",
    "::std::tuple<at::Tensor,at::Tensor,at::Tensor,at::Tensor,at::Tensor,at::Tensor>",
    "::std::tuple<at::Tensor,at::Tensor,at::Tensor,int64_t>",
    "::std::tuple<at::Tensor,at::Tensor,double,int64_t>",
    "::std::tuple<at::Tensor,at::Tensor,at::Tensor,at::Tensor,int64_t>",
    "::std::tuple<at::Tensor,at::Tensor,double,at::Tensor,int64_t>",
    "::std::tuple<double,int64_t>",
    "::std::tuple<at::Tensor,::std::vector<at::Tensor>>",
    "::std::vector<at::Tensor>",
    # Needed for flash attention forw/backward
    "::std::tuple<at::Tensor,at::Tensor,at::Tensor,at::Tensor,c10::SymInt,c10::SymInt,at::Tensor,at::Tensor,at::Tensor>",
    "at::Scalar",
    "bool",
    "int64_t",
    "void*",
    "void",
    "at::QScheme",
    "double",
    "at::IntArrayRef",
    "at::ScalarType",
    "at::Stream",
}


def dispatch_lambda_return_str(f: NativeFunction) -> str:
    # [old codegen] Remove type annotation (e.g. 'Tensor' rather than 'Tensor &')
    # because the dispatch lambdas take mutable arguments *by value*, not
    # by reference. If you then return a reference to such an argument, you
    # will now have a pointer to a dangling stack entry. Not good.
    #
    # You want:
    #
    #   auto dispatch_selu_ = [](Tensor self) -> Tensor { ...; return at::selu_(self); };
    #                                            ^^^^^^
    #
    # *not*
    #
    #   auto dispatch_selu_ = [](Tensor self) -> Tensor& { ...; return at::selu_(self); };
    #                                            ^^^^^^^
    #
    # (NB: We can't make dispatch_selu_ take Tensor&, because the enclosing
    # codegen looks like dispatch_selu_(_r.tensor(0)), and you can't take a
    # mutable reference to temporary.  Maybe we could assign it to a
    # variable itself.)
    returns_without_annotation = tuple(
        Return(r.name, r.type, None) for r in f.func.returns
    )
    return_str = cpp.returns_type(returns_without_annotation, symint=True).cpp_type()
    if return_str not in SUPPORTED_RETURN_TYPES:
        raise RuntimeError(f"{f.func.name} returns unsupported type {return_str}")
    return return_str


def cpp_dispatch_target(f: NativeFunction) -> str:
    symint = f.func.has_symint()
    name = cpp.name(f.func, symint_overload=symint)
    if Variant.method in f.variants:
        return f"self.{name}"
    if Variant.function in f.variants:
        if has_tensor_options(f) or f.func.name.name.base.endswith("_like"):
            namespace = "torch"
        else:
            namespace = "at"
        return f"{namespace}::{name}"
    raise RuntimeError(f"could not dispatch, neither function nor method: {f.func}")


def cpp_dispatch_exprs(
    f: NativeFunction,
    *,
    python_signature: PythonSignature | None = None,
) -> tuple[str, ...]:
    cpp_args: Sequence[Binding] = _cpp_signature(f, method=False).arguments()

    exprs: tuple[str, ...] = ()
    if not isinstance(python_signature, PythonSignatureDeprecated):
        # By default the exprs are consistent with the C++ signature.
        exprs = tuple(a.name for a in cpp_args)
    else:
        # For deprecated python signature we may need fill in some constants.
        exprs = tuple(
            filter(
                lambda n: n != "out" or f.func.is_out_fn(),
                python_signature.deprecated_args_exprs,
            )
        )

    if Variant.method in f.variants:
        exprs = tuple(filter("self".__ne__, exprs))

    return exprs


# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
#
#                     Python / C++ Args Binding
#
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #


# We explicitly enumerate the PythonArgParser unpacking methods for all
# supported types. This might be more verbose than necessary, partially
# because of the irregularity of unpacking method naming, partially
# because we want to mimic the old codegen behavior - to reject
# unexpected and/or unsupported cases which the old codegen rejects.
# For certain cases it is intentionally more restrictive than necessary,
# e.g.: it doesn't accepts doublelist with definite size.
def arg_parser_unpack_method(
    t: Type, default: str | None, default_init: str | None, *, symint: bool = True
) -> str:
    has_default_init = default_init is not None
    if has_default_init and str(t) not in (
        "ScalarType?",
        "ScalarType",
        "Device",
        "Device?",
        "Layout",
        "Layout?",
        "bool",
        "bool?",
    ):
        raise RuntimeError(f"type '{t}' does not supported unpacking with default")

    if isinstance(t, BaseType):
        if t.name in [
            BaseTy.Tensor,
            BaseTy.Stream,
            BaseTy.Storage,
            BaseTy.Scalar,
            BaseTy.Dimname,
        ]:
            # These unpack methods line up with their schema names
            return t.name.name.lower()
        elif t.name == BaseTy.ScalarType:
            return "scalartypeWithDefault" if has_default_init else "scalartype"
        elif t.name == BaseTy.Device:
            return "deviceWithDefault" if has_default_init else "device"
        elif t.name == BaseTy.DeviceIndex:
            return "toInt64"
        elif t.name == BaseTy.int:
            return "toInt64"
        elif t.name == BaseTy.SymInt:
            return "toSymInt" if symint else "toInt64"
        elif t.name == BaseTy.bool:
            return "toBoolWithDefault" if has_default_init else "toBool"
        elif t.name == BaseTy.float:
            return "toDouble"
        elif t.name == BaseTy.str:
            return "stringView"
        elif t.name == BaseTy.Layout:
            return "layoutWithDefault" if has_default_init else "layout"
        elif t.name == BaseTy.MemoryFormat:
            return "memoryformat"

    elif isinstance(t, OptionalType):
        if str(t.elem) == "Tensor":
            return "optionalTensor"
        elif str(t.elem) == "Generator":
            return "generator"
        elif str(t.elem) == "Dimname[]":
            return "toDimnameListOptional"
        elif not has_default_init and default in (
            None,
            "None",
            "::std::nullopt",
            "std::nullopt",
        ):
            # If default is None: append 'Optional' to elem's unpacking method
            return (
                arg_parser_unpack_method(t.elem, None, None, symint=symint) + "Optional"
            )
        else:
            # Otherwise, load as underlying type with default
            return arg_parser_unpack_method(
                t.elem, default, default_init, symint=symint
            )

    elif isinstance(t, ListType):
        if str(t.elem) == "Tensor":
            # accept and use definite size
            return f"tensorlist_n<{t.size}>" if t.size is not None else "tensorlist"
        elif str(t.elem) == "Tensor?":
            return "list_of_optional_tensors"
        elif str(t.elem) == "Dimname":
            # accept definite size
            return "dimnamelist"
        elif str(t.elem) == "int":
            # accept definite size
            return "intlist"
        elif str(t.elem) == "float":
            return "doublelist"
        elif str(t.elem) == "SymInt":
            # accept definite size
            return "symintlist" if symint else "intlist"
        elif str(t.elem) == "Scalar":
            return "scalarlist"
    raise RuntimeError(f"type '{t}' is not supported by PythonArgParser")


# Return RHS expression for python argument using PythonArgParser output.
# e.g. for arg name 'foo', arg type 'bool', arg_index = 2, returns '_r.toBool(2)'
def arg_parser_output_expr(
    arg_index: int, a: PythonArgument, *, symint: bool = True
) -> PythonArgParserOutputExpr:
    has_default = a.default_init is not None
    unpack_method = arg_parser_unpack_method(
        t=a.type, default=a.default, default_init=a.default_init, symint=symint
    )
    default = f", {a.default_init}" if has_default else ""
    expr = f"_r.{unpack_method}({arg_index}{default})"

    return PythonArgParserOutputExpr(
        name=a.name,
        expr=expr,
        index=arg_index,
        argument=a,
    )


# Returns a map with key = arg_name and value = PythonArgParserOutputExpr.
def arg_parser_output_exprs(
    ps: PythonSignature, f: NativeFunction, *, symint: bool = True
) -> dict[str, PythonArgParserOutputExpr]:
    return {
        e.name: e
        for i, a in enumerate(ps.arguments())
        for e in (arg_parser_output_expr(i, a, symint=symint),)
    }


# argument name to type for scattered tensor options fields
TENSOR_OPTIONS_FIELDS = {
    "dtype": "ScalarType?",
    "device": "Device?",
    "layout": "Layout?",
    "pin_memory": "bool?",
    "requires_grad": "bool?",
}


# bind arg parser outputs (python args) with dispatch lambda arguments (c++ args).
def dispatch_lambda_exprs(
    ps: PythonSignature, f: NativeFunction, *, symint: bool = True
) -> DispatchLambdaArgumentExprs:
    # This method is to bind 'arg_parser_outputs' and 'lambda_args' by producing
    # 'inits' and 'lambda_args_exprs' for each lambda argument using arg parser
    # outputs.
    arg_parser_outputs = arg_parser_output_exprs(ps, f, symint=symint)
    lambda_args = dispatch_lambda_args(ps, f, symint=symint)
    inits: list[str] = []
    lambda_args_exprs: dict[str, str] = {}

    has_toptions = has_tensor_options(f)

    # 1. special inits/unpacking to provide binding exprs for lambda arguments.
    for a in ps.arguments(skip_tensor_options=True):
        name = a.name
        arg_parser_expr = arg_parser_outputs[a.name].expr

        if has_toptions and name == "self":
            # TODO: why this needs to be special case?
            inits.extend(
                [
                    f"auto self = {arg_parser_expr};",
                ]
            )
            lambda_args_exprs[name] = name
        elif (
            isinstance(a, PythonOutArgument)
            and len(a.outputs) > 1
            and f.func.is_out_fn()
        ):
            inits.extend(
                [
                    f"auto out = {arg_parser_expr};",
                ]
            )
            for i, out_arg in enumerate(a.outputs):
                lambda_args_exprs[out_arg.name] = f"out[{i}]"
        elif str(a.type) == "Dimname[]?":
            # [old codegen]
            # TODO: make this part of something more general, or get rid of it.
            # optional<ArrayRef<T>> are special. The PythonArgParser returns an
            # optional<vector<T>>, which cannot be implicitly converted to
            # optional<ArrayRef<T>>. One needs to unwrap the optional and rewrap.
            inits.extend(
                [
                    f"auto __{name} = {arg_parser_expr};",
                    f"::std::optional<DimnameList> {name} = __{name} ? ::std::make_optional(DimnameList(__{name}.value())) : ::std::nullopt;",  # noqa: B950
                ]
            )
            lambda_args_exprs[name] = name
        else:
            # default case - directly using PythonArgParser output expr
            lambda_args_exprs[name] = arg_parser_expr

    # method's self is passed directly to python binding, rather than parsed
    if ps.method:
        lambda_args_exprs["self"] = "self"

    # 2. special packing/checking for TensorOptions.
    tensor_options_args_names = [a.name for a in ps.tensor_options_args]
    if has_toptions:
        if f.func.is_out_fn():
            raise RuntimeError(f"{f.func}: tensor options with output arg")
        for a in ps.tensor_options_args:
            if a.name not in TENSOR_OPTIONS_FIELDS:
                raise RuntimeError(
                    f"{f.func}: unrecognized tensor options field '{a.name}' in python binding arguments"
                )
            if str(a.type) != TENSOR_OPTIONS_FIELDS.get(a.name):
                raise RuntimeError(
                    f"{f.func}: unrecognized type '{str(a.type)}' for tensor options field '{a.name}'"
                )
        if not all(a in tensor_options_args_names for a in TENSOR_OPTIONS_FIELDS):
            raise RuntimeError(
                f"{f.func}: incomplete tensor options args: {tensor_options_args_names}"
            )

        inits.append(
            f"""\
const auto options = TensorOptions()
    .dtype({arg_parser_outputs['dtype'].expr})
    .device({arg_parser_outputs['device'].expr})
    .layout({arg_parser_outputs['layout'].expr})
    .requires_grad({arg_parser_outputs['requires_grad'].expr})
    .pinned_memory({arg_parser_outputs['pin_memory'].expr});
torch::utils::maybe_initialize_device(options);
"""
        )
        lambda_args_exprs["options"] = "options"

    # 3. special case - access scattered TensorOptions fields without packing
    # TODO: maybe move to the generator side as it's not related to binding.
    if not has_toptions and tensor_options_args_names:
        if "dtype" in tensor_options_args_names:
            # we're an output-arg variant, check these args against output tensor
            if not f.func.is_out_fn():
                raise RuntimeError(
                    f"{f.func}: dtype in tensor_options_args without output arg, {ps} {ps.arguments}"
                )
            if not all(a in tensor_options_args_names for a in ("layout", "device")):
                raise RuntimeError(
                    f"{f.func}: incomplete tensor options for output check"
                )

            inits.append(
                f"""\
check_out_type_matches({arg_parser_outputs['out'].expr}, {arg_parser_outputs['dtype'].expr},
                       {arg_parser_outputs['dtype'].is_none_expr}, {arg_parser_outputs['layout'].expr},
                       {arg_parser_outputs['device'].expr}, {arg_parser_outputs['device'].is_none_expr});
"""
            )
        # we'll set requires_grad on outgoing tensor
        if "requires_grad" not in tensor_options_args_names:
            raise RuntimeError(
                f'{f.func}: expected "requires_grad" in tensor_options_args absent, but found [{tensor_options_args_names}]'
            )

    return DispatchLambdaArgumentExprs(
        exprs=tuple(lambda_args_exprs[a.name] for a in lambda_args),
        inits=inits,
    )
