# Copyright 2023 Iguazio
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from inspect import signature
from typing import Any, Dict, List, Union, Optional
import mlrun

try:
    import mlrun.model_monitoring.api
except ModuleNotFoundError:
    raise mlrun.errors.MLRunNotFoundError(
        f"Please update your `mlrun` version to >=1.5.0 or use an "
        f"older version of the batch inference function."
    )

import numpy as np
import pandas as pd
from mlrun.frameworks.auto_mlrun import AutoMLRun


def _prepare_result_set(x: pd.DataFrame, label_columns: List[str], y_pred: np.ndarray) -> pd.DataFrame:
    """
    Set default label column names and validate given names to prepare the result set - a concatenation of the inputs
    (x) and the model predictions (y_pred).

    :param x:             The inputs.
    :param label_columns: A list of strings representing the target column names to add to the predictions. Default name
                          will be used in case the list is empty (predicted_label_{i}).
    :param y_pred:        The model predictions on the inputs.

    :returns: The result set.

    raises MLRunInvalidArgumentError: If the labels columns amount do not match the outputs or if one of the label
                                       column already exists in the dataset.
    """
    # Prepare default target columns names if not provided:
    prediction_columns_amount = 1 if len(y_pred.shape) == 1 else y_pred.shape[1]
    if len(label_columns) == 0:
        # Add default label column names:
        if prediction_columns_amount == 1:
            label_columns = ["predicted_label"]
        else:
            label_columns = [
                f"predicted_label_{i}" for i in range(prediction_columns_amount)
            ]

    # Validate the label columns:
    if prediction_columns_amount != len(label_columns):
        # No equality between provided label column names and outputs amount:
        raise mlrun.errors.MLRunInvalidArgumentError(
            f"The number of predicted labels: {prediction_columns_amount} "
            f"is not equal to the given label columns: {len(label_columns)}"
        )
    common_labels = set(label_columns) & set(x.columns.tolist())
    if common_labels:
        # Label column exist in the original inputs:
        raise mlrun.errors.MLRunInvalidArgumentError(
            f"The labels: {common_labels} are already existed in the given dataset."
        )

    return pd.concat(
        [x, pd.DataFrame(y_pred, columns=label_columns, index=x.index)], axis=1
    )


def _get_sample_set_statistics_parameters(context: mlrun.MLClientCtx,
                                          model_endpoint_sample_set: Union[
                                              mlrun.DataItem, list, dict, pd.DataFrame, pd.Series, np.ndarray],
                                          model_artifact_feature_stats: dict,
                                          feature_columns: Optional[List],
                                          drop_columns: Optional[List],
                                          label_columns: Optional[List]) -> Dict[str, Any]:
    statics_input_full_dict = dict(sample_set=model_endpoint_sample_set,
                                   model_artifact_feature_stats=model_artifact_feature_stats,
                                   sample_set_columns=feature_columns,
                                   sample_set_drop_columns=drop_columns,
                                   sample_set_label_columns=label_columns)
    get_sample_statics_function = mlrun.model_monitoring.api.get_sample_set_statistics
    statics_function_input_dict = signature(get_sample_statics_function).parameters
    #  As a result of changes to input parameters in the mlrun-get_sample_set_statistics function,
    #  we will now send only the parameters it expects.
    statistics_input_filtered = {key: statics_input_full_dict[key] for key in statics_function_input_dict}
    if len(statistics_input_filtered) != len(statics_function_input_dict):
        context.logger.warning(f"get_sample_set_statistics is in an older version; "
                               "some parameters will not be sent to the function."
                               f" Expected input: {list(statics_function_input_dict.keys())},"
                               f" actual input: {list(statistics_input_filtered.keys())}")
    return statistics_input_filtered


