# Copyright 2019 Iguazio
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Generated by nuclio.export.NuclioExporter

import warnings
from typing import Union

import mlrun
import numpy as np

warnings.simplefilter(action="ignore", category=FutureWarning)

import mlrun.feature_store as fstore
import pandas as pd
import plotly.express as px
import plotly.figure_factory as ff
import plotly.graph_objects as go
from mlrun.artifacts import (
    Artifact,
    DatasetArtifact,
    PlotlyArtifact,
    TableArtifact,
    update_dataset_meta,
)
from mlrun.datastore import DataItem
from mlrun.execution import MLClientCtx
from mlrun.feature_store import FeatureSet, FeatureVector
from plotly.subplots import make_subplots

pd.set_option("display.float_format", lambda x: "%.2f" % x)
MAX_SIZE_OF_DF = 500000


def analyze(
    context: MLClientCtx,
    name: str = "dataset",
    table: Union[FeatureSet, DataItem] = None,
    label_column: str = None,
    plots_dest: str = "plots",
    random_state: int = 1,
    problem_type: str = "classification",
    dask_key: str = "dask_key",
    dask_function: str = None,
    dask_client=None,
) -> None:
    """
    The function will output the following artifacts per
    column within the data frame (based on data types)
    If the data has more than 500,000 sample we
    sample randomly 500,000 samples:

    describe csv
    histograms
    scatter-2d
    violin chart
    correlation-matrix chart
    correlation-matrix csv
    imbalance pie chart
    imbalance-weights-vec csv

    :param context:                 The function context
    :param name:                    Key of dataset to database ("dataset" for default)
    :param table:                   MLRun input pointing to pandas dataframe (csv/parquet file path) or FeatureSet
                                    as param
    :param label_column:            Ground truth column label
    :param plots_dest:              Destination folder of summary plots (relative to artifact_path)
                                    ("plots" for default)
    :param random_state:            When the table has more than 500,000 samples, we sample randomly 500,000 samples
    :param problem_type             The type of the ML problem the data facing - regression, classification or None
                                    (classification for default)
    :param dask_key:                Key of dataframe in dask client "datasets" attribute
    :param dask_function:           Dask function url (db://..)
    :param dask_client:             Dask client object
    """
    data_item, featureset, creat, update = False, False, False, False
    get_from_table = True
    if dask_function or dask_client:
        data_item, creat = True, True
        if dask_function:
            client = mlrun.import_function(dask_function).client
        elif dask_client:
            client = dask_client
        else:
            raise ValueError("dask client was not provided")

        if dask_key in client.datasets:
            df = client.get_dataset(dask_key)
            data_item, creat, get_from_table = True, True, False
        elif table:
            get_from_table = True
        else:
            context.logger.info(
                f"only these datasets are available {client.datasets} in client {client}"
            )
            raise Exception("dataset not found on dask cluster")

    if get_from_table:
        if type(table) == DataItem:
            if table.meta is None:
                data_item, creat, update = True, True, False
            elif table.meta.kind == "dataset":
                data_item, creat, update = True, False, True
            elif table.meta.kind == "FeatureVector":
                data_item, creat, update = True, False, False
            elif table.meta.kind == "FeatureSet":
                featureset, creat, update = True, False, False

        if data_item:
            df = table.as_df()
        elif featureset:
            project_name, set_name = (
                table._path.split("/")[2],
                table._path.split("/")[4],
            )
            feature_set = fstore.get_feature_set(
                f"store://feature-sets/{project_name}/{set_name}"
            )
            df = feature_set.to_dataframe()
        else:
            context.logger.error(f"Wrong table type.")
            return

    if df.size > MAX_SIZE_OF_DF:
        df = df.sample(n=int(MAX_SIZE_OF_DF / df.shape[1]), random_state=random_state)
    extra_data = {}

    if label_column not in df.columns:
        label_column = None

    extra_data["describe csv"] = context.log_artifact(
        TableArtifact("describe-csv", df=df.describe()),
        local_path=f"{plots_dest}/describe.csv",
    )

    try:
        _create_histogram_mat_artifact(
            context, df, extra_data, label_column, plots_dest
        )
    except Exception as e:
        context.logger.warn(f"Failed to create histogram matrix artifact due to: {e}")
    try:
        _create_features_histogram_artifacts(
            context, df, extra_data, label_column, plots_dest, problem_type
        )
    except Exception as e:
        context.logger.warn(f"Failed to create pairplot histograms due to: {e}")
    try:
        _create_features_2d_scatter_artifacts(
            context, df, extra_data, label_column, plots_dest, problem_type
        )
    except Exception as e:
        context.logger.warn(f"Failed to create pairplot 2d_scatter due to: {e}")
    try:
        _create_violin_artifact(context, df, extra_data, plots_dest)
    except Exception as e:
        context.logger.warn(f"Failed to create violin distribution plots due to: {e}")
    try:
        _create_imbalance_artifact(
            context, df, extra_data, label_column, plots_dest, problem_type
        )
    except Exception as e:
        context.logger.warn(f"Failed to create class imbalance plot due to: {e}")
    try:
        _create_corr_artifact(context, df, extra_data, label_column, plots_dest)
    except Exception as e:
        context.logger.warn(f"Failed to create features correlation plot due to: {e}")

    if not data_item:
        return

    artifact = table.artifact_url
    if creat:  # dataset not stored
        artifact = DatasetArtifact(
            key="dataset", stats=True, df=df, extra_data=extra_data
        )
        artifact = context.log_artifact(artifact, db_key=name)
        context.logger.info(f"The data set is logged to the project under {name} name")

    if update:
        update_dataset_meta(artifact, extra_data=extra_data)
        context.logger.info(f"The data set named {name} is updated")

    # TODO : 3-D plot on on selected features.
    # TODO : Reintegration plot on on selected features.
    # TODO : PCA plot (with options)


