Source code for gen_class_data.gen_class_data
# Copyright 2019 Iguazio
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import pandas as pd
from typing import Optional, List
from sklearn.datasets import make_classification
from mlrun.execution import MLClientCtx
[docs]
def gen_class_data(
context: MLClientCtx,
n_samples: int,
m_features: int,
k_classes: int,
header: Optional[List[str]],
label_column: Optional[str] = "labels",
weight: float = 0.5,
random_state: int = 1,
key: str = "classifier-data",
file_ext: str = "parquet",
sk_params={}
):
"""Create a binary classification sample dataset and save.
If no filename is given it will default to:
"simdata-{n_samples}X{m_features}.parquet".
Additional scikit-learn parameters can be set using **sk_params, please see https://scikit-learn.org/stable/modules/generated/sklearn.datasets.make_classification.html for more details.
:param context: function context
:param n_samples: number of rows/samples
:param m_features: number of cols/features
:param k_classes: number of classes
:param header: header for features array
:param label_column: column name of ground-truth series
:param weight: fraction of sample negative value (ground-truth=0)
:param random_state: rng seed (see https://scikit-learn.org/stable/glossary.html#term-random-state)
:param key: key of data in artifact store
:param file_ext: (pqt) extension for parquet file
:param sk_params: additional parameters for `sklearn.datasets.make_classification`
"""
features, labels = make_classification(
n_samples=n_samples,
n_features=m_features,
weights=weight,
n_classes=k_classes,
random_state=random_state,
**sk_params)
# make dataframes, add column names, concatenate (X, y)
X = pd.DataFrame(features)
if not header:
X.columns = ["feat_" + str(x) for x in range(m_features)]
else:
X.columns = header
y = pd.DataFrame(labels, columns=[label_column])
data = pd.concat([X, y], axis=1)
context.log_dataset(key, df=data, format=file_ext, index=False)