"""LU decomposition functions."""

from warnings import warn

from numpy import asarray, asarray_chkfinite
import numpy as np
from itertools import product

# Local imports
from ._misc import _datacopied, LinAlgWarning
from .lapack import get_lapack_funcs
from ._decomp_lu_cython import lu_dispatcher

lapack_cast_dict = {x: ''.join([y for y in 'fdFD' if np.can_cast(x, y)])
                    for x in np.typecodes['All']}

__all__ = ['lu', 'lu_solve', 'lu_factor']


def lu_factor(a, overwrite_a=False, check_finite=True):
    """
    Compute pivoted LU decomposition of a matrix.

    The decomposition is::

        A = P L U

    where P is a permutation matrix, L lower triangular with unit
    diagonal elements, and U upper triangular.

    Parameters
    ----------
    a : (M, N) array_like
        Matrix to decompose
    overwrite_a : bool, optional
        Whether to overwrite data in A (may increase performance)
    check_finite : bool, optional
        Whether to check that the input matrix contains only finite numbers.
        Disabling may give a performance gain, but may result in problems
        (crashes, non-termination) if the inputs do contain infinities or NaNs.

    Returns
    -------
    lu : (M, N) ndarray
        Matrix containing U in its upper triangle, and L in its lower triangle.
        The unit diagonal elements of L are not stored.
    piv : (K,) ndarray
        Pivot indices representing the permutation matrix P:
        row i of matrix was interchanged with row piv[i].
        Of shape ``(K,)``, with ``K = min(M, N)``.

    See Also
    --------
    lu : gives lu factorization in more user-friendly format
    lu_solve : solve an equation system using the LU factorization of a matrix

    Notes
    -----
    This is a wrapper to the ``*GETRF`` routines from LAPACK. Unlike
    :func:`lu`, it outputs the L and U factors into a single array
    and returns pivot indices instead of a permutation matrix.

    While the underlying ``*GETRF`` routines return 1-based pivot indices, the
    ``piv`` array returned by ``lu_factor`` contains 0-based indices.

    Examples
    --------
    >>> import numpy as np
    >>> from scipy.linalg import lu_factor
    >>> A = np.array([[2, 5, 8, 7], [5, 2, 2, 8], [7, 5, 6, 6], [5, 4, 4, 8]])
    >>> lu, piv = lu_factor(A)
    >>> piv
    array([2, 2, 3, 3], dtype=int32)

    Convert LAPACK's ``piv`` array to NumPy index and test the permutation

    >>> def pivot_to_permutation(piv):
    ...     perm = np.arange(len(piv))
    ...     for i in range(len(piv)):
    ...         perm[i], perm[piv[i]] = perm[piv[i]], perm[i]
    ...     return perm
    ...
    >>> p_inv = pivot_to_permutation(piv)
    >>> p_inv
    array([2, 0, 3, 1])
    >>> L, U = np.tril(lu, k=-1) + np.eye(4), np.triu(lu)
    >>> np.allclose(A[p_inv] - L @ U, np.zeros((4, 4)))
    True

    The P matrix in P L U is defined by the inverse permutation and
    can be recovered using argsort:

    >>> p = np.argsort(p_inv)
    >>> p
    array([1, 3, 0, 2])
    >>> np.allclose(A - L[p] @ U, np.zeros((4, 4)))
    True

    or alternatively:

    >>> P = np.eye(4)[p]
    >>> np.allclose(A - P @ L @ U, np.zeros((4, 4)))
    True
    """
    if check_finite:
        a1 = asarray_chkfinite(a)
    else:
        a1 = asarray(a)

    # accommodate empty arrays
    if a1.size == 0:
        lu = np.empty_like(a1)
        piv = np.arange(0, dtype=np.int32)
        return lu, piv

    overwrite_a = overwrite_a or (_datacopied(a1, a))

    getrf, = get_lapack_funcs(('getrf',), (a1,))
    lu, piv, info = getrf(a1, overwrite_a=overwrite_a)
    if info < 0:
        raise ValueError('illegal value in %dth argument of '
                         'internal getrf (lu_factor)' % -info)
    if info > 0:
        warn("Diagonal number %d is exactly zero. Singular matrix." % info,
             LinAlgWarning, stacklevel=2)
    return lu, piv