def _create_histogram_mat_artifact(
    context: MLClientCtx,
    df: pd.DataFrame,
    extra_data: dict,
    label_column: str,
    plots_dest: str,
):
    """
    Create and log a histogram matrix artifact
    """
    context.log_artifact(
        item=Artifact(
            key="hist",
            body=b" Deprecated, see the artifacts scatter-2d "
            b"and histograms instead",
        ),
        local_path=f"{plots_dest}/hist.html",
    )


def _create_features_histogram_artifacts(
    context: MLClientCtx,
    df: pd.DataFrame,
    extra_data: dict,
    label_column: str,
    plots_dest: str,
    problem_type: str,
):
    """
    Create and log a histogram artifact for each feature
    """

    figs = dict()
    first_feature_name = ""
    if label_column is not None and problem_type == "classification":
        all_labels = df[label_column].unique()
    visible = True
    for (columnName, _) in df.iteritems():
        if columnName == label_column:
            continue

        if label_column is not None and problem_type == "classification":
            for label in all_labels:
                sub_fig = go.Histogram(
                    histfunc="count",
                    x=df.loc[df[label_column] == label][columnName],
                    name=str(label),
                    visible=visible,
                )
                figs[f"{columnName}@?@{label}"] = sub_fig
        else:
            sub_fig = go.Histogram(histfunc="count", x=df[columnName], visible=visible)
            figs[f"{columnName}@?@{1}"] = sub_fig
        if visible:
            first_feature_name = columnName
        visible = False

    fig = go.Figure()
    for k in figs.keys():
        fig.add_trace(figs[k])

    fig.update_layout(
        updatemenus=[
            {
                "buttons": [
                    {
                        "label": column_name,
                        "method": "update",
                        "args": [
                            {
                                "visible": [
                                    key.split("@?@")[0] == column_name
                                    for key in figs.keys()
                                ],
                                "xaxis": {
                                    "range": [
                                        min(df[column_name]),
                                        max(df[column_name]),
                                    ]
                                },
                            },
                            {"title": f"Histogram of {column_name}"},
                        ],
                    }
                    for column_name in df.columns
                    if column_name != label_column
                ],
                "direction": "down",
                "pad": {"r": 10, "t": 10},
                "showactive": True,
                "x": 0.25,
                "xanchor": "left",
                "y": 1.1,
                "yanchor": "top",
            }
        ],
        annotations=[
            dict(
                text="Select Feature Name ",
                showarrow=False,
                x=0,
                y=1.05,
                yref="paper",
                xref="paper",
                align="left",
                xanchor="left",
                yanchor="top",
                font={
                    "color": "blue",
                },
            )
        ],
    )

    fig.update_layout(
        width=600,
        height=400,
        autosize=False,
        margin=dict(t=100, b=0, l=0, r=0),
        template="plotly_white",
    )

    fig.update_layout(title_text=f"Histograms of {first_feature_name}")
    extra_data[f"histograms"] = context.log_artifact(
        PlotlyArtifact(key=f"histograms", figure=fig),
        local_path=f"{plots_dest}/histograms.html",
    )


