# mypy: allow-untyped-defs
from types import TracebackType
from typing import List, Optional
import tempfile
import traceback
import contextlib
import inspect
import os.path

# This file contains utilities for ensuring dynamically compile()'d
# code fragments display their line numbers in backtraces.
#
# The constraints:
#
# - We don't have control over the user exception printer (in particular,
#   we cannot assume the linecache trick will work, c.f.
#   https://stackoverflow.com/q/50515651/23845 )
#
# - We don't want to create temporary files every time we compile()
#   some code; file creation should happen lazily only at exception
#   time.  Arguably, you *should* be willing to write out your
#   generated Python code to file system, but in some situations
#   (esp. library code) it would violate user expectation to write
#   to the file system, so we try to avoid it.  In particular, we'd
#   like to keep the files around, so users can open up the files
#   mentioned in the trace; if the file is invisible, we want to
#   avoid clogging up the filesystem.
#
#   If this is not a constraint for you, there is a substantially simpler
#   way to implement the functionality in this PR: instead of using
#   eval/exec directly, just always write a Python file to filesystem
#   and compile that.
#
# - You have control over a context where the compiled code will get
#   executed, so that we can interpose while the stack is unwinding
#   (otherwise, we have no way to interpose on the exception printing
#   process.)
#
# There are two things you have to do to make use of the utilities here:
#
# - When you compile your source code, you must save its string source
#   in its f_globals under the magic name "__compile_source__"
#
# - Before running the compiled code, enter the
#   report_compile_source_on_error() context manager.

@contextlib.contextmanager
def report_compile_source_on_error():
    try:
        yield
    except Exception as exc:
        tb = exc.__traceback__

        # Walk the traceback, looking for frames that have
        # source attached
        stack = []
        while tb is not None:
            filename = tb.tb_frame.f_code.co_filename
            source = tb.tb_frame.f_globals.get("__compile_source__")

            if filename == "<string>" and source is not None:
                # What black magic are we doing here?  Intuitively, what
                # we would like to do is overwrite the co_filename on any
                # frames that were generated from exec/eval so that they
                # point to a temporary file that has the actual line
                # information, so Python's default error printer can print
                # useful line information on it.
                #
                # Writing out the temporary file is easy.  But overwriting
                # co_filename is not!  You can't modify the code object
                # associated with a frame.  You can, however, reconstruct
                # a traceback with entirely new frames from scratch, so that's
                # what we do.  But there's another problem, which is how to
                # make the frame?
                #
                # The black magic is we make a frankenstein frame and code
                # object which resembles the original frame/code enough so
                # that it will print properly under traceback and the default
                # error printer, but IT IS NOT THE ORIGINAL FRAME (you
                # couldn't, e.g., execute its code with different variables
                # and expect it to work.)

                # Don't delete the temporary file so the user can inspect it
                # TODO: This creates a temporary file for every frame, but we
                # technically only need one per distinct __compile_source__
                with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix=".py") as f:
                    f.write(source)
                # Create a frame.  Python doesn't let you construct
                # FrameType directly, so just make one with compile
                frame = tb.tb_frame
                code = compile('__inspect_currentframe()', f.name, 'eval')
                code = code.replace(co_name=frame.f_code.co_name)
                # Python 3.11 only
                if hasattr(frame.f_code, 'co_linetable'):
                    # We can't copy ALL of the metadata over, because you
                    # can cause Python to segfault this way.  What exactly
                    # do we need?  We need enough information for
                    # traceback to be able to print the exception
                    # correctly.  Code reading Lib/traceback.py reveals
                    # that traceback calls code.co_positions() in order to
                    # get the augmented line/col numbers.  Objects/codeobject.c,
                    # specifically _PyCode_InitAddressRange, reveals that
                    # this iterator is initialized from co_linetable and
                    # co_firstfileno.  So copy these we must!
                    code = code.replace(  # type: ignore[call-arg]
                        co_linetable=frame.f_code.co_linetable,  # type: ignore[attr-defined]
                        co_firstlineno=frame.f_code.co_firstlineno,  # type: ignore[attr-defined]
                    )
                fake_frame = eval(
                    code,
                    frame.f_globals,
                    {
                        **frame.f_locals,
                        '__inspect_currentframe': inspect.currentframe
                    }
                )
                fake_tb = TracebackType(
                    None, fake_frame, tb.tb_lasti, tb.tb_lineno
                )
                stack.append(fake_tb)
            else:
                stack.append(tb)

            tb = tb.tb_next

        # Reconstruct the linked list
        tb_next = None
        for tb in reversed(stack):
            tb.tb_next = tb_next
            tb_next = tb

        raise exc.with_traceback(tb_next)  # noqa: B904

