# mypy: allow-untyped-defs
import logging
import types
import weakref
from dataclasses import dataclass
from typing import Tuple

from torch._guards import CompileId

from . import config


log = logging.getLogger(__name__)
"""
[Note on cache size limit]

Background - TorchDynamo cache is a linked list. Each cache entry is a
(check_fn, out_code, next pointer). These are stored on the f_code's co_extra
scratch space. When a frame is invoked, we walk this linked list and run
check_fn in each cache_entry to decide if the frame needs recompilation. If none
of the check_fn's returns True, we recompile and add a new entry. To ensure we
don't end up recompiling infinitely, we put limits on the cache size.

There are two limits
1) cache_size_limit
2) accumulated_cache_size_limit


Earlier we used to have only limit - maximum number of entries in 1 cache line
(which is now represented by (2) above). So, why do we need two limits? Lets try
to understand that.

In general, we want our cache limit value to be a small number (e.g. 8 or even
lower). This ensures that for frames that cause too many recompilation fall to
eager quickly. However, there is another problem that prevents us from lowering
the value of cache_size_limit. This is due to ID_MATCH'd guards. Today, we put
ID_MATCH guards on nn module if there is a graph break. This means we will have
many recompilations for the same code object because the ID_MATCH guard fails
for different instances of the nn module. This is a common pattern in how models
are authored. Therefore, this requires us to keep the cache_size_limit high.

We resolve this by introducing these two limits. The first limit (1) limits the
number of cache entries that have an ID_MATCH'd guard for an nn module instance.
And, (2)nd limit becomes a safeguard mechanism to have a maximum compilations
for a code object. One important question is - what is the limit for the code
object that does not have any ID_MATCH guard? For such code objects, we choose
(1) as the cache size limit.

Lets take an example to understand how these limits help. Suppose, we have 16
instances of a nn module and we ID_MATCH on the self object. Further, suppose
the inputs to these functions have varying batch size, leading to one
recompilation. In total, there will be 32 recompilations, and therefore 32 cache
entries on the forward code object. In the older case when we had only 1 limit,
our cache size limit must be >= 32 to capture all these recompilations. Now,
suppose there is a separate function in the same program which is very dynamic
and unsuitable for compilation. Such a function will need to undergo 32
compilations to burst the cache and fallback to eager. These 32 recompilations
are too many and we want to fallback for these compilation-unfriendly functions
sooner.

In the new scenario, we can have (1) cache_size_limit = 2, (2)
accumulated_cache_size_limit = 32. This means that each ID_MATCH'd object can
have maximum of two cache entries, and the maximum number of cache entries
(irrespective of ID_MATCH obj) is 32. This covers the case of forward code
object which has 32 recompilations. For the other function, the one unsuitable
for recompilation, our limit is 2. So, we will burst the cache in just 2
recompilations. In this manner, these 2 limits help us resolve the tension
mentioned earlier.
"""


@dataclass
class CacheSizeRelevantForFrame:
    """
    We track the number of cache entries that have same id_match objects as the
    given frame.

    TODO(janimesh) - Consider adding a map from tuple_of_match_ids to count -
    https://github.com/pytorch/pytorch/pull/107496#discussion_r1304564682 - this
    could be useful for debugging as well.
    """

    # Total number of CacheEntry objects in the Dynamo linked list
    num_cache_entries: int = 0

    # Number of CacheEntry objects having same ID_MATCH'd objects as given frame.
    num_cache_entries_with_same_id_matched_objs: int = 0

    def will_compilation_exceed(self, limit: int) -> bool:
        # Checks if a compilation will exceed the given limit (thats why >=).
        return (
            self.will_compilation_exceed_accumulated_limit()
            or self.will_compilation_exceed_specific_limit(limit)
        )

    def will_compilation_exceed_accumulated_limit(self) -> bool:
        return self.num_cache_entries >= config.accumulated_cache_size_limit

    def will_compilation_exceed_specific_limit(self, limit: int) -> bool:
        return self.num_cache_entries_with_same_id_matched_objs >= limit


def _get_weakref_from_f_locals(frame: types.FrameType, local_name: str):
    obj = frame.f_locals.get(local_name, None)
    weak_id = None
    try:
        weak_id = weakref.ref(obj)
    except TypeError:
        pass  # cannot weakref bool object
    return weak_id


def _has_same_id_matched_objs(frame: types.FrameType, cache_entry) -> bool:
    """
    Checks if the ID_MATCH'd objects saved on cache_entry are same as the ones
    in frame.f_locals.
    """
    if not cache_entry:
        return False

    for (
        local_name,
        weakref_from_cache_entry,
    ) in cache_entry.check_fn.id_matched_objs.items():
        if weakref_from_cache_entry() is not None:
            weakref_from_frame = _get_weakref_from_f_locals(frame, local_name)
            if weakref_from_frame != weakref_from_cache_entry:
                return False

    # Also covers the case where no ID_MATCH objects are saved in frame.f_locals
    return True


def compute_cache_size(
    frame: types.FrameType, cache_entry
) -> CacheSizeRelevantForFrame:
    # Walk the linked list to calculate the cache size
    num_cache_entries = 0
    num_cache_entries_with_same_id_matched_objs = 0

    while cache_entry:
        num_cache_entries += 1
        # Track the number of cache entries having same ID_MATCH'd objects as
        # that of frame.f_locals. This will be used later to compare against the
        # cache_size_limit.
        if _has_same_id_matched_objs(frame, cache_entry):
            num_cache_entries_with_same_id_matched_objs += 1
        cache_entry = cache_entry.next

    return CacheSizeRelevantForFrame(
        num_cache_entries, num_cache_entries_with_same_id_matched_objs
    )


def is_recompilation(cache_size: CacheSizeRelevantForFrame) -> bool:
    """
    If the frame (earlier parsed by compute_cache_size) has more than 1 cache
    entry with same ID_MATCH'd objects, then its a recompilation.
    """
    # Note that you can have multiple entries in the cache but still not a
    # recompile, e.g., you can have 64 nn module instances, each one having an
    # ID_MATCH guard, and each one having just 1 cache entry in the cache.  In
    # this case, we can have 64 entries in the cache, but no recompilation
    # because there is only one entry for each id_matched_obj.
    return cache_size.will_compilation_exceed(1)


def exceeds_cache_size_limit(
    cache_size: CacheSizeRelevantForFrame, compile_id: CompileId
) -> Tuple[bool, str]:
    """
    Checks if we are exceeding the cache size limit.
    """
    if cache_size.will_compilation_exceed_accumulated_limit():
        return True, "accumulated_cache_size_limit"
    if cache_size.will_compilation_exceed_specific_limit(config.cache_size_limit):
        return True, "cache_size_limit"
    # NOTE this check is needed in the case that the frame's cache doesn't grow
    # and we keep recompiling. This can happen if the guard check_fn becomes invalidated,
    # e.g. due to guarded objects being freed. This technically makes the
    # will_compilation_exceed_accumulated_limit check unnecessary, but we will keep the
    # check in case we have a better fix in the future.
    if compile_id.frame_compile_id >= config.accumulated_cache_size_limit:
        return True, "accumulated_cache_size_limit"
    return False, ""
