# Licence Apache-2.0
from typing import List, Tuple
import numpy as np
from ..transformers.transformer import Transformer
from ..util import util
from gators import DataFrame, Series
[docs]class _BaseDatetimeFeature(Transformer):
    """Base datetime transformer class.
    Parameters
    ----------
    theta_vec : List[float]
        List of columns.
    column_names : List[str], default None.
        List of column names.
    """
    def __init__(
        self,
        columns: List[str],
        date_format: str,
        column_names: List[str],
    ):
        if not isinstance(date_format, str):
            raise TypeError("`date_format` should be a string.")
        if sorted(list(date_format)) != ["d", "m", "y"]:
            raise ValueError(
                "`date_format` should be a string composed of the letters `d`, `m` and `y`."
            )
        Transformer.__init__(self)
        self.columns = columns
        self.column_names = column_names
        self.date_format = date_format
[docs]    def fit(self, X: DataFrame, y: Series = None) -> "Transformer":
        """Fit the transformer on the dataframe `X`.
        Parameters
        ----------
        X : DataFrame
            Input dataframe.
        y : Series, default None.
            Target values.
        Returns
        -------
        self : Transformer
            Instance of itself.
        """
        self.check_dataframe(X)
        X_datetime_dtype = X[self.columns].dtypes
        for column in self.columns:
            if not np.issubdtype(X_datetime_dtype[column], np.datetime64):
                raise TypeError(
                    """
                    Datetime columns should be of subtype np.datetime64.
                    Use `ConvertColumnDatatype` to convert the dtype.
                """
                )
        self.idx_columns = util.get_idx_columns(
            columns=X.columns,
            selected_columns=self.columns,
        )
        self.n_columns = len(self.columns)
        self.idx_day_bounds, self.idx_month_bounds, self.idx_year_bounds = self.get_idx(
            self.date_format
        )
        return self 
[docs]    @staticmethod
    def get_cyclic_column_names(columns: List[str], pattern: str):
        """Get the column names.
        Parameters
        ----------
        theta_vec : List[float]
            List of datetime features.
        pattern: str
            Pattern.
        """
        column_names = []
        for c in columns:
            column_names.append(f"{c}__{pattern}_cos")
            column_names.append(f"{c}__{pattern}_sin")
        return column_names 
[docs]    @staticmethod
    def get_idx(date_format: str) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
        """[summary]
        Parameters
        ----------
        date_format : str
            Datetime format
        Returns
        -------
        idx_day_bounds : np.ndarray
            Start and end indices of the day.
        idx_month : np.ndarray
            Start and end indices of the month.
        idx_year_bounds : np.ndarray
            Start and end indices of the year.
        """
        idx_start_day = 3 * date_format.index("d")
        idx_start_month = 3 * date_format.index("m")
        idx_start_year = 3 * date_format.index("y")
        idx_start_day = (
            idx_start_day if idx_start_year > idx_start_day else idx_start_day + 2
        )
        idx_start_month = (
            idx_start_month if idx_start_year > idx_start_month else idx_start_month + 2
        )
        idx_day_bounds = np.array([idx_start_day, idx_start_day + 2])
        idx_month = np.array([idx_start_month, idx_start_month + 2])
        idx_year_bounds = np.array([idx_start_year, idx_start_year + 4])
        return idx_day_bounds, idx_month, idx_year_bounds