def _create_features_2d_scatter_artifacts(
    context: MLClientCtx,
    df: pd.DataFrame,
    extra_data: dict,
    label_column: str,
    plots_dest: str,
    problem_type: str,
):
    """
    Create and log a scatter-2d artifact for each couple of features
    """
    features = [
        columnName for (columnName, _) in df.iteritems() if columnName != label_column
    ]
    max_feature_len = float(max(len(elem) for elem in features))
    if label_column is not None:
        labels = sorted(df[label_column].unique())
    else:
        labels = [None]
    fig = go.Figure()
    if label_column is not None and problem_type == "classification":
        for l in labels:
            fig.add_trace(
                go.Scatter(
                    x=df.loc[df[label_column] == l][features[0]],
                    y=df.loc[df[label_column] == l][features[0]],
                    mode="markers",
                    visible=True,
                    showlegend=True,
                    name=str(l),
                )
            )
    elif label_column is None:
        fig.add_trace(
            go.Scatter(
                x=df[features[0]],
                y=df[features[0]],
                mode="markers",
                visible=True,
            )
        )
    elif problem_type == "regression":
        fig.add_trace(
            go.Scatter(
                x=df[features[0]],
                y=df[features[0]],
                mode="markers",
                marker=dict(
                    color=df[label_column], colorscale="Viridis", showscale=True
                ),
                visible=True,
            )
        )

    x_buttons = []
    y_buttons = []

    for ncol in features:
        if problem_type == "classification" and label_column is not None:
            x_buttons.append(
                dict(
                    method="update",
                    label=ncol,
                    args=[
                        {"x": [df.loc[df[label_column] == l][ncol] for l in labels]},
                        np.arange(len(labels)).tolist(),
                    ],
                )
            )

            y_buttons.append(
                dict(
                    method="update",
                    label=ncol,
                    args=[
                        {"y": [df.loc[df[label_column] == l][ncol] for l in labels]},
                        np.arange(len(labels)).tolist(),
                    ],
                )
            )
        else:
            x_buttons.append(
                dict(method="update", label=ncol, args=[{"x": [df[ncol]]}])
            )

            y_buttons.append(
                dict(method="update", label=ncol, args=[{"y": [df[ncol]]}])
            )

    # Pass buttons to the updatemenus argument
    fig.update_layout(
        updatemenus=[
            dict(buttons=x_buttons, direction="up", x=0.5, y=-0.1),
            dict(buttons=y_buttons, direction="down", x=-max_feature_len / 100, y=0.5),
        ]
    )

    fig.update_layout(
        width=600,
        height=400,
        autosize=False,
        margin=dict(t=100, b=0, l=0, r=0),
        template="plotly_white",
    )

    fig.update_layout(title_text=f"Scatter-2d")
    extra_data[f"scatter-2d"] = context.log_artifact(
        PlotlyArtifact(key=f"scatter-2d", figure=fig),
        local_path=f"{plots_dest}/scatter-2d.html",
    )


