# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

# @nolint

# not linting this file because it imports * from swigfaiss, which
# causes a ton of useless warnings.

import numpy as np
import array
import warnings

from faiss.loader import *

###########################################
# Utility to add a deprecation warning to
# classes from the SWIG interface
###########################################

def _make_deprecated_swig_class(deprecated_name, base_name):
    """
    Dynamically construct deprecated classes as wrappers around renamed ones

    The deprecation warning added in their __new__-method will trigger upon
    construction of an instance of the class, but only once per session.

    We do this here (in __init__.py) because the base classes are defined in
    the SWIG interface, making it cumbersome to add the deprecation there.

    Parameters
    ----------
    deprecated_name : string
        Name of the class to be deprecated; _not_ present in SWIG interface.
    base_name : string
        Name of the class that is replacing deprecated_name; must already be
        imported into the current namespace.

    Returns
    -------
    None
        However, the deprecated class gets added to the faiss namespace
    """
    base_class = globals()[base_name]

    def new_meth(cls, *args, **kwargs):
        msg = f"The class faiss.{deprecated_name} is deprecated in favour of faiss.{base_name}!"
        warnings.warn(msg, DeprecationWarning, stacklevel=2)
        instance = super(base_class, cls).__new__(cls, *args, **kwargs)
        return instance

    # three-argument version of "type" uses (name, tuple-of-bases, dict-of-attributes)
    klazz = type(deprecated_name, (base_class,), {"__new__": new_meth})

    # this ends up adding the class to the "faiss" namespace, in a way that it
    # is available both through "import faiss" and "from faiss import *"
    globals()[deprecated_name] = klazz


###########################################
# numpy array / std::vector conversions
###########################################

sizeof_long = array.array('l').itemsize
deprecated_name_map = {
    # deprecated: replacement
    'Float': 'Float32',
    'Double': 'Float64',
    'Char': 'Int8',
    'Int': 'Int32',
    'Long': 'Int32' if sizeof_long == 4 else 'Int64',
    'LongLong': 'Int64',
    'Byte': 'UInt8',
    # previously misspelled variant
    'Uint64': 'UInt64',
}

for depr_prefix, base_prefix in deprecated_name_map.items():
    _make_deprecated_swig_class(depr_prefix + "Vector", base_prefix + "Vector")

    # same for the three legacy *VectorVector classes
    if depr_prefix in ['Float', 'Long', 'Byte']:
        _make_deprecated_swig_class(depr_prefix + "VectorVector",
                                    base_prefix + "VectorVector")

# mapping from vector names in swigfaiss.swig and the numpy dtype names
# TODO: once deprecated classes are removed, remove the dict and just use .lower() below
vector_name_map = {
    'Float32': 'float32',
    'Float64': 'float64',
    'Int8': 'int8',
    'Int16': 'int16',
    'Int32': 'int32',
    'Int64': 'int64',
    'UInt8': 'uint8',
    'UInt16': 'uint16',
    'UInt32': 'uint32',
    'UInt64': 'uint64',
    **{k: v.lower() for k, v in deprecated_name_map.items()}
}


def vector_to_array(v):
    """ convert a C++ vector to a numpy array """
    classname = v.__class__.__name__
    if classname.startswith('AlignedTable'):
        return AlignedTable_to_array(v)
    assert classname.endswith('Vector')
    dtype = np.dtype(vector_name_map[classname[:-6]])
    a = np.empty(v.size(), dtype=dtype)
    if v.size() > 0:
        memcpy(swig_ptr(a), v.data(), a.nbytes)
    return a


def vector_float_to_array(v):
    return vector_to_array(v)


def copy_array_to_vector(a, v):
    """ copy a numpy array to a vector """
    n, = a.shape
    classname = v.__class__.__name__
    assert classname.endswith('Vector')
    dtype = np.dtype(vector_name_map[classname[:-6]])
    assert dtype == a.dtype, (
        'cannot copy a %s array to a %s (should be %s)' % (
            a.dtype, classname, dtype))
    v.resize(n)
    if n > 0:
        memcpy(v.data(), swig_ptr(a), a.nbytes)

# same for AlignedTable


def copy_array_to_AlignedTable(a, v):
    n, = a.shape
    # TODO check class name
    assert v.itemsize() == a.itemsize
    v.resize(n)
    if n > 0:
        memcpy(v.get(), swig_ptr(a), a.nbytes)


def array_to_AlignedTable(a):
    if a.dtype == 'uint16':
        v = AlignedTableUint16(a.size)
    elif a.dtype == 'uint8':
        v = AlignedTableUint8(a.size)
    else:
        assert False
    copy_array_to_AlignedTable(a, v)
    return v


def AlignedTable_to_array(v):
    """ convert an AlignedTable to a numpy array """
    classname = v.__class__.__name__
    assert classname.startswith('AlignedTable')
    dtype = classname[12:].lower()
    a = np.empty(v.size(), dtype=dtype)
    if a.size > 0:
        memcpy(swig_ptr(a), v.data(), a.nbytes)
    return a
