Source code for gators.data_cleaning._base_data_cleaning
# License: Apache-2.0
from typing import List
import numpy as np
from ..transformers.transformer import Transformer
from ..util import util
from gators import DataFrame
[docs]class _BaseDataCleaning(Transformer):
"""Base data cleaning transformer."""
def __init__(self):
Transformer.__init__(self)
self.columns_to_drop: List[str] = []
self.columns_to_keep: List[str] = []
self.idx_columns_to_keep = np.array([])
[docs] def transform(self, X: DataFrame) -> DataFrame:
"""Transform the dataframe `X`.
Parameters
----------
X : DataFrame
Input dataset.
Returns
-------
X : DataFrame
Transformed dataset.
"""
self.check_dataframe(X)
self.dtypes_ = X.dtypes
if len(self.columns):
return X.drop(self.columns, axis=1)
self.dtypes_ = X.dtypes
return X
[docs] def transform_numpy(self, X: np.ndarray) -> np.ndarray:
"""Transform the array `X`.
Parameters
----------
X : np.ndarray
Input array.
Returns
-------
X : np.ndarray
Transformed array.
"""
self.check_array(X)
if self.idx_columns_to_keep.size == 0:
return np.array([]).reshape(0, X.shape[1])
return X[:, self.idx_columns_to_keep]
[docs] @staticmethod
def get_idx_columns_to_keep(
columns: List[str], columns_to_drop: List[str]
) -> np.array:
"""Get the column indices to keep.
Parameters
----------
theta_vec : List[float]
List of columns of a dataset.
columns_to_drop : List[str]
List of columns to drop.
Returns
-------
np.array:
Column indices to keep.
"""
idx_columns_to_keep = util.exclude_idx_columns(
columns=columns,
excluded_columns=columns_to_drop,
)
return idx_columns_to_keep