def _create_violin_artifact(
    context: MLClientCtx, df: pd.DataFrame, extra_data: dict, plots_dest: str
):
    """
    Create and log a violin artifact
    """
    cols = 5
    rows = (df.shape[1] // cols) + 1
    fig = make_subplots(rows=rows, cols=cols)

    plot_num = 0

    for (columnName, columnData) in df.iteritems():
        violin = go.Violin(
            x=[columnName] * columnData.shape[0],
            y=columnData,
            name=columnName,
        )

        fig.add_trace(
            violin,
            row=(plot_num // cols) + 1,
            col=(plot_num % cols) + 1,
        )

        plot_num += 1

    fig["layout"].update(
        height=(rows + 1) * 200,
        width=(cols + 1) * 200,
        title="Violin Plots",
    )

    fig.update_layout(showlegend=False)
    extra_data["violin"] = context.log_artifact(
        PlotlyArtifact(key="violin", figure=fig),
        local_path=f"{plots_dest}/violin.html",
    )


def _create_imbalance_artifact(
    context: MLClientCtx,
    df: pd.DataFrame,
    extra_data: dict,
    label_column: str,
    plots_dest: str,
    problem_type: str,
):
    """
    Create and log an imbalance class artifact (csv + plot)
    """
    if label_column:
        if problem_type == "classification":
            labels_count = df[label_column].value_counts().sort_index()
            df_labels_count = pd.DataFrame(labels_count)
            df_labels_count.rename(columns={label_column: "Total"}, inplace=True)
            df_labels_count[label_column] = labels_count.index
            df_labels_count["weights"] = df_labels_count["Total"] / sum(
                df_labels_count["Total"]
            )

            fig = px.pie(df_labels_count, names=label_column, values="Total")
        else:
            fig = px.histogram(
                histfunc="count",
                x=df[label_column],
            )
            hist = np.histogram(df[label_column])
            df_labels_count = pd.DataFrame(
                {"min_val": hist[1], "count": hist[0].tolist() + [0]}
            )
        fig.update_layout(title_text="Labels Imbalance")
        extra_data["imbalance"] = context.log_artifact(
            PlotlyArtifact(key="imbalance", figure=fig),
            local_path=f"{plots_dest}/imbalance.html",
        )
        extra_data["imbalance-csv"] = context.log_artifact(
            TableArtifact("imbalance-weights-vec", df=df_labels_count),
            local_path=f"{plots_dest}/imbalance-weights-vec.csv",
        )


def _create_corr_artifact(
    context: MLClientCtx,
    df: pd.DataFrame,
    extra_data: dict,
    label_column: str,
    plots_dest: str,
):
    """
    Create and log an correlation-matrix artifact (csv + plot)
    """
    if label_column is not None:
        df = df.drop([label_column], axis=1)
    tblcorr = df.corr()
    extra_data["correlation-matrix-csv"] = context.log_artifact(
        TableArtifact("correlation-matrix-csv", df=tblcorr, visible=True),
        local_path=f"{plots_dest}/correlation-matrix.csv",
    )

    z = tblcorr.values.tolist()
    z_text = [["{:.2f}".format(y) for y in x] for x in z]
    fig = ff.create_annotated_heatmap(
        z,
        x=list(tblcorr.columns),
        y=list(tblcorr.columns),
        annotation_text=z_text,
        colorscale="agsunset",
    )
    fig["layout"]["yaxis"]["autorange"] = "reversed"  # l -> r
    fig.update_layout(title_text="Correlation matrix")
    fig["data"][0]["showscale"] = True

    extra_data["correlation"] = context.log_artifact(
        PlotlyArtifact(key="correlation", figure=fig),
        local_path=f"{plots_dest}/correlation.html",
    )