import itertools
import os
import re
from string import Template
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Tuple

from tokenizers import Encoding, Tokenizer


dirname = os.path.dirname(__file__)
css_filename = os.path.join(dirname, "visualizer-styles.css")
with open(css_filename) as f:
    css = f.read()


class Annotation:
    start: int
    end: int
    label: int

    def __init__(self, start: int, end: int, label: str):
        self.start = start
        self.end = end
        self.label = label


AnnotationList = List[Annotation]
PartialIntList = List[Optional[int]]


class CharStateKey(NamedTuple):
    token_ix: Optional[int]
    anno_ix: Optional[int]


class CharState:
    char_ix: Optional[int]

    def __init__(self, char_ix):
        self.char_ix = char_ix

        self.anno_ix: Optional[int] = None
        self.tokens: List[int] = []

    @property
    def token_ix(self):
        return self.tokens[0] if len(self.tokens) > 0 else None

    @property
    def is_multitoken(self):
        """
        BPE tokenizers can output more than one token for a char
        """
        return len(self.tokens) > 1

    def partition_key(self) -> CharStateKey:
        return CharStateKey(
            token_ix=self.token_ix,
            anno_ix=self.anno_ix,
        )


class Aligned:
    pass


class EncodingVisualizer:
    """
    Build an EncodingVisualizer

    Args:

         tokenizer (:class:`~tokenizers.Tokenizer`):
            A tokenizer instance

         default_to_notebook (:obj:`bool`):
            Whether to render html output in a notebook by default

         annotation_converter (:obj:`Callable`, `optional`):
            An optional (lambda) function that takes an annotation in any format and returns
            an Annotation object
    """

    unk_token_regex = re.compile("(.{1}\b)?(unk|oov)(\b.{1})?", flags=re.IGNORECASE)

    def __init__(
        self,
        tokenizer: Tokenizer,
        default_to_notebook: bool = True,
        annotation_converter: Optional[Callable[[Any], Annotation]] = None,
    ):
        if default_to_notebook:
            try:
                from IPython.core.display import HTML, display
            except ImportError:
                raise Exception(
                    """We couldn't import IPython utils for html display.
                        Are you running in a notebook?
                        You can also pass `default_to_notebook=False` to get back raw HTML
                    """
                )

        self.tokenizer = tokenizer
        self.default_to_notebook = default_to_notebook
        self.annotation_coverter = annotation_converter
        pass

    def __call__(
        self,
        text: str,
        annotations: AnnotationList = [],
        default_to_notebook: Optional[bool] = None,
    ) -> Optional[str]:
        """
        Build a visualization of the given text

        Args:
            text (:obj:`str`):
                The text to tokenize

            annotations (:obj:`List[Annotation]`, `optional`):
                An optional list of annotations of the text. The can either be an annotation class
                or anything else if you instantiated the visualizer with a converter function

            default_to_notebook (:obj:`bool`, `optional`, defaults to `False`):
                If True, will render the html in a notebook. Otherwise returns an html string.

        Returns:
            The HTML string if default_to_notebook is False, otherwise (default) returns None and
            renders the HTML in the notebook

        """
        final_default_to_notebook = self.default_to_notebook
        if default_to_notebook is not None:
            final_default_to_notebook = default_to_notebook
        if final_default_to_notebook:
            try:
                from IPython.core.display import HTML, display
            except ImportError:
                raise Exception(
                    """We couldn't import IPython utils for html display.
                    Are you running in a notebook?"""
                )
        if self.annotation_coverter is not None:
            annotations = list(map(self.annotation_coverter, annotations))
        encoding = self.tokenizer.encode(text)
        html = EncodingVisualizer.__make_html(text, encoding, annotations)
        if final_default_to_notebook:
            display(HTML(html))
        else:
            return html

    @staticmethod
    def calculate_label_colors(annotations: AnnotationList) -> Dict[str, str]:
        """
        Generates a color palette for all the labels in a given set of annotations

        Args:
          annotations (:obj:`Annotation`):
            A list of annotations

        Returns:
            :obj:`dict`: A dictionary mapping labels to colors in HSL format
        """
        if len(annotations) == 0:
            return {}
        labels = set(map(lambda x: x.label, annotations))
        num_labels = len(labels)
        h_step = int(255 / num_labels)
        if h_step < 20:
            h_step = 20
        s = 32
        l = 64  # noqa: E741
        h = 10
        colors = {}

        for label in sorted(labels):  # sort so we always get the same colors for a given set of labels
            colors[label] = f"hsl({h},{s}%,{l}%"
            h += h_step
        return colors

    @staticmethod
    def consecutive_chars_to_html(
        consecutive_chars_list: List[CharState],
        text: str,
        encoding: Encoding,
    ):
        """
        Converts a list of "consecutive chars" into a single HTML element.
        Chars are consecutive if they fall under the same word, token and annotation.
        The CharState class is a named tuple with a "partition_key" method that makes it easy to
        compare if two chars are consecutive.

        Args:
            consecutive_chars_list (:obj:`List[CharState]`):
                A list of CharStates that have been grouped together

            text (:obj:`str`):
                The original text being processed

            encoding (:class:`~tokenizers.Encoding`):
                The encoding returned from the tokenizer

        Returns:
            :obj:`str`: The HTML span for a set of consecutive chars
        """
        first = consecutive_chars_list[0]
        if first.char_ix is None:
            # its a special token
            stoken = encoding.tokens[first.token_ix]
            # special tokens are represented as empty spans. We use the data attribute and css
            # magic to display it
            return f'<span class="special-token" data-stoken={stoken}></span>'
        # We're not in a special token so this group has a start and end.
        last = consecutive_chars_list[-1]
        start = first.char_ix
        end = last.char_ix + 1
        span_text = text[start:end]
        css_classes = []  # What css classes will we apply on the resulting span
        data_items = {}  # What data attributes will we apply on the result span
        if first.token_ix is not None:
            # We can either be in a token or not (e.g. in white space)
            css_classes.append("token")
            if first.is_multitoken:
                css_classes.append("multi-token")
            if first.token_ix % 2:
                # We use this to color alternating tokens.
                # A token might be split by an annotation that ends in the middle of it, so this
                # lets us visually indicate a consecutive token despite its possible splitting in
                # the html markup
                css_classes.append("odd-token")
            else:
                # Like above, but a different color so we can see the tokens alternate
                css_classes.append("even-token")
            if EncodingVisualizer.unk_token_regex.search(encoding.tokens[first.token_ix]) is not None:
                # This is a special token that is in the text. probably UNK
                css_classes.append("special-token")
                # TODO is this the right name for the data attribute ?
                data_items["stok"] = encoding.tokens[first.token_ix]
        else:
            # In this case we are looking at a group/single char that is not tokenized.
            # e.g. white space
            css_classes.append("non-token")
        css = f'''class="{' '.join(css_classes)}"'''
        data = ""
        for key, val in data_items.items():
            data += f' data-{key}="{val}"'
        return f"<span {css} {data} >{span_text}</span>"

    @staticmethod
    def __make_html(text: str, encoding: Encoding, annotations: AnnotationList) -> str:
        char_states = EncodingVisualizer.__make_char_states(text, encoding, annotations)
        current_consecutive_chars = [char_states[0]]
        prev_anno_ix = char_states[0].anno_ix
        spans = []
        label_colors_dict = EncodingVisualizer.calculate_label_colors(annotations)
        cur_anno_ix = char_states[0].anno_ix
        if cur_anno_ix is not None:
            # If we started in an  annotation make a span for it
            anno = annotations[cur_anno_ix]
            label = anno.label
            color = label_colors_dict[label]
            spans.append(f'<span class="annotation" style="color:{color}" data-label="{label}">')

        for cs in char_states[1:]:
            cur_anno_ix = cs.anno_ix
            if cur_anno_ix != prev_anno_ix:
                # If we've transitioned in or out of an annotation
                spans.append(
                    # Create a span from the current consecutive characters
                    EncodingVisualizer.consecutive_chars_to_html(
                        current_consecutive_chars,
                        text=text,
                        encoding=encoding,
                    )
                )
                current_consecutive_chars = [cs]

                if prev_anno_ix is not None:
                    # if we transitioned out of an annotation close it's span
                    spans.append("</span>")
                if cur_anno_ix is not None:
                    # If we entered a new annotation make a span for it
                    anno = annotations[cur_anno_ix]
                    label = anno.label
                    color = label_colors_dict[label]
                    spans.append(f'<span class="annotation" style="color:{color}" data-label="{label}">')
            prev_anno_ix = cur_anno_ix

            if cs.partition_key() == current_consecutive_chars[0].partition_key():
                # If the current charchter is in the same "group" as the previous one
                current_consecutive_chars.append(cs)
            else:
                # Otherwise we make a span for the previous group
                spans.append(
                    EncodingVisualizer.consecutive_chars_to_html(
                        current_consecutive_chars,
                        text=text,
                        encoding=encoding,
                    )
                )
                # An reset the consecutive_char_list to form a new group
                current_consecutive_chars = [cs]
        # All that's left is to fill out the final span
        # TODO I think there is an edge case here where an annotation's span might not close
        spans.append(
            EncodingVisualizer.consecutive_chars_to_html(
                current_consecutive_chars,
                text=text,
                encoding=encoding,
            )
        )
        res = HTMLBody(spans)  # Send the list of spans to the body of our html
        return res

    @staticmethod
    def __make_anno_map(text: str, annotations: AnnotationList) -> PartialIntList:
        """
        Args:
            text (:obj:`str`):
                The raw text we want to align to

            annotations (:obj:`AnnotationList`):
                A (possibly empty) list of annotations

        Returns:
            A list of  length len(text) whose entry at index i is None if there is no annotation on
            charachter i or k, the index of the annotation that covers index i where k is with
            respect to the list of annotations
        """
        annotation_map = [None] * len(text)
        for anno_ix, a in enumerate(annotations):
            for i in range(a.start, a.end):
                annotation_map[i] = anno_ix
        return annotation_map

    @staticmethod
    def __make_char_states(text: str, encoding: Encoding, annotations: AnnotationList) -> List[CharState]:
        """
        For each character in the original text, we emit a tuple representing it's "state":

            * which token_ix it corresponds to
            * which word_ix it corresponds to
            * which annotation_ix it corresponds to

        Args:
            text (:obj:`str`):
                The raw text we want to align to

            annotations (:obj:`List[Annotation]`):
                A (possibly empty) list of annotations

            encoding: (:class:`~tokenizers.Encoding`):
                The encoding returned from the tokenizer

        Returns:
            :obj:`List[CharState]`: A list of CharStates, indicating for each char in the text what
            it's state is
        """
        annotation_map = EncodingVisualizer.__make_anno_map(text, annotations)
        # Todo make this a dataclass or named tuple
        char_states: List[CharState] = [CharState(char_ix) for char_ix in range(len(text))]
        for token_ix, token in enumerate(encoding.tokens):
            offsets = encoding.token_to_chars(token_ix)
            if offsets is not None:
                start, end = offsets
                for i in range(start, end):
                    char_states[i].tokens.append(token_ix)
        for char_ix, anno_ix in enumerate(annotation_map):
            char_states[char_ix].anno_ix = anno_ix

        return char_states


def HTMLBody(children: List[str], css_styles=css) -> str:
    """
    Generates the full html with css from a list of html spans

    Args:
        children (:obj:`List[str]`):
            A list of strings, assumed to be html elements

        css_styles (:obj:`str`, `optional`):
            Optional alternative implementation of the css

    Returns:
        :obj:`str`: An HTML string with style markup
    """
    children_text = "".join(children)
    return f"""
    <html>
        <head>
            <style>
                {css_styles}
            </style>
        </head>
        <body>
            <div class="tokenized-text" dir=auto>
            {children_text}
            </div>
        </body>
    </html>
    """
