diff --git a/src/pytorch_tabular/config/config.py b/src/pytorch_tabular/config/config.py index 502ffabc..3003ff83 100644 --- a/src/pytorch_tabular/config/config.py +++ b/src/pytorch_tabular/config/config.py @@ -65,8 +65,9 @@ class DataConfig: categorical_cols (List): Column names of the categorical fields to treat differently. Defaults to [] - date_columns (List): (Column names, Freq) tuples of the date fields. For eg. a field named - introduction_date and with a monthly frequency should have an entry ('intro_date','M'} + date_columns (List): (Column name, Freq, Format) tuples of the date fields. For eg. a field named + introduction_date and with a monthly frequency like "2023-12" should have + an entry ('intro_date','M','%Y-%m') encode_date_columns (bool): Whether or not to encode the derived variables from date @@ -115,7 +116,8 @@ class DataConfig: default_factory=list, metadata={ "help": "(Column names, Freq) tuples of the date fields. For eg. a field named" - " `introduction_date` and with a monthly frequency should have an entry ('intro_date','M'}" + " introduction_date and with a monthly frequency like '2023-12' should have" + " an entry ('intro_date','M','%Y-%m')" }, ) diff --git a/src/pytorch_tabular/tabular_datamodule.py b/src/pytorch_tabular/tabular_datamodule.py index 49294967..a20ee3c6 100644 --- a/src/pytorch_tabular/tabular_datamodule.py +++ b/src/pytorch_tabular/tabular_datamodule.py @@ -287,8 +287,8 @@ def do_leave_one_out_encoder(self) -> bool: def _encode_date_columns(self, data: DataFrame) -> DataFrame: added_features = [] - for field_name, freq in self.config.date_columns: - data = self.make_date(data, field_name) + for field_name, freq, format in self.config.date_columns: + data = self.make_date(data, field_name, format) data, _new_feats = self.add_datepart(data, field_name, frequency=freq, prefix=None, drop=True) added_features += _new_feats return data, added_features @@ -630,7 +630,7 @@ def time_features_from_frequency_str(cls, freq_str: str) -> List[str]: # adapted from fastai @classmethod - def make_date(cls, df: DataFrame, date_field: str) -> DataFrame: + def make_date(cls, df: DataFrame, date_field: str, date_format: str = "ISO8601") -> DataFrame: """Make sure `df[date_field]` is of the right date type. Args: @@ -645,7 +645,7 @@ def make_date(cls, df: DataFrame, date_field: str) -> DataFrame: if isinstance(field_dtype, DatetimeTZDtype): field_dtype = np.datetime64 if not np.issubdtype(field_dtype, np.datetime64): - df[date_field] = to_datetime(df[date_field], infer_datetime_format=True) + df[date_field] = to_datetime(df[date_field], format=date_format) return df # adapted from fastai diff --git a/tests/test_datamodule.py b/tests/test_datamodule.py index 1f80bb2c..133ab0aa 100644 --- a/tests/test_datamodule.py +++ b/tests/test_datamodule.py @@ -117,7 +117,7 @@ def test_date_encoding(timeseries_data, freq): target=target + ["Occupancy"], continuous_cols=["Temperature", "Humidity", "Light", "CO2", "HumidityRatio"], categorical_cols=[], - date_columns=[("date", freq)], + date_columns=[("date", freq, "%Y-%m-%d %H:%M:%S")], encode_date_columns=True, ) model_config_params = {"task": "regression"}