def lu_solve(lu_and_piv, b, trans=0, overwrite_b=False, check_finite=True):
    """Solve an equation system, a x = b, given the LU factorization of a

    Parameters
    ----------
    (lu, piv)
        Factorization of the coefficient matrix a, as given by lu_factor.
        In particular piv are 0-indexed pivot indices.
    b : array
        Right-hand side
    trans : {0, 1, 2}, optional
        Type of system to solve:

        =====  =========
        trans  system
        =====  =========
        0      a x   = b
        1      a^T x = b
        2      a^H x = b
        =====  =========
    overwrite_b : bool, optional
        Whether to overwrite data in b (may increase performance)
    check_finite : bool, optional
        Whether to check that the input matrices contain only finite numbers.
        Disabling may give a performance gain, but may result in problems
        (crashes, non-termination) if the inputs do contain infinities or NaNs.

    Returns
    -------
    x : array
        Solution to the system

    See Also
    --------
    lu_factor : LU factorize a matrix

    Examples
    --------
    >>> import numpy as np
    >>> from scipy.linalg import lu_factor, lu_solve
    >>> A = np.array([[2, 5, 8, 7], [5, 2, 2, 8], [7, 5, 6, 6], [5, 4, 4, 8]])
    >>> b = np.array([1, 1, 1, 1])
    >>> lu, piv = lu_factor(A)
    >>> x = lu_solve((lu, piv), b)
    >>> np.allclose(A @ x - b, np.zeros((4,)))
    True

    """
    (lu, piv) = lu_and_piv
    if check_finite:
        b1 = asarray_chkfinite(b)
    else:
        b1 = asarray(b)

    overwrite_b = overwrite_b or _datacopied(b1, b)

    if lu.shape[0] != b1.shape[0]:
        raise ValueError(f"Shapes of lu {lu.shape} and b {b1.shape} are incompatible")

    # accommodate empty arrays
    if b1.size == 0:
        m = lu_solve((np.eye(2, dtype=lu.dtype), [0, 1]), np.ones(2, dtype=b.dtype))
        return np.empty_like(b1, dtype=m.dtype)

    getrs, = get_lapack_funcs(('getrs',), (lu, b1))
    x, info = getrs(lu, piv, b1, trans=trans, overwrite_b=overwrite_b)
    if info == 0:
        return x
    raise ValueError('illegal value in %dth argument of internal gesv|posv'
                     % -info)


