Source code for gen_class_data.gen_class_data

# Copyright 2019 Iguazio
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import pandas as pd
from typing import Optional, List
from sklearn.datasets import make_classification

from mlrun.execution import MLClientCtx


[docs] def gen_class_data( context: MLClientCtx, n_samples: int, m_features: int, k_classes: int, header: Optional[List[str]], label_column: Optional[str] = "labels", weight: float = 0.5, random_state: int = 1, key: str = "classifier-data", file_ext: str = "parquet", sk_params={} ): """Create a binary classification sample dataset and save. If no filename is given it will default to: "simdata-{n_samples}X{m_features}.parquet". Additional scikit-learn parameters can be set using **sk_params, please see https://scikit-learn.org/stable/modules/generated/sklearn.datasets.make_classification.html for more details. :param context: function context :param n_samples: number of rows/samples :param m_features: number of cols/features :param k_classes: number of classes :param header: header for features array :param label_column: column name of ground-truth series :param weight: fraction of sample negative value (ground-truth=0) :param random_state: rng seed (see https://scikit-learn.org/stable/glossary.html#term-random-state) :param key: key of data in artifact store :param file_ext: (pqt) extension for parquet file :param sk_params: additional parameters for `sklearn.datasets.make_classification` """ features, labels = make_classification( n_samples=n_samples, n_features=m_features, weights=weight, n_classes=k_classes, random_state=random_state, **sk_params) # make dataframes, add column names, concatenate (X, y) X = pd.DataFrame(features) if not header: X.columns = ["feat_" + str(x) for x in range(m_features)] else: X.columns = header y = pd.DataFrame(labels, columns=[label_column]) data = pd.concat([X, y], axis=1) context.log_dataset(key, df=data, format=file_ext, index=False)