def shorten_filename(fn, *, base=None):
    """Shorten a source filepath, with the assumption that torch/ subdirectories don't need to be shown to user."""
    if base is None:
        base = os.path.dirname(os.path.dirname(__file__))
    # Truncate torch/foo.py to foo.py
    try:
        prefix = os.path.commonpath([fn, base])
    except ValueError:
        return fn
    else:
        return fn[len(prefix) + 1:]

def format_frame(frame, *, base=None, line=False):
    """
    Format a FrameSummary in a short way, without printing full absolute path or code.

    The idea is the result fits on a single line.
    """
    extra_line = ""
    if line:
        extra_line = f"{frame.line}  # "
    return f"{extra_line}{shorten_filename(frame.filename, base=base)}:{frame.lineno} in {frame.name}"

def format_traceback_short(tb):
    """Format a TracebackType in a short way, printing only the inner-most frame."""
    return format_frame(traceback.extract_tb(tb)[-1])

class CapturedTraceback:
    __slots__ = ['tb', 'skip']

    def __init__(self, tb, skip=0):
        self.tb = tb
        self.skip = skip

    def cleanup(self):
        self.tb = None

    def summary(self):
        import torch._C._profiler

        if self.tb is None:
            # TODO: Maybe indicate that the traceback was elided?
            return traceback.StackSummary()

        return _extract_symbolized_tb(
            torch._C._profiler.symbolize_tracebacks([self.tb])[0],
            self.skip
        )

    def __getstate__(self):
        return (None, {
            'tb': None,  # TB is not pickleable
            'skip': self.skip,
        })

    @staticmethod
    def extract(*, script=False, cpp=False, skip=0):
        """
        Like traceback.extract_stack(), but faster (approximately 20x faster); it
        is fast enough that you can unconditionally log stacks this way as part of
        normal execution.  It returns a torch._C._profiler.CapturedTraceback
        object that must be formatted specially with format_captured_tb.

        By default, this only reports Python backtraces (like extract_stack).  You
        can set the script/cpp kwargs to also turn on TorchScript/C++ trace
        reporting.
        """
        import torch._C._profiler

        if script or cpp:
            assert skip == 0, "skip with script/cpp NYI"

        return CapturedTraceback(
            torch._C._profiler.gather_traceback(python=True, script=script, cpp=cpp),
            # Elide extract() frame if we don't have script/cpp frames.  If
            # we do have those frames, it doesn't work so force zero.
            0 if script or cpp else skip + 1
        )

    def format(self):
        """
        Formats a single torch._C._profiler.CapturedTraceback into a list of
        strings equivalent to the output of traceback.format_list.  Note that if
        pass it CapturedTraceback with C++ traces,  it is better not to use this
        function and use the batch formatting API format_captured_tbs to amortize
        the cost of symbolization
        """
        return traceback.format_list(self.summary())

    @staticmethod
    def format_all(tbs):
        """
        Bulk version of CapturedTraceback.format.  Returns a list of list of strings.
        """
        import torch._C._profiler

        # Directly populate tracebacks that already have cached summaries
        rs: List[Optional[List[str]]] = []
        delayed_idxs = []
        for i, tb in enumerate(tbs):
            if tb.tb is None:
                rs.append([])
            else:
                rs.append(None)
                delayed_idxs.append(i)

        stbs = torch._C._profiler.symbolize_tracebacks([tbs[i].tb for i in delayed_idxs])
        for i, stb in zip(delayed_idxs, stbs):
            rs[i] = traceback.format_list(tbs[i].summary())

        return rs


def _extract_symbolized_tb(tb, skip):
    """
    Given a symbolized traceback from symbolize_tracebacks, return a StackSummary object of
    pre-processed stack trace entries.
    """
    stack = traceback.StackSummary()
    for f in reversed(tb[skip:]):
        stack.append(traceback.FrameSummary(f['filename'], f['line'], f['name']))
    return stack
