from __future__ import annotations

import json
import logging
import os
import shutil
from collections.abc import Iterable
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable

import numpy as np
import torch
import transformers
from packaging import version
from torch import Tensor, nn
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data import DataLoader
from tqdm.autonotebook import trange
from transformers import TrainerCallback, TrainerControl, TrainerState

from sentence_transformers.datasets.NoDuplicatesDataLoader import NoDuplicatesDataLoader
from sentence_transformers.datasets.SentenceLabelDataset import SentenceLabelDataset
from sentence_transformers.training_args import (
    BatchSamplers,
    MultiDatasetBatchSamplers,
    SentenceTransformerTrainingArguments,
)
from sentence_transformers.util import batch_to_device, fullname, is_datasets_available

from .evaluation import SentenceEvaluator
from .model_card_templates import ModelCardTemplate

if is_datasets_available():
    from datasets import Dataset, DatasetDict

logger = logging.getLogger(__name__)

if TYPE_CHECKING:
    from sentence_transformers.readers.InputExample import InputExample
    from sentence_transformers.SentenceTransformer import SentenceTransformer


class SaveModelCallback(TrainerCallback):
    """A Callback to save the model to the `output_dir`.

    There are two cases:
    1. save_best_model is True and evaluator is defined:
        We save on evaluate, but only if the new model is better than the currently saved one
        according to the evaluator.
    2. If evaluator is not defined:
        We save after the model has been trained.
    """

    def __init__(self, output_dir: str, evaluator: SentenceEvaluator | None, save_best_model: bool) -> None:
        super().__init__()
        self.output_dir = output_dir
        self.evaluator = evaluator
        self.save_best_model = save_best_model
        self.best_metric = None

    def is_better(self, new_metric: float) -> bool:
        if getattr(self.evaluator, "greater_is_better", True):
            return new_metric > self.best_metric
        return new_metric < self.best_metric

    def on_evaluate(
        self,
        args: SentenceTransformerTrainingArguments,
        state: TrainerState,
        control: TrainerControl,
        metrics: dict[str, Any],
        model: SentenceTransformer,
        **kwargs,
    ) -> None:
        if self.evaluator is not None and self.save_best_model:
            metric_key = getattr(self.evaluator, "primary_metric", "evaluator")
            for key, value in metrics.items():
                if key.endswith(metric_key):
                    if self.best_metric is None or self.is_better(value):
                        self.best_metric = value
                        model.save(self.output_dir)

    def on_train_end(
        self,
        args: SentenceTransformerTrainingArguments,
        state: TrainerState,
        control: TrainerControl,
        model: SentenceTransformer,
        **kwargs,
    ) -> None:
        if self.evaluator is None:
            model.save(self.output_dir)


class EvaluatorCallback(TrainerCallback):
    """The SentenceTransformers.fit method always ran the evaluator on every epoch,
    in addition to every "evaluation_steps". This callback is responsible for that.

    The `.trainer` must be provided after the trainer has been created.
    """

    def __init__(self, evaluator: SentenceEvaluator) -> None:
        super().__init__()
        self.evaluator = evaluator
        self.metric_key_prefix = "eval"
        self.trainer = None

    def on_epoch_end(
        self,
        args: SentenceTransformerTrainingArguments,
        state: TrainerState,
        control: TrainerControl,
        model: SentenceTransformer,
        **kwargs,
    ) -> None:
        evaluator_metrics = self.evaluator(model, epoch=state.epoch)
        if not isinstance(evaluator_metrics, dict):
            evaluator_metrics = {"evaluator": evaluator_metrics}

        # Prefix all keys with metric_key_prefix + '_'
        for key in list(evaluator_metrics.keys()):
            if not key.startswith(f"{self.metric_key_prefix}_"):
                evaluator_metrics[f"{self.metric_key_prefix}_{key}"] = evaluator_metrics.pop(key)

        if self.trainer is not None:
            self.trainer.callback_handler.on_evaluate(args, state, control, metrics=evaluator_metrics)