def infer(
        context: mlrun.MLClientCtx,
        dataset: Union[mlrun.DataItem, list, dict, pd.DataFrame, pd.Series, np.ndarray],
        model_path: Union[str, mlrun.DataItem],
        drop_columns: Union[str, List[str], int, List[int]] = None,
        label_columns: Union[str, List[str]] = None,
        feature_columns: Union[str, List[str]] = None,
        log_result_set: bool = True,
        result_set_name: str = "prediction",
        batch_id: str = None,
        artifacts_tag: str = "",
        # Drift analysis parameters
        perform_drift_analysis: bool = None,
        endpoint_id: str = "",
        # The following model endpoint parameters are relevant only if:
        # perform drift analysis is not disabled
        # a new model endpoint record is going to be generated
        model_endpoint_name: str = "batch-infer",
        model_endpoint_sample_set: Union[
            mlrun.DataItem, list, dict, pd.DataFrame, pd.Series, np.ndarray
        ] = None,

        # the following parameters are deprecated and will be removed once the versioning mechanism is implemented
        # TODO: Remove the following parameters once FHUB-13 is resolved
        trigger_monitoring_job: Optional[bool] = None,
        batch_image_job: Optional[str] = None,
        model_endpoint_drift_threshold: Optional[float] = None,
        model_endpoint_possible_drift_threshold: Optional[float] = None,

        # prediction kwargs to pass to the model predict function
        **predict_kwargs: Dict[str, Any],

):
    """
    Perform a prediction on the provided dataset using the specified model.
    Ensure that the model has already been logged under the current project.

    If you wish to apply monitoring tools (e.g., drift analysis), set the perform_drift_analysis parameter to True.
    This will create a new model endpoint record under the specified model_endpoint_name.
    Additionally, ensure that model monitoring is enabled at the project level by calling the
    project.enable_model_monitoring() function. You can also apply monitoring to an existing model by providing its
    endpoint id or name, and the monitoring tools will be applied to that endpoint.

    At the moment, this function is supported for `mlrun>=1.5.0` versions.

    :param context:                                 MLRun context.
    :param dataset:                                 The dataset to infer through the model. Provided as an input (DataItem)
                                                    that represents Dataset artifact / Feature vector URI.
                                                    If using MLRun SDK, `dataset` can also be provided as a list, dictionary or
                                                    numpy array.
    :param model_path:                              Model store uri (should start with store://). Provided as an input (DataItem).
                                                    If using MLRun SDK, `model_path` can also be provided as a parameter (string).
                                                    To generate a valid model store URI, please log the model before running this function.
                                                    If `endpoint_id` of existing model endpoint is provided, make sure
                                                    that it has a similar model store path, otherwise the drift analysis
                                                    won't be triggered.
    :param drop_columns:                            A string / integer or a list of strings / integers that represent the column names
                                                    / indices to drop. When the dataset is a list or a numpy array this parameter must
                                                    be represented by integers.
    :param label_columns:                           The target label(s) of the column(s) in the dataset for Regression or
                                                    Classification tasks. The label column can be accessed from the model object, or
                                                    the feature vector provided if available.
    :param feature_columns:                         List of feature columns that will be used to build the dataframe when dataset is
                                                    from type list or numpy array.
    :param log_result_set:                          Whether to log the result set - a DataFrame of the given inputs concatenated with
                                                    the predictions. Defaulted to True.
    :param result_set_name:                         The db key to set name of the prediction result and the filename. Defaulted to
                                                    'prediction'.
    :param batch_id:                                The ID of the given batch (inference dataset). If `None`, it will be generated.
                                                    Will be logged as a result of the run.
    :param artifacts_tag:                           Tag to use for prediction set result artifact.
    :param perform_drift_analysis:                  Whether to perform drift analysis between the sample set of the model object to the
                                                    dataset given. By default, None, which means it will perform drift analysis if the
                                                    model already has feature stats that are considered as a reference sample set.
                                                    Performing drift analysis on a new endpoint id will generate a new model endpoint
                                                    record.
    :param endpoint_id:                             Model endpoint unique ID. If `perform_drift_analysis` was set, the endpoint_id
                                                    will be used either to perform the analysis on existing model endpoint or to
                                                    generate a new model endpoint record.
    :param model_endpoint_name:                     If a new model endpoint is generated, the model name will be presented under this
                                                    endpoint.
    :param model_endpoint_sample_set:               A sample dataset to give to compare the inputs in the drift analysis.
                                                    Can be provided as an input (DataItem) or as a parameter (e.g. string, list, DataFrame).
                                                    The default chosen sample set will always be the one who is set in the model artifact itself.
    :param trigger_monitoring_job:                  Whether to trigger the batch drift analysis after the infer job.
    :param batch_image_job:                         The image that will be used to register the monitoring batch job if not exist.
                                                    By default, the image is mlrun/mlrun.
    :param model_endpoint_drift_threshold:          The threshold of which to mark drifts. Defaulted to 0.7.
    :param model_endpoint_possible_drift_threshold: The threshold of which to mark possible drifts. Defaulted to 0.5.

    raises MLRunInvalidArgumentError: if both `model_path` and `endpoint_id` are not provided
    """


    if trigger_monitoring_job:
        context.logger.warning("The `trigger_monitoring_job` parameter is deprecated and will be removed once the versioning mechanism is implemented. "
                               "if you are using mlrun<1.7.0, please import the previous version of this function, for example "
                               "'hub://batch_inference_v2:2.5.0'.")
    if batch_image_job:
        context.logger.warning("The `batch_image_job` parameter is deprecated and will be removed once the versioning mechanism is implemented. "
                               "if you are using mlrun<1.7.0, please import the previous version of this function, for example "
                               "'hub://batch_inference_v2:2.5.0'.")
    if model_endpoint_drift_threshold:
        context.logger.warning("The `model_endpoint_drift_threshold` parameter is deprecated and will be removed once the versioning mechanism is implemented. "
                               "if you are using mlrun<1.7.0, please import the previous version of this function, for example "
                               "'hub://batch_inference_v2:2.5.0'.")
    if model_endpoint_possible_drift_threshold:
        context.logger.warning("The `model_endpoint_possible_drift_threshold` parameter is deprecated and will be removed once the versioning mechanism is implemented. "
                               "if you are using mlrun<1.7.0, please import the previous version of this function, for example "
                               "'hub://batch_inference_v2:2.5.0'.")

    # Loading the model:
    context.logger.info(f"Loading model...")
    if isinstance(model_path, mlrun.DataItem):
        model_path = model_path.artifact_url
    if not mlrun.datastore.is_store_uri(model_path):
        raise mlrun.errors.MLRunInvalidArgumentError(
            f"The provided model path ({model_path}) is invalid - should start with `store://`. "
            f"Please make sure that you have logged the model using `project.log_model()` "
            f"which generates a unique store uri for the logged model."
        )
    model_handler = AutoMLRun.load_model(model_path=model_path, context=context)

    if label_columns is None:
        label_columns = [
            output.name for output in model_handler._model_artifact.spec.outputs
        ]

    if feature_columns is None:
        feature_columns = [
            input.name for input in model_handler._model_artifact.spec.inputs
        ]

    # Get dataset by object, URL or by FeatureVector:
    context.logger.info(f"Loading data...")
    x, label_columns = mlrun.model_monitoring.api.read_dataset_as_dataframe(
        dataset=dataset,
        feature_columns=feature_columns,
        label_columns=label_columns,
        drop_columns=drop_columns,
    )

    # Predict:
    context.logger.info(f"Calculating prediction...")
    y_pred = model_handler.model.predict(x, **predict_kwargs)

    # Prepare the result set:
    result_set = _prepare_result_set(x=x, label_columns=label_columns, y_pred=y_pred)

    # Check for logging the result set:
    if log_result_set:
        mlrun.model_monitoring.api.log_result(
            context=context,
            result_set_name=result_set_name,
            result_set=result_set,
            artifacts_tag=artifacts_tag,
            batch_id=batch_id,
        )

    # Check for performing drift analysis
    if (
            perform_drift_analysis is None
            and model_handler._model_artifact.spec.feature_stats is not None
    ):
        perform_drift_analysis = True
    if perform_drift_analysis:
        context.logger.info("Performing drift analysis...")
        # Get the sample set statistics (either from the sample set or from the statistics logged with the model)
        statistics_input_filtered = _get_sample_set_statistics_parameters(
            context=context,
            model_endpoint_sample_set=model_endpoint_sample_set,
            model_artifact_feature_stats=model_handler._model_artifact.spec.feature_stats,
            feature_columns=feature_columns,
            drop_columns=drop_columns,
            label_columns=label_columns)
        sample_set_statistics = mlrun.model_monitoring.api.get_sample_set_statistics(**statistics_input_filtered)
        mlrun.model_monitoring.api.record_results(
            project=context.project,
            context=context,
            endpoint_id=endpoint_id,
            model_path=model_path,
            model_endpoint_name=model_endpoint_name,
            infer_results_df=result_set.copy(),
            sample_set_statistics=sample_set_statistics,
        )