# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name, unused-import, redefined-outer-name
"""Runtime Tensor API"""
import ctypes
import warnings
from typing import Optional

import numpy as np

try:
    import ml_dtypes
except ImportError:
    ml_dtypes = None

import tvm_ffi
from tvm_ffi import device, DLDeviceType

import tvm
from tvm.runtime import Device
from . import _ffi_api


def from_dlpack(ext_tensor):
    """
    Convert an external tensor to an Tensor.

    Parameters
    ----------
    ext_tensor : object
        The external tensor to convert.

    require_alignment : int
        The minimum required alignment to check for the tensor.

    require_contiguous : bool
        Whether to check for contiguous memory.
    """
    # TODO(tvm-team): change to require_alignment=0 and require_contiguous=False
    # once we update the compiler generated code to guard against misaligned access.
    return tvm_ffi.from_dlpack(
        ext_tensor,
        require_alignment=64,
        require_contiguous=True,
    )


@tvm_ffi.register_object("ffi.Tensor")
class Tensor(tvm_ffi.core.Tensor):
    """Lightweight Tensor class of TVM runtime.

    Strictly this is only an Array Container (a buffer object)
    No arthimetic operations are defined.
    All operations are performed by TVM functions.

    The goal is not to re-build yet another array library.
    Instead, this is a minimal data structure to demonstrate
    how can we use TVM in existing project which might have their own array containers.
    """

    def __setitem__(self, in_slice, value):
        """Set ndarray value"""
        if (
            not isinstance(in_slice, slice)
            or in_slice.start is not None
            or in_slice.stop is not None
        ):
            raise ValueError("Array only support set from numpy array")
        if isinstance(value, Tensor):
            if not value.same_as(self):
                value.copyto(self)
        elif isinstance(value, (np.ndarray, np.generic)):
            self.copyfrom(value)
        else:
            raise TypeError(f"type {type(value)} not supported")

    def copyfrom(self, source_array):
        """Perform a synchronous copy from the array.

        Parameters
        ----------
        source_array : array_like
            The data source we should like to copy from.

        Returns
        -------
        arr : Tensor
            Reference to self.
        """
        if isinstance(source_array, Tensor):
            source_array.copyto(self)
            return self

        if not isinstance(source_array, np.ndarray):
            try:
                source_array = np.array(source_array, dtype=self.dtype)
            except:
                raise TypeError(
                    f"array must be an array_like data, type {type(source_array)} is not supported"
                )

        t = tvm_ffi.dtype(self.dtype)
        shape, dtype = self.shape, self.dtype
        if t.lanes > 1:
            shape = shape + (t.lanes,)
            t = t.with_lanes(1)
            dtype = str(t)

        if source_array.shape != shape:
            raise ValueError(
                f"array shape do not match the shape of Tensor {source_array.shape} vs {shape}"
            )
        numpy_str_map = tvm_ffi.dtype._NUMPY_DTYPE_TO_STR
        np_dtype_str = (
            numpy_str_map[source_array.dtype]
            if source_array.dtype in numpy_str_map
            else str(source_array.dtype)
        )
        if (not source_array.flags["C_CONTIGUOUS"]) or (
            dtype == "bfloat16" or dtype != np_dtype_str
        ):
            if dtype == "bfloat16":
                source_array = np.frombuffer(source_array.tobytes(), "uint16")
            source_array = np.ascontiguousarray(
                source_array, dtype="uint16" if dtype == "bfloat16" else dtype
            )
        if self.dtype.startswith("float4_e2m1fn"):
            # we need to pack the input data when converting to float4_e2m1fn type,
            data_bits = source_array.view(dtype="uint8").flatten()
            if data_bits.size % 2:
                data_bits = np.pad(data_bits, (0, 1), mode="constant", constant_values=0)
            data_bits = data_bits.reshape(-1, 2)
            packed = ((data_bits[:, 0] & 0x0F) << 4) | (data_bits[:, 1] & 0x0F)
            source_array = packed.astype(np.int8)
        assert source_array.flags["C_CONTIGUOUS"]
        data = source_array.ctypes.data_as(ctypes.c_void_p)
        nbytes = source_array.size * source_array.dtype.itemsize
        _ffi_api.TVMTensorCopyFromBytes(self, data, nbytes)
        return self

    def __repr__(self):
        # exception safety handling for chandle=None
        if self.__chandle__() == 0:
            return type(self).__name__ + "(chandle=None)"
        res = f"<tvm.runtime.Tensor shape={self.shape}, {self.device}>\n"
        res += self.numpy().__repr__()
        return res

    def __str__(self):
        return str(self.numpy())

    def numpy(self):
        """Convert this array to numpy array

        Returns
        -------
        np_arr : numpy.ndarray
            The corresponding numpy array.
        """
        t = tvm_ffi.dtype(self.dtype)
        shape, dtype = self.shape, self.dtype
        old_dtype = dtype
        if t.lanes > 1:
            shape = shape + (t.lanes,)
            t = t.with_lanes(1)
            dtype = str(t)
        if dtype == "int4":
            dtype = "int8"
        if dtype in [
            "bfloat16",
            "float8_e3m4",
            "float8_e4m3",
            "float8_e4m3b11fnuz",
            "float8_e4m3fn",
            "float8_e4m3fnuz",
            "float8_e5m2",
            "float8_e5m2fnuz",
            "float8_e8m0fnu",
            "float6_e2m3fn",
            "float6_e3m2fn",
            "float4_e2m1fn",
        ]:
            if ml_dtypes is None:
                raise RuntimeError(
                    f"ml_dtypes is not installed, cannot convert {dtype} array to numpy."
                )
            try:
                dtype = getattr(ml_dtypes, dtype)
            except AttributeError:
                raise RuntimeError(f"ml_dtypes has no attribute '{dtype}', cannot convert array.")
        np_arr = np.empty(shape, dtype=dtype)
        assert np_arr.flags["C_CONTIGUOUS"]
        data = np_arr.ctypes.data_as(ctypes.c_void_p)
        # TODO(kathy): revisit and get a mirrored function of ffi::GetDataSize
        # in Python to replace line below
        nbytes = np_arr.size if dtype == "bool" else (np_arr.size * old_dtype.bits + 7) // 8
        _ffi_api.TVMTensorCopyToBytes(self, data, nbytes)

        if old_dtype == "int4" or old_dtype.startswith("float4_e2m1fn"):
            length = np_arr.size
            np_arr = np_arr.view("int8")
            np_arr_ret = np.empty((length,), dtype="int8")
            np_arr = np_arr.reshape((length,))
            odd_index = np.bitwise_and(np_arr, 0x0F)
            even_index = np.bitwise_and(np_arr >> 4, 0x0F)
            np_arr_ret[1::2] = odd_index[0 : length // 2]
            np_arr_ret[0::2] = even_index[0 : (length + 1) // 2]
            return np_arr_ret.reshape(shape).view(dtype)

        return np_arr

    def copyto(self, target, mem_scope=None):
        """Copy array to target

        Parameters
        ----------
        target : Tensor
            The target array to be copied, must have same shape as this array.

        mem_scope : Optional[str]
            The memory scope of the array.
        """
        if isinstance(target, Tensor):
            return self._copyto(target)
        if isinstance(target, tvm_ffi.core.Device):
            res = empty(self.shape, self.dtype, target, mem_scope)
            return self._copyto(res)
        raise ValueError(f"Unsupported target type {type(target)}")

    def _copyto(self, target_nd):
        """Internal function that implements copy to target ndarray."""
        _ffi_api.TVMTensorCopyFromTo(self, target_nd)
        return target_nd

    def _create_view(self, shape, dtype: Optional[str] = None, relative_byte_offset: int = 0):
        """Create a view into an existing array.

        The view shares the same allocation and datatype as the
        existing array, but can have a different array shape.  This is
        useful for runtimes that support non-flat memory, where both
        the physical shape of an allocation and the logical shape of
        the tensor it represents may need to be independently
        specified.

        Warning: This function should not be used outside of low-level
        manipulations, as it breaks non-aliasing assumptions made by
        TVM.  This function may also be removed/replaced in the
        future.

        Parameters
        ----------
        shape: Union[tvm.runtime.ShapeTuple, Sequence[typing.SupportsInt]]

            The shape of the view.

        dtype: Optional[str]

            The datatype of the view.  If None (default), the view
            will be the same data type as the current array.

        relative_byte_offset: int

            The location of the view, relative to the location of the current
            array.

            Note: While the `DLTensor.byte_offset` field of the returned view
            is usually the same as `relative_byte_offset`, this is not
            guaranteed.  The `DLTensor.byte_offset` field is relative to the
            start of the backing allocation, while the `relative_byte_offset`
            is relative to the start of `self`.

        """

        if not isinstance(shape, tvm.runtime.ShapeTuple):
            shape = tvm.runtime.ShapeTuple([int(dim) for dim in shape])

        if dtype is None:
            dtype = self.dtype

        return _ffi_api.TVMTensorCreateView(self, shape, dtype, relative_byte_offset)


def empty(shape, dtype="float32", device=None, mem_scope=None):
    """Create an empty array given shape and device

    Parameters
    ----------
    shape : Union[tvm.runtime.ShapeTuple, Sequence[typing.SupportsInt]]
        The shape of the array.

    dtype : type or str
        The data type of the array.

    device : Device
        The device of the array.

    mem_scope : Optional[str]
        The memory scope of the array.

    Returns
    -------
    arr : tvm.runtime.Tensor
        The array tvm supported.
    """
    device = device or cpu()
    if not isinstance(shape, tvm.runtime.ShapeTuple):
        shape = tvm.runtime.ShapeTuple([int(dim) for dim in shape])
    dtype = tvm_ffi.dtype(dtype)
    arr = _ffi_api.TVMTensorAllocWithScope(shape, dtype, device, mem_scope)
    return arr


def tensor(arr, device=None, mem_scope=None):
    """Create an tensor from source arr.

    Parameters
    ----------
    arr : numpy.ndarray
        The array to be copied from

    device : Device, optional
        The device to create the array

    mem_scope : Optional[str]
        The memory scope of the array

    Returns
    -------
    ret : Tensor
        The created array
    """
    device = device or cpu()

    if not isinstance(arr, (np.ndarray, Tensor)):
        arr = np.array(arr)
    return empty(arr.shape, arr.dtype, device, mem_scope).copyfrom(arr)


def cpu(dev_id=0):
    """Construct a CPU device

    Parameters
    ----------
    dev_id : int, optional
        The integer device id

    Returns
    -------
    dev : Device
        The created device
    """
    return device(DLDeviceType.kDLCPU, dev_id)


def cuda(dev_id=0):
    """Construct a CUDA GPU device

    Parameters
    ----------
    dev_id : int, optional
        The integer device id

    Returns
    -------
    dev : Device
        The created device
    """
    return device(DLDeviceType.kDLCUDA, dev_id)


def rocm(dev_id=0):
    """Construct a ROCM device

    Parameters
    ----------
    dev_id : int, optional
        The integer device id

    Returns
    -------
    dev : Device
        The created device
    """
    return device(DLDeviceType.kDLROCM, dev_id)


def opencl(dev_id=0):
    """Construct a OpenCL device

    Parameters
    ----------
    dev_id : int, optional
        The integer device id

    Returns
    -------
    dev : Device
        The created device
    """
    return device(DLDeviceType.kDLOpenCL, dev_id)


def metal(dev_id=0):
    """Construct a metal device

    Parameters
    ----------
    dev_id : int, optional
        The integer device id

    Returns
    -------
    dev : Device
        The created device
    """
    return device(DLDeviceType.kDLMetal, dev_id)


def vpi(dev_id=0):
    """Construct a VPI simulated device

    Parameters
    ----------
    dev_id : int, optional
        The integer device id

    Returns
    -------
    dev : Device
        The created device
    """
    return device(DLDeviceType.kDLVPI, dev_id)


def vulkan(dev_id=0):
    """Construct a Vulkan device

    Parameters
    ----------
    dev_id : int, optional
        The integer device id

    Returns
    -------
    dev : Device
        The created device
    """
    return device(DLDeviceType.kDLVulkan, dev_id)


def ext_dev(dev_id=0):
    """Construct a extension device

    Parameters
    ----------
    dev_id : int, optional
        The integer device id

    Returns
    -------
    dev : Device
        The created device

    Note
    ----
    This API is reserved for quick testing of new
    device by plugin device API as ext_dev.
    """
    return device(DLDeviceType.kDLExtDev, dev_id)


def hexagon(dev_id=0):
    """Construct a Hexagon device

    Parameters
    ----------
    dev_id : int, optional
        The integer device id

    Returns
    -------
    dev : Device
        The created device
    """
    return device(DLDeviceType.kDLHexagon, dev_id)


def webgpu(dev_id=0):
    """Construct a webgpu device.

    Parameters
    ----------
    dev_id : int, optional
        The integer device id

    Returns
    -------
    dev : Device
        The created device
    """
    return device(DLDeviceType.kDLWebGPU, dev_id)


# Register back to FFI
tvm_ffi.core._set_class_tensor(Tensor)