class OriginalCallback(TrainerCallback):
    """A Callback to invoke the original callback function that was provided to SentenceTransformer.fit()

    This callback has the following signature: `(score: float, epoch: int, steps: int) -> None`
    """

    def __init__(self, callback: Callable[[float, int, int], None], evaluator: SentenceEvaluator) -> None:
        super().__init__()
        self.callback = callback
        self.evaluator = evaluator

    def on_evaluate(
        self,
        args: transformers.TrainingArguments,
        state: TrainerState,
        control: TrainerControl,
        metrics: dict[str, Any],
        **kwargs,
    ) -> None:
        metric_key = getattr(self.evaluator, "primary_metric", "evaluator")
        for key, value in metrics.items():
            if key.endswith(metric_key):
                return self.callback(value, state.epoch, state.global_step)


class FitMixin:
    """Mixin class for injecting the `fit` method into Sentence Transformers"""

    def fit(
        self,
        train_objectives: Iterable[tuple[DataLoader, nn.Module]],
        evaluator: SentenceEvaluator = None,
        epochs: int = 1,
        steps_per_epoch=None,
        scheduler: str = "WarmupLinear",
        warmup_steps: int = 10000,
        optimizer_class: type[Optimizer] = torch.optim.AdamW,
        optimizer_params: dict[str, object] = {"lr": 2e-5},
        weight_decay: float = 0.01,
        evaluation_steps: int = 0,
        output_path: str = None,
        save_best_model: bool = True,
        max_grad_norm: float = 1,
        use_amp: bool = False,
        callback: Callable[[float, int, int], None] = None,
        show_progress_bar: bool = True,
        checkpoint_path: str = None,
        checkpoint_save_steps: int = 500,
        checkpoint_save_total_limit: int = 0,
    ) -> None:
        """
        Deprecated training method from before Sentence Transformers v3.0, it is recommended to use
        :class:`~sentence_transformers.trainer.SentenceTransformerTrainer` instead. This method uses
        :class:`~sentence_transformers.trainer.SentenceTransformerTrainer` behind the scenes, but does
        not provide as much flexibility as the Trainer itself.

        This training approach uses a list of DataLoaders and Loss functions to train the model. Each DataLoader
        is sampled in turn for one batch. We sample only as many batches from each DataLoader as there are in the
        smallest one to make sure of equal training with each dataset, i.e. round robin sampling.

        This method should produce equivalent results in v3.0+ as before v3.0, but if you encounter any issues
        with your existing training scripts, then you may wish to use
        :meth:`SentenceTransformer.old_fit <sentence_transformers.SentenceTransformer.old_fit>` instead.
        That uses the old training method from before v3.0.

        Args:
            train_objectives: Tuples of (DataLoader, LossFunction). Pass
                more than one for multi-task learning
            evaluator: An evaluator (sentence_transformers.evaluation)
                evaluates the model performance during training on held-
                out dev data. It is used to determine the best model
                that is saved to disc.
            epochs: Number of epochs for training
            steps_per_epoch: Number of training steps per epoch. If set
                to None (default), one epoch is equal the DataLoader
                size from train_objectives.
            scheduler: Learning rate scheduler. Available schedulers:
                constantlr, warmupconstant, warmuplinear, warmupcosine,
                warmupcosinewithhardrestarts
            warmup_steps: Behavior depends on the scheduler. For
                WarmupLinear (default), the learning rate is increased
                from o up to the maximal learning rate. After these many
                training steps, the learning rate is decreased linearly
                back to zero.
            optimizer_class: Optimizer
            optimizer_params: Optimizer parameters
            weight_decay: Weight decay for model parameters
            evaluation_steps: If > 0, evaluate the model using evaluator
                after each number of training steps
            output_path: Storage path for the model and evaluation files
            save_best_model: If true, the best model (according to
                evaluator) is stored at output_path
            max_grad_norm: Used for gradient normalization.
            use_amp: Use Automatic Mixed Precision (AMP). Only for
                Pytorch >= 1.6.0
            callback: Callback function that is invoked after each
                evaluation. It must accept the following three
                parameters in this order: `score`, `epoch`, `steps`
            show_progress_bar: If True, output a tqdm progress bar
            checkpoint_path: Folder to save checkpoints during training
            checkpoint_save_steps: Will save a checkpoint after so many
                steps
            checkpoint_save_total_limit: Total number of checkpoints to
                store
        """
        if not is_datasets_available():
            raise ImportError("Please install `datasets` to use this function: `pip install datasets`.")

        # Delayed import to counter the SentenceTransformers -> FitMixin -> SentenceTransformerTrainer -> SentenceTransformers circular import
        from sentence_transformers.trainer import SentenceTransformerTrainer

        data_loaders, loss_fns = zip(*train_objectives)

        # Clear the dataloaders from collate functions as we just want raw InputExamples
        def identity(batch):
            return batch

        for data_loader in data_loaders:
            data_loader.collate_fn = identity

        batch_size = 8
        batch_sampler = BatchSamplers.BATCH_SAMPLER
        # Convert dataloaders into a DatasetDict
        # TODO: This is rather inefficient, as we load all data into memory. We might benefit from a more efficient solution
        train_dataset_dict = {}
        for loader_idx, data_loader in enumerate(data_loaders, start=1):
            if isinstance(data_loader, NoDuplicatesDataLoader):
                batch_sampler = BatchSamplers.NO_DUPLICATES
            elif hasattr(data_loader, "dataset") and isinstance(data_loader.dataset, SentenceLabelDataset):
                batch_sampler = BatchSamplers.GROUP_BY_LABEL

            batch_size = getattr(data_loader, "batch_size", batch_size)
            texts = []
            labels = []
            for batch in data_loader:
                batch_texts, batch_labels = zip(*[(example.texts, example.label) for example in batch])
                texts += batch_texts
                labels += batch_labels
            dataset = Dataset.from_dict({f"sentence_{idx}": text for idx, text in enumerate(zip(*texts))})
            # Add label column, unless all labels are 0 (the default value for `labels` in InputExample)
            add_label_column = True
            try:
                if set(labels) == {0}:
                    add_label_column = False
            except TypeError:
                pass
            if add_label_column:
                dataset = dataset.add_column("label", labels)
            train_dataset_dict[f"_dataset_{loader_idx}"] = dataset

        train_dataset_dict = DatasetDict(train_dataset_dict)

        def _default_checkpoint_dir() -> str:
            dir_name = "checkpoints/model"
            idx = 1
            while Path(dir_name).exists() and len(list(Path(dir_name).iterdir())) != 0:
                dir_name = f"checkpoints/model_{idx}"
                idx += 1
            return dir_name

        # Convert loss_fns into a dict with `dataset_{idx}` keys
        loss_fn_dict = {f"_dataset_{idx}": loss_fn for idx, loss_fn in enumerate(loss_fns, start=1)}

        # Use steps_per_epoch to perhaps set max_steps
        max_steps = -1
        if steps_per_epoch is not None and steps_per_epoch > 0:
            if epochs == 1:
                max_steps = steps_per_epoch
            else:
                logger.warning(
                    "Setting `steps_per_epoch` alongside `epochs` > 1 no longer works. "
                    "We will train with the full datasets per epoch."
                )
                steps_per_epoch = None

        # Transformers renamed `evaluation_strategy` to `eval_strategy` in v4.41.0
        eval_strategy_key = (
            "eval_strategy"
            if version.parse(transformers.__version__) >= version.parse("4.41.0")
            else "evaluation_strategy"
        )
        args = SentenceTransformerTrainingArguments(
            output_dir=checkpoint_path or _default_checkpoint_dir(),
            batch_sampler=batch_sampler,
            multi_dataset_batch_sampler=MultiDatasetBatchSamplers.ROUND_ROBIN,
            per_device_train_batch_size=batch_size,
            per_device_eval_batch_size=batch_size,
            num_train_epochs=epochs,
            max_steps=max_steps,
            **{
                eval_strategy_key: "steps" if evaluation_steps is not None and evaluation_steps > 0 else "no",
            },
            eval_steps=evaluation_steps,
            # load_best_model_at_end=save_best_model, # <- TODO: Look into a good solution for save_best_model
            max_grad_norm=max_grad_norm,
            fp16=use_amp,
            disable_tqdm=not show_progress_bar,
            save_strategy="steps" if checkpoint_path is not None else "no",
            save_steps=checkpoint_save_steps,
            save_total_limit=checkpoint_save_total_limit,
        )

        if steps_per_epoch is None or steps_per_epoch == 0:
            steps_per_epoch = min([len(train_dataset) // batch_size for train_dataset in train_dataset_dict.values()])
        num_train_steps = int(steps_per_epoch * epochs)

        # Prepare optimizer & scheduler
        param_optimizer = list(self.named_parameters())

        no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"]
        optimizer_grouped_parameters = [
            {
                "params": [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
                "weight_decay": weight_decay,
            },
            {"params": [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], "weight_decay": 0.0},
        ]

        optimizer = optimizer_class(optimizer_grouped_parameters, **optimizer_params)
        scheduler_obj = self._get_scheduler(
            optimizer, scheduler=scheduler, warmup_steps=warmup_steps, t_total=num_train_steps
        )

        # Create callbacks
        callbacks = []
        if evaluator is not None:
            callbacks.append(EvaluatorCallback(evaluator))
            if callback is not None:
                callbacks.append(OriginalCallback(callback, evaluator))

        trainer = SentenceTransformerTrainer(
            model=self,
            args=args,
            train_dataset=train_dataset_dict,
            eval_dataset=None,
            loss=loss_fn_dict,
            evaluator=evaluator,
            optimizers=(optimizer, scheduler_obj),
            callbacks=callbacks,
        )
        # Set the trainer on the EvaluatorCallback, required for logging the metrics
        for callback in trainer.callback_handler.callbacks:
            if isinstance(callback, EvaluatorCallback):
                callback.trainer = trainer

        if output_path is not None:
            trainer.add_callback(SaveModelCallback(output_path, evaluator, save_best_model))

        trainer.train()

    @staticmethod
    def _get_scheduler(optimizer, scheduler: str, warmup_steps: int, t_total: int) -> LambdaLR:
        """
        Returns the correct learning rate scheduler. Available scheduler:

        - constantlr,
        - warmupconstant,
        - warmuplinear,
        - warmupcosine,
        - warmupcosinewithhardrestarts
        """
        scheduler = scheduler.lower()
        if scheduler == "constantlr":
            return transformers.get_constant_schedule(optimizer)
        elif scheduler == "warmupconstant":
            return transformers.get_constant_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps)
        elif scheduler == "warmuplinear":
            return transformers.get_linear_schedule_with_warmup(
                optimizer, num_warmup_steps=warmup_steps, num_training_steps=t_total
            )
        elif scheduler == "warmupcosine":
            return transformers.get_cosine_schedule_with_warmup(
                optimizer, num_warmup_steps=warmup_steps, num_training_steps=t_total
            )
        elif scheduler == "warmupcosinewithhardrestarts":
            return transformers.get_cosine_with_hard_restarts_schedule_with_warmup(
                optimizer, num_warmup_steps=warmup_steps, num_training_steps=t_total
            )
        else:
            raise ValueError(f"Unknown scheduler {scheduler}")

    def smart_batching_collate(self, batch: list[InputExample]) -> tuple[list[dict[str, Tensor]], Tensor]:
        """
        Transforms a batch from a SmartBatchingDataset to a batch of tensors for the model
        Here, batch is a list of InputExample instances: [InputExample(...), ...]

        Args:
            batch: a batch from a SmartBatchingDataset

        Returns:
            a batch of tensors for the model
        """
        texts = [example.texts for example in batch]
        sentence_features = [self.tokenize(sentence) for sentence in zip(*texts)]
        labels = [example.label for example in batch]

        # Use torch.from_numpy to convert the numpy array directly to a tensor,
        # which is the recommended approach for converting numpy arrays to tensors
        if labels and isinstance(labels[0], np.ndarray):
            labels_tensor = torch.from_numpy(np.stack(labels))
        else:
            labels_tensor = torch.tensor(labels)

        return sentence_features, labels_tensor

    """
    Temporary methods that will be removed when this refactor is complete:
    """

    def old_fit(
        self,
        train_objectives: Iterable[tuple[DataLoader, nn.Module]],
        evaluator: SentenceEvaluator = None,
        epochs: int = 1,
        steps_per_epoch=None,
        scheduler: str = "WarmupLinear",
        warmup_steps: int = 10000,
        optimizer_class: type[Optimizer] = torch.optim.AdamW,
        optimizer_params: dict[str, object] = {"lr": 2e-5},
        weight_decay: float = 0.01,
        evaluation_steps: int = 0,
        output_path: str = None,
        save_best_model: bool = True,
        max_grad_norm: float = 1,
        use_amp: bool = False,
        callback: Callable[[float, int, int], None] = None,
        show_progress_bar: bool = True,
        checkpoint_path: str = None,
        checkpoint_save_steps: int = 500,
        checkpoint_save_total_limit: int = 0,
    ) -> None:
        """
        Deprecated training method from before Sentence Transformers v3.0, it is recommended to use
        :class:`sentence_transformers.trainer.SentenceTransformerTrainer` instead. This method should
        only be used if you encounter issues with your existing training scripts after upgrading to v3.0+.

        This training approach uses a list of DataLoaders and Loss functions to train the model. Each DataLoader
        is sampled in turn for one batch. We sample only as many batches from each DataLoader as there are in the
        smallest one to make sure of equal training with each dataset, i.e. round robin sampling.

        Args:
            train_objectives: Tuples of (DataLoader, LossFunction). Pass
                more than one for multi-task learning
            evaluator: An evaluator (sentence_transformers.evaluation)
                evaluates the model performance during training on held-
                out dev data. It is used to determine the best model
                that is saved to disc.
            epochs: Number of epochs for training
            steps_per_epoch: Number of training steps per epoch. If set
                to None (default), one epoch is equal the DataLoader
                size from train_objectives.
            scheduler: Learning rate scheduler. Available schedulers:
                constantlr, warmupconstant, warmuplinear, warmupcosine,
                warmupcosinewithhardrestarts
            warmup_steps: Behavior depends on the scheduler. For
                WarmupLinear (default), the learning rate is increased
                from o up to the maximal learning rate. After these many
                training steps, the learning rate is decreased linearly
                back to zero.
            optimizer_class: Optimizer
            optimizer_params: Optimizer parameters
            weight_decay: Weight decay for model parameters
            evaluation_steps: If > 0, evaluate the model using evaluator
                after each number of training steps
            output_path: Storage path for the model and evaluation files
            save_best_model: If true, the best model (according to
                evaluator) is stored at output_path
            max_grad_norm: Used for gradient normalization.
            use_amp: Use Automatic Mixed Precision (AMP). Only for
                Pytorch >= 1.6.0
            callback: Callback function that is invoked after each
                evaluation. It must accept the following three
                parameters in this order: `score`, `epoch`, `steps`
            show_progress_bar: If True, output a tqdm progress bar
            checkpoint_path: Folder to save checkpoints during training
            checkpoint_save_steps: Will save a checkpoint after so many
                steps
            checkpoint_save_total_limit: Total number of checkpoints to
                store
        """

        ##Add info to model card
        # info_loss_functions = "\n".join(["- {} with {} training examples".format(str(loss), len(dataloader)) for dataloader, loss in train_objectives])
        info_loss_functions = []
        for dataloader, loss in train_objectives:
            info_loss_functions.extend(ModelCardTemplate.get_train_objective_info(dataloader, loss))
        info_loss_functions = "\n\n".join([text for text in info_loss_functions])

        info_fit_parameters = json.dumps(
            {
                "evaluator": fullname(evaluator),
                "epochs": epochs,
                "steps_per_epoch": steps_per_epoch,
                "scheduler": scheduler,
                "warmup_steps": warmup_steps,
                "optimizer_class": str(optimizer_class),
                "optimizer_params": optimizer_params,
                "weight_decay": weight_decay,
                "evaluation_steps": evaluation_steps,
                "max_grad_norm": max_grad_norm,
            },
            indent=4,
            sort_keys=True,
        )
        self._model_card_text = None
        self._model_card_vars["{TRAINING_SECTION}"] = ModelCardTemplate.__TRAINING_SECTION__.replace(
            "{LOSS_FUNCTIONS}", info_loss_functions
        ).replace("{FIT_PARAMETERS}", info_fit_parameters)

        if use_amp:
            from torch.cuda.amp import autocast

            scaler = torch.cuda.amp.GradScaler()

        self.to(self.device)

        dataloaders = [dataloader for dataloader, _ in train_objectives]

        # Use smart batching
        for dataloader in dataloaders:
            dataloader.collate_fn = self.smart_batching_collate

        loss_models = [loss for _, loss in train_objectives]
        for loss_model in loss_models:
            loss_model.to(self.device)

        self.best_score = -9999999

        if steps_per_epoch is None or steps_per_epoch == 0:
            steps_per_epoch = min([len(dataloader) for dataloader in dataloaders])

        num_train_steps = int(steps_per_epoch * epochs)

        # Prepare optimizers
        optimizers = []
        schedulers = []
        for loss_model in loss_models:
            param_optimizer = list(loss_model.named_parameters())

            no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"]
            optimizer_grouped_parameters = [
                {
                    "params": [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
                    "weight_decay": weight_decay,
                },
                {"params": [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], "weight_decay": 0.0},
            ]

            optimizer = optimizer_class(optimizer_grouped_parameters, **optimizer_params)
            scheduler_obj = self._get_scheduler(
                optimizer, scheduler=scheduler, warmup_steps=warmup_steps, t_total=num_train_steps
            )

            optimizers.append(optimizer)
            schedulers.append(scheduler_obj)

        global_step = 0
        data_iterators = [iter(dataloader) for dataloader in dataloaders]

        num_train_objectives = len(train_objectives)

        skip_scheduler = False
        for epoch in trange(epochs, desc="Epoch", disable=not show_progress_bar):
            training_steps = 0

            for loss_model in loss_models:
                loss_model.zero_grad()
                loss_model.train()

            for _ in trange(steps_per_epoch, desc="Iteration", smoothing=0.05, disable=not show_progress_bar):
                for train_idx in range(num_train_objectives):
                    loss_model = loss_models[train_idx]
                    optimizer = optimizers[train_idx]
                    scheduler = schedulers[train_idx]
                    data_iterator = data_iterators[train_idx]

                    try:
                        data = next(data_iterator)
                    except StopIteration:
                        data_iterator = iter(dataloaders[train_idx])
                        data_iterators[train_idx] = data_iterator
                        data = next(data_iterator)

                    features, labels = data
                    labels = labels.to(self.device)
                    features = list(map(lambda batch: batch_to_device(batch, self.device), features))

                    if use_amp:
                        with autocast():
                            loss_value = loss_model(features, labels)

                        scale_before_step = scaler.get_scale()
                        scaler.scale(loss_value).backward()
                        scaler.unscale_(optimizer)
                        torch.nn.utils.clip_grad_norm_(loss_model.parameters(), max_grad_norm)
                        scaler.step(optimizer)
                        scaler.update()

                        skip_scheduler = scaler.get_scale() != scale_before_step
                    else:
                        loss_value = loss_model(features, labels)
                        loss_value.backward()
                        torch.nn.utils.clip_grad_norm_(loss_model.parameters(), max_grad_norm)
                        optimizer.step()

                    optimizer.zero_grad()

                    if not skip_scheduler:
                        scheduler.step()

                training_steps += 1
                global_step += 1

                if evaluation_steps > 0 and training_steps % evaluation_steps == 0:
                    self._eval_during_training(
                        evaluator, output_path, save_best_model, epoch, training_steps, callback
                    )

                    for loss_model in loss_models:
                        loss_model.zero_grad()
                        loss_model.train()

                if (
                    checkpoint_path is not None
                    and checkpoint_save_steps is not None
                    and checkpoint_save_steps > 0
                    and global_step % checkpoint_save_steps == 0
                ):
                    self._save_checkpoint(checkpoint_path, checkpoint_save_total_limit, global_step)

            self._eval_during_training(evaluator, output_path, save_best_model, epoch, -1, callback)

        if evaluator is None and output_path is not None:  # No evaluator, but output path: save final model version
            self.save(output_path)

        if checkpoint_path is not None:
            self._save_checkpoint(checkpoint_path, checkpoint_save_total_limit, global_step)

    def _eval_during_training(self, evaluator, output_path, save_best_model, epoch, steps, callback) -> None:
        """Runs evaluation during the training"""
        eval_path = output_path
        if output_path is not None:
            os.makedirs(output_path, exist_ok=True)
            eval_path = os.path.join(output_path, "eval")
            os.makedirs(eval_path, exist_ok=True)

        if evaluator is not None:
            score = evaluator(self, output_path=eval_path, epoch=epoch, steps=steps)
            if callback is not None:
                callback(score, epoch, steps)
            if score > self.best_score:
                self.best_score = score
                if save_best_model:
                    self.save(output_path)

    def _save_checkpoint(self, checkpoint_path, checkpoint_save_total_limit, step) -> None:
        # Store new checkpoint
        self.save(os.path.join(checkpoint_path, str(step)))

        # Delete old checkpoints
        if checkpoint_save_total_limit is not None and checkpoint_save_total_limit > 0:
            old_checkpoints = []
            for subdir in os.listdir(checkpoint_path):
                if subdir.isdigit():
                    old_checkpoints.append({"step": int(subdir), "path": os.path.join(checkpoint_path, subdir)})

            if len(old_checkpoints) > checkpoint_save_total_limit:
                old_checkpoints = sorted(old_checkpoints, key=lambda x: x["step"])
                shutil.rmtree(old_checkpoints[0]["path"])