def lu(a, permute_l=False, overwrite_a=False, check_finite=True,
       p_indices=False):
    """
    Compute LU decomposition of a matrix with partial pivoting.

    The decomposition satisfies::

        A = P @ L @ U

    where ``P`` is a permutation matrix, ``L`` lower triangular with unit
    diagonal elements, and ``U`` upper triangular. If `permute_l` is set to
    ``True`` then ``L`` is returned already permuted and hence satisfying
    ``A = L @ U``.

    Parameters
    ----------
    a : (M, N) array_like
        Array to decompose
    permute_l : bool, optional
        Perform the multiplication P*L (Default: do not permute)
    overwrite_a : bool, optional
        Whether to overwrite data in a (may improve performance)
    check_finite : bool, optional
        Whether to check that the input matrix contains only finite numbers.
        Disabling may give a performance gain, but may result in problems
        (crashes, non-termination) if the inputs do contain infinities or NaNs.
    p_indices : bool, optional
        If ``True`` the permutation information is returned as row indices.
        The default is ``False`` for backwards-compatibility reasons.

    Returns
    -------
    **(If `permute_l` is ``False``)**

    p : (..., M, M) ndarray
        Permutation arrays or vectors depending on `p_indices`
    l : (..., M, K) ndarray
        Lower triangular or trapezoidal array with unit diagonal.
        ``K = min(M, N)``
    u : (..., K, N) ndarray
        Upper triangular or trapezoidal array

    **(If `permute_l` is ``True``)**

    pl : (..., M, K) ndarray
        Permuted L matrix.
        ``K = min(M, N)``
    u : (..., K, N) ndarray
        Upper triangular or trapezoidal array

    Notes
    -----
    Permutation matrices are costly since they are nothing but row reorder of
    ``L`` and hence indices are strongly recommended to be used instead if the
    permutation is required. The relation in the 2D case then becomes simply
    ``A = L[P, :] @ U``. In higher dimensions, it is better to use `permute_l`
    to avoid complicated indexing tricks.

    In 2D case, if one has the indices however, for some reason, the
    permutation matrix is still needed then it can be constructed by
    ``np.eye(M)[P, :]``.

    Examples
    --------

    >>> import numpy as np
    >>> from scipy.linalg import lu
    >>> A = np.array([[2, 5, 8, 7], [5, 2, 2, 8], [7, 5, 6, 6], [5, 4, 4, 8]])
    >>> p, l, u = lu(A)
    >>> np.allclose(A, p @ l @ u)
    True
    >>> p  # Permutation matrix
    array([[0., 1., 0., 0.],  # Row index 1
           [0., 0., 0., 1.],  # Row index 3
           [1., 0., 0., 0.],  # Row index 0
           [0., 0., 1., 0.]]) # Row index 2
    >>> p, _, _ = lu(A, p_indices=True)
    >>> p
    array([1, 3, 0, 2], dtype=int32)  # as given by row indices above
    >>> np.allclose(A, l[p, :] @ u)
    True

    We can also use nd-arrays, for example, a demonstration with 4D array:

    >>> rng = np.random.default_rng()
    >>> A = rng.uniform(low=-4, high=4, size=[3, 2, 4, 8])
    >>> p, l, u = lu(A)
    >>> p.shape, l.shape, u.shape
    ((3, 2, 4, 4), (3, 2, 4, 4), (3, 2, 4, 8))
    >>> np.allclose(A, p @ l @ u)
    True
    >>> PL, U = lu(A, permute_l=True)
    >>> np.allclose(A, PL @ U)
    True

    """
    a1 = np.asarray_chkfinite(a) if check_finite else np.asarray(a)
    if a1.ndim < 2:
        raise ValueError('The input array must be at least two-dimensional.')

    # Also check if dtype is LAPACK compatible
    if a1.dtype.char not in 'fdFD':
        dtype_char = lapack_cast_dict[a1.dtype.char]
        if not dtype_char:  # No casting possible
            raise TypeError(f'The dtype {a1.dtype} cannot be cast '
                            'to float(32, 64) or complex(64, 128).')

        a1 = a1.astype(dtype_char[0])  # makes a copy, free to scratch
        overwrite_a = True

    *nd, m, n = a1.shape
    k = min(m, n)
    real_dchar = 'f' if a1.dtype.char in 'fF' else 'd'

    # Empty input
    if min(*a1.shape) == 0:
        if permute_l:
            PL = np.empty(shape=[*nd, m, k], dtype=a1.dtype)
            U = np.empty(shape=[*nd, k, n], dtype=a1.dtype)
            return PL, U
        else:
            P = (np.empty([*nd, 0], dtype=np.int32) if p_indices else
                 np.empty([*nd, 0, 0], dtype=real_dchar))
            L = np.empty(shape=[*nd, m, k], dtype=a1.dtype)
            U = np.empty(shape=[*nd, k, n], dtype=a1.dtype)
            return P, L, U

    # Scalar case
    if a1.shape[-2:] == (1, 1):
        if permute_l:
            return np.ones_like(a1), (a1 if overwrite_a else a1.copy())
        else:
            P = (np.zeros(shape=[*nd, m], dtype=int) if p_indices
                 else np.ones_like(a1))
            return P, np.ones_like(a1), (a1 if overwrite_a else a1.copy())

    # Then check overwrite permission
    if not _datacopied(a1, a):  # "a"  still alive through "a1"
        if not overwrite_a:
            # Data belongs to "a" so make a copy
            a1 = a1.copy(order='C')
        #  else: Do nothing we'll use "a" if possible
    # else:  a1 has its own data thus free to scratch

    # Then layout checks, might happen that overwrite is allowed but original
    # array was read-only or non-contiguous.

    if not (a1.flags['C_CONTIGUOUS'] and a1.flags['WRITEABLE']):
        a1 = a1.copy(order='C')

    if not nd:  # 2D array

        p = np.empty(m, dtype=np.int32)
        u = np.zeros([k, k], dtype=a1.dtype)
        lu_dispatcher(a1, u, p, permute_l)
        P, L, U = (p, a1, u) if m > n else (p, u, a1)

    else:  # Stacked array

        # Prepare the contiguous data holders
        P = np.empty([*nd, m], dtype=np.int32)  # perm vecs

        if m > n:  # Tall arrays, U will be created
            U = np.zeros([*nd, k, k], dtype=a1.dtype)
            for ind in product(*[range(x) for x in a1.shape[:-2]]):
                lu_dispatcher(a1[ind], U[ind], P[ind], permute_l)
            L = a1

        else:  # Fat arrays, L will be created
            L = np.zeros([*nd, k, k], dtype=a1.dtype)
            for ind in product(*[range(x) for x in a1.shape[:-2]]):
                lu_dispatcher(a1[ind], L[ind], P[ind], permute_l)
            U = a1

    # Convert permutation vecs to permutation arrays
    # permute_l=False needed to enter here to avoid wasted efforts
    if (not p_indices) and (not permute_l):
        if nd:
            Pa = np.zeros([*nd, m, m], dtype=real_dchar)
            # An unreadable index hack - One-hot encoding for perm matrices
            nd_ix = np.ix_(*([np.arange(x) for x in nd]+[np.arange(m)]))
            Pa[(*nd_ix, P)] = 1
            P = Pa
        else:  # 2D case
            Pa = np.zeros([m, m], dtype=real_dchar)
            Pa[np.arange(m), P] = 1
            P = Pa

    return (L, U) if permute_l else (P, L, U)
