diff --git a/docssrc/source/tutorials/custom_encoder_rulebased/custom_encoder_rulebased.ipynb b/docssrc/source/tutorials/custom_encoder_rulebased/custom_encoder_rulebased.ipynb index 10f9a14e6..56be888cf 100644 --- a/docssrc/source/tutorials/custom_encoder_rulebased/custom_encoder_rulebased.ipynb +++ b/docssrc/source/tutorials/custom_encoder_rulebased/custom_encoder_rulebased.ipynb @@ -39,7 +39,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "id": "raising-adventure", "metadata": { "execution": { @@ -74,7 +74,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, "id": "technical-government", "metadata": { "execution": { @@ -84,118 +84,7 @@ "shell.execute_reply": "2022-02-03T21:30:38.234810Z" } }, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
modelyearpricetransmissionmileagefuelTypetaxmpgengineSize
0A1201712500Manual15735Petrol15055.41.4
1A6201616500Automatic36203Diesel2064.22.0
2A1201611000Manual29946Petrol3055.41.4
3A4201716800Automatic25952Diesel14567.32.0
4A3201917300Manual1998Petrol14549.61.0
\n", - "
" - ], - "text/plain": [ - " model year price transmission mileage fuelType tax mpg engineSize\n", - "0 A1 2017 12500 Manual 15735 Petrol 150 55.4 1.4\n", - "1 A6 2016 16500 Automatic 36203 Diesel 20 64.2 2.0\n", - "2 A1 2016 11000 Manual 29946 Petrol 30 55.4 1.4\n", - "3 A4 2017 16800 Automatic 25952 Diesel 145 67.3 2.0\n", - "4 A3 2019 17300 Manual 1998 Petrol 145 49.6 1.0" - ] - }, - "execution_count": 2, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "filename = 'https://raw.githubusercontent.com/mindsdb/benchmarks/main/benchmarks/datasets/used_car_price/data.csv'\n", "df = pd.read_csv(filename)\n", @@ -224,7 +113,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": null, "id": "absent-maker", "metadata": { "execution": { @@ -234,38 +123,7 @@ "shell.execute_reply": "2022-02-03T21:30:38.968531Z" } }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "\u001B[32mINFO:lightwood-1462817:Dropping features: []\u001B[0m\n", - "\u001B[32mINFO:lightwood-1462817:Analyzing a sample of 6920\u001B[0m\n", - "\u001B[32mINFO:lightwood-1462817:from a total population of 10668, this is equivalent to 64.9% of your data.\u001B[0m\n", - "\u001B[32mINFO:lightwood-1462817:Using 7 processes to deduct types.\u001B[0m\n", - "\u001B[32mINFO:lightwood-1462817:Infering type for: model\u001B[0m\n", - "\u001B[32mINFO:lightwood-1462817:Infering type for: year\u001B[0m\n", - "\u001B[32mINFO:lightwood-1462817:Infering type for: price\u001B[0m\n", - "\u001B[32mINFO:lightwood-1462817:Infering type for: transmission\u001B[0m\n", - "\u001B[32mINFO:lightwood-1462817:Infering type for: fuelType\u001B[0m\n", - "\u001B[32mINFO:lightwood-1462817:Infering type for: mileage\u001B[0m\n", - "\u001B[32mINFO:lightwood-1462817:Infering type for: tax\u001B[0m\n", - "\u001B[32mINFO:lightwood-1462817:Column year has data type integer\u001B[0m\n", - "\u001B[32mINFO:lightwood-1462817:Column price has data type integer\u001B[0m\n", - "\u001B[32mINFO:lightwood-1462817:Infering type for: mpg\u001B[0m\n", - "\u001B[32mINFO:lightwood-1462817:Infering type for: engineSize\u001B[0m\n", - "\u001B[32mINFO:lightwood-1462817:Column tax has data type integer\u001B[0m\n", - "\u001B[32mINFO:lightwood-1462817:Column mileage has data type integer\u001B[0m\n", - "\u001B[32mINFO:lightwood-1462817:Column engineSize has data type float\u001B[0m\n", - "\u001B[32mINFO:lightwood-1462817:Column mpg has data type float\u001B[0m\n", - "\u001B[32mINFO:lightwood-1462817:Column transmission has data type categorical\u001B[0m\n", - "\u001B[32mINFO:lightwood-1462817:Column fuelType has data type categorical\u001B[0m\n", - "\u001B[32mINFO:lightwood-1462817:Column model has data type categorical\u001B[0m\n", - "\u001B[32mINFO:lightwood-1462817:Starting statistical analysis\u001B[0m\n", - "\u001B[32mINFO:lightwood-1462817:Finished statistical analysis\u001B[0m\n" - ] - } - ], + "outputs": [], "source": [ "# Create the Problem Definition\n", "pdef = ProblemDefinition.from_dict({\n", @@ -287,7 +145,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": null, "id": "coastal-paragraph", "metadata": { "execution": { @@ -297,134 +155,7 @@ "shell.execute_reply": "2022-02-03T21:30:38.973749Z" } }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "{\n", - " \"encoders\": {\n", - " \"price\": {\n", - " \"module\": \"NumericEncoder\",\n", - " \"args\": {\n", - " \"is_target\": \"True\",\n", - " \"positive_domain\": \"$statistical_analysis.positive_domain\"\n", - " }\n", - " },\n", - " \"model\": {\n", - " \"module\": \"OneHotEncoder\",\n", - " \"args\": {}\n", - " },\n", - " \"year\": {\n", - " \"module\": \"NumericEncoder\",\n", - " \"args\": {}\n", - " },\n", - " \"transmission\": {\n", - " \"module\": \"OneHotEncoder\",\n", - " \"args\": {}\n", - " },\n", - " \"mileage\": {\n", - " \"module\": \"NumericEncoder\",\n", - " \"args\": {}\n", - " },\n", - " \"fuelType\": {\n", - " \"module\": \"OneHotEncoder\",\n", - " \"args\": {}\n", - " },\n", - " \"tax\": {\n", - " \"module\": \"NumericEncoder\",\n", - " \"args\": {}\n", - " },\n", - " \"mpg\": {\n", - " \"module\": \"NumericEncoder\",\n", - " \"args\": {}\n", - " },\n", - " \"engineSize\": {\n", - " \"module\": \"NumericEncoder\",\n", - " \"args\": {}\n", - " }\n", - " },\n", - " \"dtype_dict\": {\n", - " \"model\": \"categorical\",\n", - " \"year\": \"integer\",\n", - " \"price\": \"integer\",\n", - " \"transmission\": \"categorical\",\n", - " \"mileage\": \"integer\",\n", - " \"fuelType\": \"categorical\",\n", - " \"tax\": \"integer\",\n", - " \"mpg\": \"float\",\n", - " \"engineSize\": \"float\"\n", - " },\n", - " \"dependency_dict\": {},\n", - " \"model\": {\n", - " \"module\": \"BestOf\",\n", - " \"args\": {\n", - " \"submodels\": [\n", - " {\n", - " \"module\": \"Neural\",\n", - " \"args\": {\n", - " \"fit_on_dev\": true,\n", - " \"stop_after\": \"$problem_definition.seconds_per_mixer\",\n", - " \"search_hyperparameters\": true\n", - " }\n", - " },\n", - " {\n", - " \"module\": \"LightGBM\",\n", - " \"args\": {\n", - " \"stop_after\": \"$problem_definition.seconds_per_mixer\",\n", - " \"fit_on_dev\": true\n", - " }\n", - " },\n", - " {\n", - " \"module\": \"Regression\",\n", - " \"args\": {\n", - " \"stop_after\": \"$problem_definition.seconds_per_mixer\"\n", - " }\n", - " }\n", - " ],\n", - " \"args\": \"$pred_args\",\n", - " \"accuracy_functions\": \"$accuracy_functions\",\n", - " \"ts_analysis\": null\n", - " }\n", - " },\n", - " \"problem_definition\": {\n", - " \"target\": \"price\",\n", - " \"pct_invalid\": 2,\n", - " \"unbias_target\": true,\n", - " \"seconds_per_mixer\": 57024.0,\n", - " \"seconds_per_encoder\": null,\n", - " \"expected_additional_time\": 0.5703437328338623,\n", - " \"time_aim\": 259200,\n", - " \"target_weights\": null,\n", - " \"positive_domain\": false,\n", - " \"timeseries_settings\": {\n", - " \"is_timeseries\": false,\n", - " \"order_by\": null,\n", - " \"window\": null,\n", - " \"group_by\": null,\n", - " \"use_previous_target\": true,\n", - " \"horizon\": null,\n", - " \"historical_columns\": null,\n", - " \"target_type\": \"\",\n", - " \"allow_incomplete_history\": true,\n", - " \"eval_cold_start\": true,\n", - " \"interval_periods\": []\n", - " },\n", - " \"anomaly_detection\": false,\n", - " \"use_default_analysis\": true,\n", - " \"ignore_features\": [],\n", - " \"fit_on_all\": true,\n", - " \"strict_mode\": true,\n", - " \"seed_nr\": 420\n", - " },\n", - " \"identifiers\": {},\n", - " \"accuracy_functions\": [\n", - " \"r2_score\"\n", - " ]\n", - "}\n" - ] - } - ], + "outputs": [], "source": [ "print(json_ai.to_json())" ] @@ -484,7 +215,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": null, "id": "e03db1b0", "metadata": { "execution": { @@ -494,15 +225,7 @@ "shell.execute_reply": "2022-02-03T21:30:38.978491Z" } }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Overwriting LabelEncoder.py\n" - ] - } - ], + "outputs": [], "source": [ "%%writefile LabelEncoder.py\n", "\n", @@ -533,9 +256,9 @@ " is_prepared: bool\n", "\n", " is_timeseries_encoder: bool = False\n", - " is_trainable_encoder: bool = False\n", + " is_trainable_encoder: bool = True\n", "\n", - " def __init__(self, is_target: bool = False) -> None:\n", + " def __init__(self, is_target: bool = False, stop_after = 10) -> None:\n", " \"\"\"\n", " Initialize the Label Encoder\n", "\n", @@ -548,8 +271,7 @@ " # For LabelEncoder, this is always 1 (1 label per category)\n", " self.output_size = 1\n", "\n", - " # Not all encoders need to be prepared\n", - " def prepare(self, priming_data: pd.Series) -> None:\n", + " def prepare(self, train_data: pd.Series, dev_data: pd.Series) -> None:\n", " \"\"\"\n", " Create a LabelEncoder for categorical data.\n", "\n", @@ -561,7 +283,7 @@ " \"\"\"\n", "\n", " # Find all unique categories in the dataset\n", - " categories = priming_data.unique()\n", + " categories = train_data.unique()\n", "\n", " log.info(\"Categories Detected = \" + str(self.output_size))\n", "\n", @@ -608,7 +330,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": null, "id": "e30866c1", "metadata": { "execution": { @@ -670,7 +392,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": null, "id": "elementary-fusion", "metadata": { "execution": { @@ -699,7 +421,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": null, "id": "inappropriate-james", "metadata": { "execution": { @@ -733,7 +455,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": null, "id": "palestinian-harvey", "metadata": { "execution": { @@ -743,47 +465,7 @@ "shell.execute_reply": "2022-02-03T21:30:39.355539Z" } }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "\u001B[32mINFO:lightwood-1462817:Performing statistical analysis on data\u001B[0m\n", - "\u001B[32mINFO:lightwood-1462817:Starting statistical analysis\u001B[0m\n", - "\u001B[32mINFO:lightwood-1462817:Finished statistical analysis\u001B[0m\n", - "\u001B[37mDEBUG:lightwood-1462817: `analyze_data` runtime: 0.14 seconds\u001B[0m\n", - "\u001B[32mINFO:lightwood-1462817:Cleaning the data\u001B[0m\n", - "\u001B[37mDEBUG:lightwood-1462817: `preprocess` runtime: 0.05 seconds\u001B[0m\n", - "\u001B[32mINFO:lightwood-1462817:Splitting the data into train/test\u001B[0m\n", - "\u001B[37mDEBUG:lightwood-1462817: `split` runtime: 0.0 seconds\u001B[0m\n", - "\u001B[32mINFO:lightwood-1462817:Preparing the encoders\u001B[0m\n", - "\u001B[32mINFO:lightwood-1462817:Encoder prepping dict length of: 1\u001B[0m\n", - "\u001B[32mINFO:lightwood-1462817:Encoder prepping dict length of: 2\u001B[0m\n", - "\u001B[32mINFO:lightwood-1462817:Encoder prepping dict length of: 3\u001B[0m\n", - "\u001B[32mINFO:lightwood-1462817:Encoder prepping dict length of: 4\u001B[0m\n", - "\u001B[32mINFO:lightwood-1462817:Encoder prepping dict length of: 5\u001B[0m\n", - "\u001B[32mINFO:lightwood-1462817:Encoder prepping dict length of: 6\u001B[0m\n", - "\u001B[32mINFO:lightwood-1462817:Encoder prepping dict length of: 7\u001B[0m\n", - "\u001B[32mINFO:lightwood-1462817:Encoder prepping dict length of: 8\u001B[0m\n", - "\u001B[32mINFO:lightwood-1462817:Encoder prepping dict length of: 9\u001B[0m\n", - "\u001B[32mINFO:lightwood-1462817:Categories Detected = 1\u001B[0m\n", - "\u001B[32mINFO:lightwood-1462817:Categories Detected = 1\u001B[0m\n", - "\u001B[32mINFO:lightwood-1462817:Categories Detected = 1\u001B[0m\n", - "\u001B[32mINFO:lightwood-1462817:Done running for: price\u001B[0m\n", - "\u001B[32mINFO:lightwood-1462817:Done running for: model\u001B[0m\n", - "\u001B[32mINFO:lightwood-1462817:Done running for: year\u001B[0m\n", - "\u001B[32mINFO:lightwood-1462817:Done running for: transmission\u001B[0m\n", - "\u001B[32mINFO:lightwood-1462817:Done running for: mileage\u001B[0m\n", - "\u001B[32mINFO:lightwood-1462817:Done running for: fuelType\u001B[0m\n", - "\u001B[32mINFO:lightwood-1462817:Done running for: tax\u001B[0m\n", - "\u001B[32mINFO:lightwood-1462817:Done running for: mpg\u001B[0m\n", - "\u001B[32mINFO:lightwood-1462817:Done running for: engineSize\u001B[0m\n", - "\u001B[37mDEBUG:lightwood-1462817: `prepare` runtime: 0.16 seconds\u001B[0m\n", - "\u001B[32mINFO:lightwood-1462817:Featurizing the data\u001B[0m\n", - "\u001B[37mDEBUG:lightwood-1462817: `featurize` runtime: 0.0 seconds\u001B[0m\n" - ] - } - ], + "outputs": [], "source": [ "# Perform Stats Analysis\n", "predictor.analyze_data(df)\n", @@ -811,7 +493,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": null, "id": "silent-dealing", "metadata": { "execution": { @@ -821,76 +503,7 @@ "shell.execute_reply": "2022-02-03T21:30:39.392125Z" } }, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
fuelTypeEncData
0Diesel1
1Diesel1
2Diesel1
3Petrol2
4Diesel1
\n", - "
" - ], - "text/plain": [ - " fuelType EncData\n", - "0 Diesel 1\n", - "1 Diesel 1\n", - "2 Diesel 1\n", - "3 Petrol 2\n", - "4 Diesel 1" - ] - }, - "execution_count": 10, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "# Pick a categorical column name\n", "col_name = \"fuelType\"\n", @@ -916,7 +529,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": null, "id": "superior-mobility", "metadata": { "execution": { @@ -926,15 +539,7 @@ "shell.execute_reply": "2022-02-03T21:30:39.396663Z" } }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "{'Unknown': 0, 'Diesel': 1, 'Petrol': 2, 'Hybrid': 3}\n" - ] - } - ], + "outputs": [], "source": [ "# Label Name -> Label Number\n", "print(predictor.encoders[col_name].label_dict)" @@ -952,6 +557,11 @@ } ], "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, "language_info": { "codemirror_mode": { "name": "ipython", @@ -967,4 +577,4 @@ }, "nbformat": 4, "nbformat_minor": 5 -} \ No newline at end of file +} diff --git a/lightwood/__about__.py b/lightwood/__about__.py index 42dc89079..7605812b6 100644 --- a/lightwood/__about__.py +++ b/lightwood/__about__.py @@ -1,6 +1,6 @@ __title__ = 'lightwood' __package_name__ = 'lightwood' -__version__ = '23.5.1.1' +__version__ = '23.6.2.0' __description__ = "Lightwood is a toolkit for automatic machine learning model building" __email__ = "community@mindsdb.com" __author__ = 'MindsDB Inc' diff --git a/lightwood/analysis/analyze.py b/lightwood/analysis/analyze.py index 408ce317c..87e21b9c9 100644 --- a/lightwood/analysis/analyze.py +++ b/lightwood/analysis/analyze.py @@ -3,7 +3,6 @@ from dataprep_ml import StatisticalAnalysis from lightwood.helpers.log import log -from lightwood.helpers.ts import filter_ds from type_infer.dtype import dtype from lightwood.ensemble import BaseEnsemble from lightwood.analysis.base import BaseAnalysisBlock @@ -60,8 +59,6 @@ def model_analyzer( normal_predictions = None if len(analysis_blocks) > 0: - filtered_df = filter_ds(encoded_val_data, tss) - encoded_val_data = EncodedDs(encoded_val_data.encoders, filtered_df, encoded_val_data.target) normal_predictions = predictor(encoded_val_data, args=PredictionArguments.from_dict(args)) normal_predictions = normal_predictions.set_index(encoded_val_data.data_frame.index) diff --git a/lightwood/analysis/helpers/feature_importance.py b/lightwood/analysis/helpers/feature_importance.py index ce205f388..de01e6888 100644 --- a/lightwood/analysis/helpers/feature_importance.py +++ b/lightwood/analysis/helpers/feature_importance.py @@ -81,6 +81,7 @@ def analyze(self, info: Dict[str, object], **kwargs) -> Dict[str, object]: shuffle_data = deepcopy(ref_data) shuffle_data.clear_cache() shuffle_data.data_frame[col] = shuffle(shuffle_data.data_frame[col].values) + shuffle_data.build_cache() # TODO: bottleneck, add a method to build a single column instead! shuffled_preds = ns.predictor(shuffle_data, args=PredictionArguments.from_dict(args)) shuffled_col_accuracy[col] = np.mean(list(evaluate_accuracies( diff --git a/lightwood/api/json_ai.py b/lightwood/api/json_ai.py index 1f21007ed..3a7072130 100644 --- a/lightwood/api/json_ai.py +++ b/lightwood/api/json_ai.py @@ -94,7 +94,7 @@ def lookup_encoder( dtype.binary: "BinaryEncoder", dtype.categorical: "CategoricalAutoEncoder" if statistical_analysis is None - or len(statistical_analysis.histograms[col_name]) > 100 + or len(statistical_analysis.histograms[col_name]['x']) > 16 else "OneHotEncoder", dtype.tags: "MultiHotEncoder", dtype.date: "DatetimeEncoder", @@ -617,7 +617,6 @@ def _add_implicit_values(json_ai: JsonAI) -> JsonAI: mixers[i]["args"]["target_encoder"] = mixers[i]["args"].get( "target_encoder", "$encoders[self.target]" ) - mixers[i]["args"]["use_optuna"] = True elif mixers[i]["module"] == "LightGBMArray": mixers[i]["args"]["input_cols"] = mixers[i]["args"].get( @@ -944,14 +943,17 @@ def code_from_json_ai(json_ai: JsonAI) -> str: parallel_encoding = parallel_encoding_check(data['train'], self.encoders) if parallel_encoding: + log.debug('Preparing in parallel...') for col_name, encoder in self.encoders.items(): if col_name != self.target and not encoder.is_trainable_encoder: prepped_encoders[col_name] = (encoder, concatenated_train_dev[col_name], 'prepare') prepped_encoders = mut_method_call(prepped_encoders) else: + log.debug('Preparing sequentially...') for col_name, encoder in self.encoders.items(): if col_name != self.target and not encoder.is_trainable_encoder: + log.debug(f'Preparing encoder for {{col_name}}...') encoder.prepare(concatenated_train_dev[col_name]) prepped_encoders[col_name] = encoder @@ -997,7 +999,22 @@ def code_from_json_ai(json_ai: JsonAI) -> str: feature_body = f""" log.info('Featurizing the data') -feature_data = {{ key: EncodedDs(self.encoders, data, self.target) for key, data in split_data.items() if key != "stratified_on"}} +tss = self.problem_definition.timeseries_settings + +feature_data = dict() +for key, data in split_data.items(): + if key != 'stratified_on': + + # compute and store two splits - full and filtered (useful for time series post-train analysis) + if key not in self.feature_cache: + featurized_split = EncodedDs(self.encoders, data, self.target) + filtered_subset = EncodedDs(self.encoders, filter_ts(data, tss), self.target) + + for k, s in zip((key, f'{{key}}_filtered'), (featurized_split, filtered_subset)): + self.feature_cache[k] = s + + for k in (key, f'{{key}}_filtered'): + feature_data[k] = self.feature_cache[k] return feature_data @@ -1018,9 +1035,7 @@ def code_from_json_ai(json_ai: JsonAI) -> str: # Extract the featurized data into train/dev/test encoded_train_data = enc_data['train'] encoded_dev_data = enc_data['dev'] -encoded_test_data = enc_data['test'] -filtered_df = filter_ds(encoded_test_data, self.problem_definition.timeseries_settings) -encoded_test_data = EncodedDs(encoded_test_data.encoders, filtered_df, encoded_test_data.target) +encoded_test_data = enc_data['test_filtered'] log.info('Training the mixers') @@ -1174,6 +1189,7 @@ def code_from_json_ai(json_ai: JsonAI) -> str: enc_train_test["dev"]]).data_frame, adjust_args={'learn_call': True}) +self.feature_cache = dict() # empty feature cache to avoid large predictor objects """ learn_body = align(learn_body, 2) # ----------------- # @@ -1208,13 +1224,14 @@ def code_from_json_ai(json_ai: JsonAI) -> str: log.info(f'[Predict phase 3/{{n_phases}}] - Calling ensemble') df = self.ensemble(encoded_ds, args=self.pred_args) -if self.pred_args.all_mixers: - return df -else: +if not self.pred_args.all_mixers: log.info(f'[Predict phase 4/{{n_phases}}] - Analyzing output') - insights, global_insights = {call(json_ai.explainer)} + df, global_insights = {call(json_ai.explainer)} self.global_insights = {{**self.global_insights, **global_insights}} - return insights + +self.feature_cache = dict() # empty feature cache to avoid large predictor objects + +return df """ predict_body = align(predict_body, 2) @@ -1252,6 +1269,9 @@ def __init__(self): self.runtime_log = dict() self.global_insights = dict() + # Feature cache + self.feature_cache = dict() + @timed def analyze_data(self, data: pd.DataFrame) -> None: # Perform a statistical analysis on the unprocessed data diff --git a/lightwood/data/encoded_ds.py b/lightwood/data/encoded_ds.py index b7f90993f..d9ba4e498 100644 --- a/lightwood/data/encoded_ds.py +++ b/lightwood/data/encoded_ds.py @@ -1,5 +1,5 @@ import inspect -from typing import List, Tuple +from typing import List, Tuple, Dict import torch import numpy as np import pandas as pd @@ -8,7 +8,7 @@ class EncodedDs(Dataset): - def __init__(self, encoders: List[BaseEncoder], data_frame: pd.DataFrame, target: str) -> None: + def __init__(self, encoders: Dict[str, BaseEncoder], data_frame: pd.DataFrame, target: str) -> None: """ Create a Lightwood datasource from a data frame and some encoders. This class inherits from `torch.utils.data.Dataset`. @@ -21,10 +21,8 @@ def __init__(self, encoders: List[BaseEncoder], data_frame: pd.DataFrame, target self.data_frame = data_frame self.encoders = encoders self.target = target - self.cache_encoded = True - self.cache = [None] * len(self.data_frame) self.encoder_spans = {} - self.input_length = 0 + self.input_length = 0 # feature tensor dim # save encoder span, has to use same iterator as in __getitem__ for correct indeces for col in self.data_frame: @@ -33,6 +31,13 @@ def __init__(self, encoders: List[BaseEncoder], data_frame: pd.DataFrame, target self.input_length + self.encoders[col].output_size) self.input_length += self.encoders[col].output_size + # if cache enabled, we immediately build it + self.use_cache = True + self.cache_built = False + self.X_cache: torch.Tensor = torch.full((len(self.data_frame),), fill_value=torch.nan) + self.Y_cache: torch.Tensor = torch.full((len(self.data_frame),), fill_value=torch.nan) + self.build_cache() + def __len__(self): """ The length of an `EncodedDs` datasource equals the amount of rows of the original dataframe. @@ -44,45 +49,65 @@ def __len__(self): def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]: """ The getter yields a tuple (X, y), where: - - `X `is a concatenation of all encoded representations of the row - - `y` is the encoded target + - `X `is a concatenation of all encoded representations of the row. Size: (B, n_features) + - `y` is the encoded target. Size: (B, n_features) :param idx: index of the row to access. :return: tuple (X, y) with encoded data. """ # noqa - if self.cache_encoded: - if self.cache[idx] is not None: - return self.cache[idx] + if self.use_cache and self.X_cache[idx] is not torch.nan: + X = self.X_cache[idx, :] + Y = self.Y_cache[idx] + else: + X, Y = self._encode_idxs([idx, ]) + if self.use_cache: + self.X_cache[idx, :] = X + self.Y_cache[idx, :] = Y + + return X, Y + + def _encode_idxs(self, idxs: list): + if not isinstance(idxs, list): + raise Exception(f"Passed indexes is not an iterable. Check the type! Index: {idxs}") - X = torch.FloatTensor() - Y = torch.FloatTensor() + X = torch.zeros((len(idxs), self.input_length)) + Y = torch.zeros((len(idxs),)) for col in self.data_frame: if self.encoders.get(col, None): kwargs = {} if 'dependency_data' in inspect.signature(self.encoders[col].encode).parameters: - kwargs['dependency_data'] = {dep: [self.data_frame.iloc[idx][dep]] + kwargs['dependency_data'] = {dep: [self.data_frame.iloc[idxs][dep]] for dep in self.encoders[col].dependencies} if hasattr(self.encoders[col], 'data_window'): cols = [self.target] + [f'{self.target}_timestep_{i}' for i in range(1, self.encoders[col].data_window)] - data = [self.data_frame[cols].iloc[idx].tolist()] + data = self.data_frame[cols].iloc[idxs].values else: cols = [col] - data = self.data_frame[cols].iloc[idx].tolist() + data = self.data_frame[cols].iloc[idxs].values.flatten() - encoded_tensor = self.encoders[col].encode(data, **kwargs)[0] + encoded_tensor = self.encoders[col].encode(data, **kwargs) if torch.isnan(encoded_tensor).any() or torch.isinf(encoded_tensor).any(): raise Exception(f'Encoded tensor: {encoded_tensor} contains nan or inf values, this tensor is \ the encoding of column {col} using {self.encoders[col].__class__}') if col != self.target: - X = torch.cat([X, encoded_tensor]) + a, b = self.encoder_spans[col] + X[:, a:b] = torch.squeeze(encoded_tensor, dim=list(range(2, len(encoded_tensor.shape)))) + + # target post-processing else: Y = encoded_tensor - if self.cache_encoded: - self.cache[idx] = (X, Y) + if len(encoded_tensor.shape) > 2: + Y = encoded_tensor.squeeze() + + if len(encoded_tensor.shape) < 2: + Y = encoded_tensor.unsqueeze(1) + + # else: + # Y = encoded_tensor.ravel() return X, Y @@ -102,20 +127,35 @@ def get_encoded_column_data(self, column_name: str) -> torch.Tensor: :param column_name: name of the column. :return: A `torch.Tensor` with the encoded data of the `column_name` column. """ + if self.use_cache and self.cache_built: + if column_name == self.target and self.Y_cache is not None: + return self.Y_cache + elif self.X_cache is not torch.nan: + a, b = self.encoder_spans[column_name] + return self.X_cache[:, a:b] + kwargs = {} if 'dependency_data' in inspect.signature(self.encoders[column_name].encode).parameters: deps = [dep for dep in self.encoders[column_name].dependencies if dep in self.data_frame.columns] - kwargs['dependency_data'] = {dep: self.data_frame[dep].tolist() for dep in deps} + kwargs['dependency_data'] = {dep: self.data_frame[dep] for dep in deps} encoded_data = self.encoders[column_name].encode(self.data_frame[column_name], **kwargs) if torch.isnan(encoded_data).any() or torch.isinf(encoded_data).any(): raise Exception(f'Encoded tensor: {encoded_data} contains nan or inf values') if not isinstance(encoded_data, torch.Tensor): raise Exception( - f'The encoder: {self.encoders[column_name]} for column: {column_name} does not return a Tensor !') + f'The encoder: {self.encoders[column_name]} for column: {column_name} does not return a Tensor!') + + if self.use_cache and not self.cache_built: + if column_name == self.target: + self.Y_cache = encoded_data + else: + a, b = self.encoder_spans[column_name] + self.X_cache = self.X_cache[:, a:b] + return encoded_data - def get_encoded_data(self, include_target=True) -> torch.Tensor: + def get_encoded_data(self, include_target: bool = True) -> torch.Tensor: """ Gets all encoded data. @@ -129,17 +169,29 @@ def get_encoded_data(self, include_target=True) -> torch.Tensor: return torch.cat(encoded_dfs, 1) + def build_cache(self): + """ This method builds a cache for the entire dataframe provided at initialization. """ + if not self.use_cache: + raise RuntimeError("Cannot build a cache for EncodedDS with `use_cache` set to False.") + + idxs = list(range(len(self.data_frame))) + X, Y = self._encode_idxs(idxs) + self.X_cache = X + self.Y_cache = Y + self.cache_built = True + def clear_cache(self): - """ - Clears the `EncodedDs` cache. - """ - self.cache = [None] * len(self.data_frame) + """ Clears the `EncodedDs` cache. """ + self.X_cache = torch.full((len(self.data_frame),), fill_value=torch.nan) + self.Y_cache = torch.full((len(self.data_frame),), fill_value=torch.nan) + self.cache_built = False class ConcatedEncodedDs(EncodedDs): """ `ConcatedEncodedDs` abstracts over multiple encoded datasources (`EncodedDs`) as if they were a single entity. """ # noqa + # TODO: We should probably delete this abstraction, it's not really useful and it adds complexity/overhead def __init__(self, encoded_ds_arr: List[EncodedDs]) -> None: # @TODO: missing super() call here? diff --git a/lightwood/encoder/array/ts_num_array.py b/lightwood/encoder/array/ts_num_array.py index b61dae507..5a20a8fac 100644 --- a/lightwood/encoder/array/ts_num_array.py +++ b/lightwood/encoder/array/ts_num_array.py @@ -1,14 +1,13 @@ from typing import List, Dict, Iterable, Optional import torch -import torch.nn.functional as F from lightwood.encoder import BaseEncoder from lightwood.encoder.numeric import TsNumericEncoder class TsArrayNumericEncoder(BaseEncoder): - def __init__(self, timesteps: int, is_target: bool = False, positive_domain: bool = False, grouped_by=None): + def __init__(self, timesteps: int, is_target: bool = False, positive_domain: bool = False, grouped_by=None, nan=0): """ This encoder handles arrays of numerical time series data by wrapping the numerical encoder with behavior specific to time series tasks. @@ -23,6 +22,7 @@ def __init__(self, timesteps: int, is_target: bool = False, positive_domain: boo self.dependencies = grouped_by self.data_window = timesteps self.positive_domain = positive_domain + self.nan_value = nan self.sub_encoder = TsNumericEncoder(is_target=is_target, positive_domain=positive_domain, grouped_by=grouped_by) self.output_size = self.data_window * self.sub_encoder.output_size @@ -52,34 +52,9 @@ def encode(self, data: Iterable[Iterable], dependency_data: Optional[Dict[str, s if not dependency_data: dependency_data = {'__default': [None] * len(data)} - ret = [] - for series in data: - ret.append(self.encode_one(series, dependency_data=dependency_data)) - - return torch.vstack(ret) - - def encode_one(self, data: Iterable, dependency_data: Optional[Dict[str, str]] = {}) -> torch.Tensor: - """ - Encodes a single windowed slice of any given time series. + ret = self.sub_encoder.encode(data, dependency_data=dependency_data) - :param data: windowed slice of a numerical time series. - :param dependency_data: used to determine the correct normalizer for the input. - - :return: an encoded time series array, as per the underlying `TsNumericEncoder` object. - The output of this encoder for all time steps is concatenated, so the final shape of the tensor is (1, NxK), where N: self.data_window and K: sub-encoder # of output features. - """ # noqa - ret = [] - - for data_point in data: - ret.append(self.sub_encoder.encode([data_point], dependency_data=dependency_data)) - - ret = torch.hstack(ret) - padding_size = self.output_size - ret.shape[-1] - - if padding_size > 0: - ret = F.pad(ret, (0, padding_size)) - - return ret + return torch.Tensor(ret).nan_to_num(self.nan_value) def decode(self, encoded_values, dependency_data=None) -> List[List]: """ diff --git a/lightwood/encoder/categorical/binary.py b/lightwood/encoder/categorical/binary.py index 77c019a42..8dad1241f 100644 --- a/lightwood/encoder/categorical/binary.py +++ b/lightwood/encoder/categorical/binary.py @@ -7,6 +7,7 @@ from lightwood.encoder.base import BaseEncoder from lightwood.helpers.constants import _UNCOMMON_WORD +from lightwood.helpers.log import log class BinaryEncoder(BaseEncoder): @@ -34,17 +35,20 @@ def __init__( self, is_target: bool = False, target_weights: Dict[str, float] = None, + handle_unknown: str = 'use_encoded_value' ): super().__init__(is_target) """ :param is_target: Whether encoder featurizes target column :param target_weights: Percentage of total population represented by each category (from [0, 1]), as a dictionary. + :param handle_unknown: if set to `use_encoded_value`, will assign all classes with index greater than 1 to a special UNKNOWN index. This doesn't affect the encoded representation of shape (B, 2). During decoding, any unknown or otherwise known but "out-of-bounds" word will be decoded back to the lightwood unknown category token. If this argument is set to `error`, the encoder will raise an error while preparing if there are more than two observed classes. """ # noqa self.map = {} # category name -> index self.rev_map = {} # index -> category name self.output_size = 2 self.encoder_class_type = str + self.handle_unknown = handle_unknown # Weight-balance info if encoder represents target self.target_weights = None @@ -67,13 +71,16 @@ def prepare(self, priming_data: Iterable[str]): self.rev_map = {indx: cat for cat, indx in self.map.items()} # Enforce only binary; map must have exactly 2 classes. - if len(self.map) > 2: - raise ValueError(f'Issue with dtype; data has > 2 classes. All classes are: {self.map}') + if len(self.map) > 2 and self.handle_unknown == 'use_encoded_value': + log.warning('Warning: dtype for binary encoder has > 2 classes. Extra classes will be pointed to an invalid token. Try overriding this encoder with a multi-class categorical encoder, otherwise performance may not be optimal.') # noqa + log.warning(f'Observed classes are: {self.map}.') + elif self.handle_unknown == 'error': + raise ValueError(f'Data has > 2 classes and encoder is in strict mode. Aborting. All classes are: {self.map}.') # noqa # For target-only, report on relative weights of classes if self.is_target: - self.index_weights = torch.Tensor([1, 1]) # Equally wt. both classes + self.index_weights = torch.ones(self.output_size) # Equally wt. both classes # If target weights provided, weight by inverse if self.target_weights is not None: @@ -102,12 +109,14 @@ def encode(self, column_data: Iterable[str]) -> torch.Tensor: 'You need to call "prepare" before calling "encode" or "decode".' ) - ret = torch.zeros(size=(len(column_data), 2)) + ret = torch.zeros(size=(len(column_data), self.output_size)) for idx, word in enumerate(column_data): index = self.map.get(word, None) - if index is not None: + if index is None or index >= self.output_size: + pass # any unknown value is ignored + else: ret[idx, index] = 1 return torch.Tensor(ret) @@ -130,7 +139,11 @@ def decode(self, encoded_data: torch.Tensor): if not np.any(vector): # Vector of all 0s -> unknown category ret.append(_UNCOMMON_WORD) else: - ret.append(self.rev_map[np.argmax(vector)]) + idx = np.argmax(vector) + if idx >= self.output_size: + ret.append(_UNCOMMON_WORD) # known, but not either of the supported categories + else: + ret.append(self.rev_map[idx]) return ret diff --git a/lightwood/encoder/categorical/onehot.py b/lightwood/encoder/categorical/onehot.py index c25c09879..e72a1f59c 100644 --- a/lightwood/encoder/categorical/onehot.py +++ b/lightwood/encoder/categorical/onehot.py @@ -68,12 +68,12 @@ def prepare(self, priming_data: Iterable[str]): unq_cats = np.unique([i for i in priming_data if i is not None]).tolist() if self.use_unknown: - log.info("Encoding UNKNOWN categories as index 0") + log.debug("Encoding UNKNOWN categories as index 0") self.map = {cat: indx + 1 for indx, cat in enumerate(unq_cats)} self.map.update({_UNCOMMON_WORD: 0}) self.rev_map = {indx: cat for cat, indx in self.map.items()} else: - log.info("Encoding UNKNOWN categories as vector of all 0s") + log.debug("Encoding UNKNOWN categories as vector of all 0s") self.map = {cat: indx for indx, cat in enumerate(unq_cats)} self.rev_map = {indx: cat for cat, indx in self.map.items()} diff --git a/lightwood/encoder/numeric/numeric.py b/lightwood/encoder/numeric/numeric.py index 251fd1ae6..a2b261e3b 100644 --- a/lightwood/encoder/numeric/numeric.py +++ b/lightwood/encoder/numeric/numeric.py @@ -1,12 +1,13 @@ import math -from typing import Iterable, List, Union +from typing import Union + import torch import numpy as np -from torch.types import Number +import pandas as pd +from type_infer.dtype import dtype + from lightwood.encoder.base import BaseEncoder -from lightwood.helpers.log import log from lightwood.helpers.general import is_none -from type_infer.dtype import dtype class NumericEncoder(BaseEncoder): @@ -28,13 +29,12 @@ def __init__(self, data_type: dtype = None, is_target: bool = False, positive_do :param positive_domain: Forces the encoder to always output positive values """ super().__init__(is_target) - self._type = data_type self._abs_mean = None self.positive_domain = positive_domain self.decode_log = False self.output_size = 4 if not self.is_target else 3 - def prepare(self, priming_data: Iterable): + def prepare(self, priming_data: pd.Series): """ "NumericalEncoder" uses a rule-based form to prepare results on training (priming) data. The averages etc. are taken from this distribution. @@ -43,109 +43,105 @@ def prepare(self, priming_data: Iterable): if self.is_prepared: raise Exception('You can only call "prepare" once for a given encoder.') - value_type = 'int' - for number in priming_data: - if not is_none(number): - if int(number) != number: - value_type = 'float' - - self._type = value_type if self._type is None else self._type - non_null_priming_data = [x for x in priming_data if not is_none(x)] - self._abs_mean = np.mean(np.abs(non_null_priming_data)) + self._abs_mean = priming_data.abs().mean() self.is_prepared = True - def encode(self, data: Iterable): + def encode(self, data: Union[np.ndarray, pd.Series]): """ - :param data: An iterable data structure containing the numbers to be encoded - + :param data: A pandas series or numpy array containing the numbers to be encoded :returns: A torch tensor with the representations of each number """ if not self.is_prepared: raise Exception('You need to call "prepare" before calling "encode" or "decode".') - ret = [] - for real in data: - try: - real = float(real) - except Exception: - real = None - if self.is_target: - # Will crash if ``real`` is not a float, this is fine, targets should always have a value - vector = [0] * 3 - vector[0] = 1 if real < 0 and not self.positive_domain else 0 - vector[1] = math.log(abs(real)) if abs(real) > 0 else -20 - vector[2] = real / self._abs_mean - - else: - vector = [0] * 4 - try: - if is_none(real): - vector[0] = 0 - else: - vector[0] = 1 - vector[1] = math.log(abs(real)) if abs(real) > 0 else -20 - vector[2] = 1 if real < 0 and not self.positive_domain else 0 - vector[3] = real / self._abs_mean - except Exception as e: - vector = [0] * 4 - log.error(f'Can\'t encode input value: {real}, exception: {e}') - - ret.append(vector) - - return torch.Tensor(ret) - - def decode(self, encoded_values: Union[List[Number], torch.Tensor], decode_log: bool = None) -> list: + if isinstance(data, pd.Series): + data = data.values + + inp_data = np.nan_to_num(data.astype(float), nan=0, posinf=np.finfo(np.float32).max, neginf=np.finfo(np.float32).min) # noqa + if not self.positive_domain: + sign = np.vectorize(self._sign_fn, otypes=[float])(inp_data) + else: + sign = np.zeros(len(data)) + log_value = np.vectorize(self._log_fn, otypes=[float])(inp_data) + log_value = np.nan_to_num(log_value, nan=0, posinf=20, neginf=-20) + + norm = np.vectorize(self._norm_fn, otypes=[float])(inp_data) + norm = np.nan_to_num(norm, nan=0, posinf=20, neginf=-20) + + if self.is_target: + components = [sign, log_value, norm] + else: + nones = np.vectorize(self._none_fn, otypes=[float])(data) + components = [sign, log_value, norm, nones] + + return torch.Tensor(np.asarray(components)).T + + @staticmethod + def _sign_fn(x: float) -> float: + return 0 if x < 0 else 1 + + @staticmethod + def _log_fn(x: float) -> float: + return math.log(abs(x)) if abs(x) > 0 else -20 + + def _norm_fn(self, x: float) -> float: + return x / self._abs_mean + + @staticmethod + def _none_fn(x: float) -> float: + return 1 if is_none(x) else 0 + + def decode(self, encoded_values: torch.Tensor, decode_log: bool = None) -> list: """ :param encoded_values: The encoded values to decode into single numbers :param decode_log: Whether to decode the ``log`` or ``linear`` part of the representation, since the encoded vector contains both a log and a linear part - :returns: The decoded number + :returns: The decoded array """ # noqa + if not self.is_prepared: raise Exception('You need to call "prepare" before calling "encode" or "decode".') if decode_log is None: decode_log = self.decode_log - ret = [] - if isinstance(encoded_values, torch.Tensor): - encoded_values = encoded_values.tolist() - - for vector in encoded_values: - if self.is_target: - if np.isnan( - vector[0]) or vector[0] == float('inf') or np.isnan( - vector[1]) or vector[1] == float('inf') or np.isnan( - vector[2]) or vector[2] == float('inf'): - log.error(f'Got weird target value to decode: {vector}') - real_value = pow(10, 63) - else: - if decode_log: - sign = -1 if vector[0] > 0.5 else 1 - try: - real_value = math.exp(vector[1]) * sign - except OverflowError: - real_value = pow(10, 63) * sign - else: - real_value = vector[2] * self._abs_mean - - if self.positive_domain: - real_value = abs(real_value) - - if self._type == 'int': - real_value = int(real_value) - - else: - if vector[0] < 0.5: - ret.append(None) - continue - - real_value = vector[3] * self._abs_mean - - if self._type == 'int': - real_value = round(real_value) - - if isinstance(real_value, torch.Tensor): - real_value = real_value.item() - ret.append(real_value) - return ret + # force = True prevents side effects on the original encoded_values + ev = encoded_values.numpy(force=True) + + # set "divergent" value as default (note: finfo.max() instead of pow(10, 63)) + ret = np.full((ev.shape[0],), dtype=float, fill_value=np.finfo(np.float64).max) + + # `none` filter (if not a target column) + if not self.is_target: + mask_none = ev[:, -1] == 1 + ret[mask_none] = np.nan + else: + mask_none = np.zeros_like(ret) + + # sign component + sign = np.ones(ev.shape[0], dtype=float) + mask_sign = ev[:, 0] < 0.5 + sign[mask_sign] = -1 + + # real component + if decode_log: + real_value = np.exp(ev[:, 1]) * sign + overflow_mask = ev[:, 1] >= 63 + real_value[overflow_mask] = 10 ** 63 + valid_mask = ~overflow_mask + else: + real_value = ev[:, 2] * self._abs_mean + valid_mask = np.ones_like(real_value, dtype=bool) + + # final filters + if self.positive_domain: + real_value = abs(real_value) + + ret[valid_mask] = real_value[valid_mask] + + # set nan back to None + if mask_none.sum() > 0: + ret = ret.astype(object) + ret[mask_none] = None + + return ret.tolist() # TODO: update signature on BaseEncoder and replace all encs to return ndarrays diff --git a/lightwood/encoder/numeric/ts_numeric.py b/lightwood/encoder/numeric/ts_numeric.py index 06127c9a3..d790f5cb5 100644 --- a/lightwood/encoder/numeric/ts_numeric.py +++ b/lightwood/encoder/numeric/ts_numeric.py @@ -1,9 +1,10 @@ -import math +from typing import Union, List, Dict + import torch import numpy as np +import pandas as pd + from lightwood.encoder.numeric import NumericEncoder -from lightwood.helpers.general import is_none -from lightwood.helpers.log import log class TsNumericEncoder(NumericEncoder): @@ -20,101 +21,95 @@ def __init__(self, is_target: bool = False, positive_domain: bool = False, group self.dependencies = grouped_by self.output_size = 1 - def encode(self, data, dependency_data={}): + def encode(self, data: Union[np.ndarray, pd.Series], dependency_data: Dict[str, List[pd.Series]] = {}): """ + :param data: A pandas series containing the numbers to be encoded :param dependency_data: dict with grouped_by column info, to retrieve the correct normalizer for each datum + + :returns: A torch tensor with the representations of each number """ # noqa if not self.is_prepared: raise Exception('You need to call "prepare" before calling "encode" or "decode".') + if not dependency_data: dependency_data = {'__default': [None] * len(data)} - ret = [] - for real, group in zip(data, list(zip(*dependency_data.values()))): - try: - real = float(real) - except Exception: - try: - real = float(real.replace(',', '.')) - except Exception: - real = None - if self.is_target: - vector = [0] - if group is not None and self.normalizers is not None: - try: - mean = self.normalizers[tuple(group)].abs_mean - except KeyError: - # novel group-by, we use default normalizer mean - mean = self.normalizers['__default'].abs_mean - else: - mean = self._abs_mean + if isinstance(data, pd.Series): + data = data.values - if not is_none(real): - vector[0] = real / mean if mean != 0 else real + # get array of series-wise observed means + if self.normalizers is None: + means = np.full((len(data)), fill_value=self._abs_mean) + else: + # use global mean as default for novel series + means = np.full((len(data)), fill_value=self.normalizers['__default'].abs_mean) + + def _get_group_mean(group) -> float: + if (group, ) in self.normalizers: + return self.normalizers[(group, )].abs_mean else: - pass - # This should raise an exception *once* we fix the TsEncoder such that this doesn't get feed `nan` - # raise Exception(f'Can\'t encode target value: {real}') - else: - vector = [0] - try: - if not is_none(real): - vector[0] = real / self._abs_mean - except Exception as e: - log.error(f'Can\'t encode input value: {real}, exception: {e}') - - ret.append(vector) - - return torch.Tensor(ret) - - def decode(self, encoded_values, decode_log=None, dependency_data=None): + return self.normalizers['__default'].abs_mean + + for i, group in enumerate(list(zip(*dependency_data.values()))): # TODO: support multigroup + if group[0] is not None: + means = np.vectorize(_get_group_mean, otypes=[float])(group[0].values) + + if len(data.shape) > 1 and data.shape[1] > 1: + if len(means.shape) == 1: + means = np.expand_dims(means, 1) + means = np.repeat(means, data.shape[1], axis=1) + + def _norm_fn(x: float, mean: float) -> float: + return x / mean + + # nones = np.vectorize(self._none_fn, otypes=[float])(data) # TODO + encoded = np.vectorize(_norm_fn, otypes=[float])(data, means) + # encoded[nones] = 0 # if measurement is None, it is zeroed out # TODO + + # TODO: mask for where mean is 0, then pass real as-is + + return torch.Tensor(encoded).unsqueeze(1) + + def decode(self, encoded_values: torch.Tensor, decode_log: bool = None, dependency_data=None): if not self.is_prepared: raise Exception('You need to call "prepare" before calling "encode" or "decode".') - if decode_log is None: - decode_log = self.decode_log + assert isinstance(encoded_values, torch.Tensor), 'It is not a tensor!' # TODO: debug purposes + assert not decode_log # TODO: debug purposes - ret = [] if not dependency_data: dependency_data = {'__default': [None] * len(encoded_values)} - if isinstance(encoded_values, torch.Tensor): - encoded_values = encoded_values.tolist() - - for vector, group in zip(encoded_values, list(zip(*dependency_data.values()))): - if self.is_target: - if np.isnan(vector[0]) or vector[0] == float('inf'): - log.error(f'Got weird target value to decode: {vector}') - real_value = pow(10, 63) - else: - if decode_log: - sign = -1 if vector[0] < 0 else 1 - try: - real_value = math.exp(vector[0]) * sign - except OverflowError: - real_value = pow(10, 63) * sign - else: - if group is not None and self.normalizers is not None: - try: - mean = self.normalizers[tuple(group)].abs_mean - except KeyError: - # decode new group with default normalizer - mean = self.normalizers['__default'].abs_mean - else: - mean = self._abs_mean - real_value = vector[0] * mean + # force = True prevents side effects on the original encoded_values + ev = encoded_values.numpy(force=True) - if self.positive_domain: - real_value = abs(real_value) + # set global mean as default + ret = np.full((ev.shape[0],), dtype=float, fill_value=self._abs_mean) + + # TODO: perhaps capture nan, infs, etc and set to pow(10,63)? + + # set means array + if self.normalizers is None: + means = np.full((ev.shape[0],), fill_value=self._abs_mean) + else: + means = np.full((len(encoded_values)), fill_value=self.normalizers['__default'].abs_mean) + for i, group in enumerate(list(zip(*dependency_data.values()))): + if group is not None: + if tuple(group) in self.normalizers: + means[i] = self.normalizers[tuple(group)].abs_mean + else: + means[i] = self.normalizers['__default'].abs_mean + else: + means[i] = self._abs_mean - if self._type == 'int': - real_value = int(round(real_value, 0)) + # set real value + real_value = np.multiply(ev[:].reshape(-1,), means) + valid_mask = np.ones_like(real_value, dtype=bool) - else: - real_value = vector[0] * self._abs_mean + # final filters + if self.positive_domain: + real_value = abs(real_value) - if self._type == 'int': - real_value = round(real_value) + ret[valid_mask] = real_value[valid_mask] # TODO probably not needed - ret.append(real_value) - return ret + return ret.tolist() diff --git a/lightwood/encoder/text/pretrained.py b/lightwood/encoder/text/pretrained.py index b9fcd1bae..4532192bb 100644 --- a/lightwood/encoder/text/pretrained.py +++ b/lightwood/encoder/text/pretrained.py @@ -1,15 +1,12 @@ -""" -""" +import os import time +from typing import Iterable +from collections import deque + +import numpy as np import torch from torch.utils.data import DataLoader -import os import pandas as pd -from lightwood.encoder.text.helpers.pretrained_helpers import TextEmbed -from lightwood.helpers.device import get_device_from_name -from lightwood.encoder.base import BaseEncoder -from lightwood.helpers.log import log -from lightwood.helpers.torch import LightwoodAutocast from type_infer.dtype import dtype from transformers import ( DistilBertModel, @@ -18,8 +15,14 @@ AdamW, get_linear_schedule_with_warmup, ) +from sklearn.model_selection import train_test_split + +from lightwood.encoder.text.helpers.pretrained_helpers import TextEmbed +from lightwood.helpers.device import get_device_from_name +from lightwood.encoder.base import BaseEncoder +from lightwood.helpers.log import log +from lightwood.helpers.torch import LightwoodAutocast from lightwood.helpers.general import is_none -from typing import Iterable class PretrainedLangEncoder(BaseEncoder): @@ -48,7 +51,6 @@ def __init__( :param is_target: Whether this encoder represents the target. NOT functional for text generation yet. :param batch_size: size of batch while fine-tuning :param max_position_embeddings: max sequence length of input text - :param custom_train: If True, trains model on target procided :param frozen: If True, freezes transformer layers during training. :param epochs: number of epochs to train model with :param output_type: Data dtype of the target; if categorical/binary, the option to return logits is possible. @@ -64,12 +66,14 @@ def __init__( self._frozen = frozen self._batch_size = batch_size self._epochs = epochs + self._patience = 3 # measured in batches rather than epochs + self._val_loss_every = 5 # how many batches to wait before checking val loss. If -1, will check train loss instead of val for early stopping. # noqa + self._tr_loss_every = 2 # same as above, but only applies if `_val_loss_every` is set to -1 # Model setup self._model = None self.model_type = None - # TODO: Other LMs; Distilbert is a good balance of speed/performance self._classifier_model_class = DistilBertForSequenceClassification self._embeddings_model_class = DistilBertModel self._pretrained_model_name = "distilbert-base-uncased" @@ -90,19 +94,19 @@ def __init__( def prepare( self, - train_priming_data: Iterable[str], - dev_priming_data: Iterable[str], + train_priming_data: pd.Series, + dev_priming_data: pd.Series, encoded_target_values: torch.Tensor, ): """ Fine-tunes a transformer on the priming data. - CURRENTLY WIP; train + dev are placeholders for a validation-based approach. - + Transformer is fine-tuned with weight-decay on training split. + Train + Dev are concatenated together and a transformer is then fine tuned with weight-decay applied on the transformer parameters. The option to freeze the underlying transformer and only train a linear layer exists if `frozen=True`. This trains faster, with the exception that the performance is often lower than fine-tuning on internal benchmarks. - + :param train_priming_data: Text data in the train set - :param dev_priming_data: Text data in the dev set (not currently supported; can be empty) + :param dev_priming_data: Text data in the dev set :param encoded_target_values: Encoded target labels in Nrows x N_output_dimension """ # noqa if self.is_prepared: @@ -110,26 +114,30 @@ def prepare( os.environ['TOKENIZERS_PARALLELISM'] = 'true' - # TODO -> we shouldn't be concatenating these together - if len(dev_priming_data) > 0: - priming_data = pd.concat([train_priming_data, dev_priming_data]).values + # remove empty strings (`None`s for dtype `object`) + filtered_tr = train_priming_data[~train_priming_data.isna()] + filtered_dev = dev_priming_data[~dev_priming_data.isna()] + + if filtered_dev.shape[0] > 0: + priming_data = pd.concat([filtered_tr, filtered_dev]).tolist() + val_size = (len(dev_priming_data)) / len(train_priming_data) else: - priming_data = train_priming_data.tolist() + priming_data = filtered_tr.tolist() + val_size = 0.1 # leave out 0.1 for validation + + # Label encode the OHE/binary output for classification + labels = encoded_target_values.argmax(dim=1) - # Replaces empty strings with '' - priming_data = [x if x is not None else "" for x in priming_data] + # Split into train and validation sets + train_texts, val_texts, train_labels, val_labels = train_test_split(priming_data, labels, test_size=val_size) # If classification, then fine-tune - if (self.output_type in (dtype.categorical, dtype.binary)): - log.info("Training model.") + if self.output_type in (dtype.categorical, dtype.binary): + log.info("Training model.\n\tOutput trained is categorical") # Prepare priming data into tokenized form + attention masks - text = self._tokenizer(priming_data, truncation=True, padding=True) - - log.info("\tOutput trained is categorical") - - # Label encode the OHE/binary output for classification - labels = encoded_target_values.argmax(dim=1) + training_text = self._tokenizer(train_texts, truncation=True, padding=True) + validation_text = self._tokenizer(val_texts, truncation=True, padding=True) # Construct the model self._model = self._classifier_model_class.from_pretrained( @@ -138,8 +146,12 @@ def prepare( ).to(self.device) # Construct the dataset for training - xinp = TextEmbed(text, labels) - dataset = DataLoader(xinp, batch_size=self._batch_size, shuffle=True) + xinp = TextEmbed(training_text, train_labels) + train_dataset = DataLoader(xinp, batch_size=self._batch_size, shuffle=True) + + # Construct the dataset for validation + xvalinp = TextEmbed(validation_text, val_labels) + val_dataset = DataLoader(xvalinp, batch_size=self._batch_size, shuffle=True) # Set max length of input string; affects input to the model if self._max_len is None: @@ -148,8 +160,7 @@ def prepare( if self._frozen: log.info("\tFrozen Model + Training Classifier Layers") """ - Freeze the base transformer model and train - a linear layer on top + Freeze the base transformer model and train a linear layer on top """ # Freeze all the transformer parameters for param in self._model.base_model.parameters(): @@ -189,12 +200,12 @@ def prepare( scheduler = get_linear_schedule_with_warmup( optimizer, num_warmup_steps=0, # default value for GLUE - num_training_steps=len(dataset) * self._epochs, + num_training_steps=len(train_dataset) * self._epochs, ) # Train model; declare optimizer earlier if desired. self._tune_model( - dataset, optim=optimizer, scheduler=scheduler, n_epochs=self._epochs + train_dataset, val_dataset, optim=optimizer, scheduler=scheduler, n_epochs=self._epochs ) else: @@ -206,8 +217,7 @@ def prepare( ).to(self.device) # TODO: Not a great flag - # Currently, if the task is not classification, you must have - # an embedding generator only. + # Currently, if the task is not classification, you must have an embedding generator only if self.embed_mode is False: log.info("Embedding mode must be ON for non-classification targets.") self.embed_mode = True @@ -216,19 +226,15 @@ def prepare( encoded = self.encode(priming_data[0:1]) self.output_size = len(encoded[0]) - def _tune_model(self, dataset, optim, scheduler, n_epochs=1): + def _tune_model(self, train_dataset, val_dataset, optim, scheduler, n_epochs=1): """ - Given a model, train for n_epochs. - Specifically intended for tuning; it does NOT use loss/ - stopping criterion. - - model - torch.nn model; - dataset - torch.DataLoader; dataset to train - device - torch.device; cuda/cpu - log - lightwood.logger.log; log.info output + Given a model, tune for n_epochs. + + train_dataset - torch.DataLoader; dataset to train + val_dataset - torch.DataLoader; dataset used to compute validation loss + early stopping optim - transformers.optimization.AdamW; optimizer scheduler - scheduling params - n_epochs - number of epochs to train + n_epochs - max number of epochs to train for, provided there is no early stopping """ # noqa self._model.train() @@ -244,20 +250,21 @@ def _tune_model(self, dataset, optim, scheduler, n_epochs=1): else: log.info("Scheduler provided.") + best_tr_loss = best_val_loss = float("inf") + tr_loss_queue = deque(maxlen=self._patience) + patience_counter = self._patience + started = time.time() for epoch in range(n_epochs): total_loss = 0 - for batch in dataset: + for bidx, batch in enumerate(train_dataset): optim.zero_grad() with LightwoodAutocast(): - inpids = batch["input_ids"].to(self.device) - attn = batch["attention_mask"].to(self.device) - labels = batch["labels"].to(self.device) - outputs = self._model(inpids, attention_mask=attn, labels=labels) - loss = outputs[0] + loss = self._call(batch) + tr_loss_queue.append(loss.item()) total_loss += loss.item() loss.backward() @@ -267,9 +274,48 @@ def _tune_model(self, dataset, optim, scheduler, n_epochs=1): if time.time() - started > self.stop_after: break + # val-based early stopping + if False and (self._val_loss_every != -1) and (bidx % self._val_loss_every == 0): + self._model.eval() + val_loss = 0 + + for vbatch in val_dataset: + val_loss += self._call(vbatch).item() + + log.info(f"Epoch {epoch+1} train batch {bidx+1} - Validation loss: {val_loss/len(val_dataset)}") + if val_loss / len(val_dataset) >= best_val_loss: + break + + best_val_loss = val_loss / len(val_dataset) + self._model.train() + + # train-based early stopping + elif False and (bidx + 1) % self._tr_loss_every == 0: + self._model.eval() + + tr_loss = np.average(tr_loss_queue) + log.info(f"Epoch {epoch} train batch {bidx} - Train loss: {tr_loss}") # noqa + self._model.train() + + if tr_loss >= best_tr_loss and patience_counter == 0: + break + elif patience_counter > 0: + patience_counter -= 1 + elif tr_loss < best_tr_loss: + best_tr_loss = tr_loss + patience_counter = self._patience + if time.time() - started > self.stop_after: break - self._train_callback(epoch, total_loss / len(dataset)) + self._train_callback(epoch, total_loss / len(train_dataset)) + + def _call(self, batch): + inpids = batch["input_ids"].to(self.device) + attn = batch["attention_mask"].to(self.device) + labels = batch["labels"].to(self.device) + outputs = self._model(inpids, attention_mask=attn, labels=labels) + loss = outputs[0] + return loss def _train_callback(self, epoch, loss): log.info(f"{self.name} at epoch {epoch+1} and loss {loss}!") diff --git a/lightwood/encoder/text/short.py b/lightwood/encoder/text/short.py index e7f68186e..127bb863f 100644 --- a/lightwood/encoder/text/short.py +++ b/lightwood/encoder/text/short.py @@ -8,6 +8,8 @@ class ShortTextEncoder(BaseEncoder): + is_trainable_encoder = False + def __init__(self, is_target=False, mode=None, device=''): """ :param is_target: @@ -55,7 +57,7 @@ def prepare(self, priming_data): unique_tokens = set() max_words_per_sent = 0 for sent in no_null_sentences: - tokens = tokenize_text(sent) + tokens = list(tokenize_text(sent)) max_words_per_sent = max(max_words_per_sent, len(tokens)) for tok in tokens: unique_tokens.add(tok) @@ -78,7 +80,7 @@ def encode(self, column_data: List[str]) -> torch.Tensor: no_null_sentences = (x if x is not None else '' for x in column_data) output = [] for sent in no_null_sentences: - tokens = tokenize_text(sent) + tokens = list(tokenize_text(sent)) encoded_words = self.cae.encode(tokens) encoded_sent = self._combine_fn(encoded_words) output.append(torch.Tensor(encoded_sent)) diff --git a/lightwood/helpers/ts.py b/lightwood/helpers/ts.py index 445492cf6..c1306157a 100644 --- a/lightwood/helpers/ts.py +++ b/lightwood/helpers/ts.py @@ -297,13 +297,12 @@ def min_k(top_k, data): return candidate_sps -def filter_ds(ds, tss, n_rows=1): +def filter_ts(df: pd.DataFrame, tss, n_rows=1): """ This method triggers only for timeseries datasets. It returns a dataframe that filters out all but the first ``n_rows`` per group. """ # noqa - df = ds.data_frame if tss.is_timeseries: gby = tss.group_by if gby is None: diff --git a/lightwood/mixer/neural.py b/lightwood/mixer/neural.py index 90040a3aa..697ef288f 100644 --- a/lightwood/mixer/neural.py +++ b/lightwood/mixer/neural.py @@ -1,5 +1,6 @@ import time from copy import deepcopy +from collections import deque from typing import Dict, List, Optional import torch @@ -42,7 +43,8 @@ def __init__( net: str, fit_on_dev: bool, search_hyperparameters: bool, - n_epochs: Optional[int] = None + n_epochs: Optional[int] = None, + lr: Optional[float] = None, ): """ The Neural mixer trains a fully connected dense network from concatenated encoded outputs of each of the features in the dataset to predicted the encoded output. @@ -55,13 +57,17 @@ def __init__( :param fit_on_dev: If we should fit on the dev dataset :param search_hyperparameters: If the network should run a more through hyperparameter search (currently disabled) :param n_epochs: amount of epochs that the network will be trained for. Supersedes all other early stopping criteria if specified. + :param lr: learning rate for the network. By default, it is automatically selected based on an initial search process. """ # noqa super().__init__(stop_after) self.dtype_dict = dtype_dict self.target = target self.target_encoder = target_encoder + self.num_hidden = 1 self.epochs_to_best = 0 self.n_epochs = n_epochs + self.lr = lr + self.loss_hist_len = 7 # length of queue to use for early stopping self.fit_on_dev = fit_on_dev self.net_name = net self.supports_proba = dtype_dict[target] in [dtype.binary, dtype.categorical] @@ -106,32 +112,45 @@ def _select_criterion(self) -> torch.nn.Module: return criterion - def _select_optimizer(self) -> Optimizer: - optimizer = ad_optim.Ranger(self.model.parameters(), lr=self.lr, weight_decay=2e-2) + def _select_optimizer(self, model, lr) -> Optimizer: + optimizer = ad_optim.Ranger(model.parameters(), lr=lr, weight_decay=2e-2) return optimizer - def _find_lr(self, dl): - optimizer = self._select_optimizer() + def _find_lr(self, train_data): + lr = 1e-4 # good starting point as search escalates + lrs = deque([5e-4, 1e-3, 2e-3, 3e-3, 5e-3, 1e-2, 5e-2, 1e-1]) + starting_model = deepcopy(self.model) criterion = self._select_criterion() scaler = GradScaler() - running_losses: List[float] = [] - cum_loss = 0 - lr_log = [] + running_losses = deque(maxlen=self.loss_hist_len) + lr_log = deque(maxlen=self.loss_hist_len) best_model = self.model stop = False - batches = 0 - for epoch in range(1, 101): - if stop: - break - for i, (X, Y) in enumerate(dl): - if stop: - break + n_steps = 10 + cum_loss = 0 + + while stop is False: + # overfit learning on first n_steps samples (biased, but we only want an intuition on what LR is decent) + dl = DataLoader(train_data, + batch_size=min(len(train_data.data_frame), 32, self.batch_size), + shuffle=False) + dl_iter = iter(dl) + self.model = deepcopy(starting_model) + self.model.train() + optimizer = self._select_optimizer(self.model, lr=lr) + + for i in range(n_steps): + try: + X, Y = next(dl_iter) + except StopIteration: + dl_iter = iter(dl) + X, Y = next(dl_iter) - batches += len(X) X = X.to(self.model.device) Y = Y.to(self.model.device) + with LightwoodAutocast(): optimizer.zero_grad() Yh = self._net_call(X) @@ -145,22 +164,21 @@ def _find_lr(self, dl): optimizer.step() cum_loss += loss.item() - # Account for ranger lookahead update - if (i + 1) * epoch % 6: - batches = 0 - lr = optimizer.param_groups[0]['lr'] - log.info(f'Loss of {cum_loss} with learning rate {lr}') - running_losses.append(cum_loss) - lr_log.append(lr) - cum_loss = 0 - if len(running_losses) < 2 or np.mean(running_losses[:-1]) > np.mean(running_losses): - optimizer.param_groups[0]['lr'] = lr * 1.4 - # Time saving since we don't have to start training fresh - best_model = deepcopy(self.model) - else: - stop = True + log.info(f'Loss of {cum_loss} with learning rate {lr}') + running_losses.append(cum_loss) + lr_log.append(lr) + cum_loss = 0 + lr = lrs.popleft() + if len(lrs) == 0: + stop = True + + # store model if best so far + inv_running_losses = list(running_losses)[::-1] # invert so when tied we pick the most aggresive LR + best_loss_idx = np.nanargmin(inv_running_losses) # nanargmin ignores nans that may arise + if best_loss_idx == 0: + best_model = deepcopy(self.model) # store model for slight time savings + best_loss_lr = lr_log[-1] - best_loss_lr = lr_log[np.argmin(running_losses)] lr = best_loss_lr log.info(f'Found learning rate of: {lr}') return lr, best_model @@ -168,7 +186,7 @@ def _find_lr(self, dl): def _max_fit(self, train_dl, dev_dl, criterion, optimizer, scaler, stop_after, return_model_after): epochs_to_best = 0 best_dev_error = pow(2, 32) - running_errors = [] + running_errors = deque(maxlen=self.loss_hist_len) best_model = self.model for epoch in range(1, return_model_after + 1): @@ -215,10 +233,11 @@ def _max_fit(self, train_dl, dev_dl, criterion, optimizer, scaler, stop_after, r # automated early stopping else: - if len(running_errors) >= 5: - delta_mean = np.average([running_errors[-i - 1] - running_errors[-i] for i in range(1, 5)], - weights=[(1 / 2)**i for i in range(1, 5)]) - if delta_mean <= 0: + if len(running_errors) >= self.loss_hist_len: + delta_mean = np.average([ + running_errors[-i - 1] - running_errors[-i] for i in range(len(running_errors) - 1)], + weights=[(1 / 2)**i for i in range(len(running_errors) - 1)]) + if delta_mean >= 0: break elif (time.time() - self.started) > stop_after: break @@ -243,8 +262,9 @@ def _error(self, dev_dl, criterion) -> float: def _init_net(self, ds: EncodedDs): self.net_class = DefaultNet if self.net_name == 'DefaultNet' else ArNet - net_kwargs = {'input_size': len(ds[0][0]), - 'output_size': len(ds[0][1]), + X, Y = ds[0] + net_kwargs = {'input_size': len(X), + 'output_size': len(Y), 'num_hidden': self.num_hidden, 'dropout': 0} @@ -274,17 +294,13 @@ def _fit(self, train_data: EncodedDs, dev_data: EncodedDs) -> None: dev_dl = DataLoader(dev_data, batch_size=self.batch_size, shuffle=False) train_dl = DataLoader(train_data, batch_size=self.batch_size, shuffle=False) - self.lr = 1e-4 - self.num_hidden = 1 - - # Find learning rate - # keep the weights + # Find learning rate & keep initial weights self._init_net(train_data) if not self.lr: - self.lr, self.model = self._find_lr(train_dl) + self.lr, self.model = self._find_lr(train_data) # Keep on training - optimizer = self._select_optimizer() + optimizer = self._select_optimizer(self.model, lr=self.lr) criterion = self._select_criterion() scaler = GradScaler() @@ -314,7 +330,7 @@ def partial_fit(self, train_data: EncodedDs, dev_data: EncodedDs, args: Optional self.started = time.time() train_dl = DataLoader(train_data, batch_size=self.batch_size, shuffle=True) dev_dl = DataLoader(dev_data, batch_size=self.batch_size, shuffle=True) - optimizer = self._select_optimizer() + optimizer = self._select_optimizer(self.model, lr=self.lr) criterion = self._select_criterion() scaler = GradScaler() diff --git a/lightwood/mixer/neural_ts.py b/lightwood/mixer/neural_ts.py index ef34b53f7..813266cff 100644 --- a/lightwood/mixer/neural_ts.py +++ b/lightwood/mixer/neural_ts.py @@ -7,10 +7,8 @@ import torch from torch import nn -import torch_optimizer as ad_optim from torch.cuda.amp import GradScaler from torch.utils.data import DataLoader -from torch.optim.optimizer import Optimizer from type_infer.dtype import dtype from lightwood.api.types import PredictionArguments @@ -76,10 +74,6 @@ def _select_criterion(self) -> torch.nn.Module: return criterion - def _select_optimizer(self) -> Optimizer: - optimizer = ad_optim.Ranger(self.model.parameters(), lr=self.lr) - return optimizer - def _fit(self, train_data: EncodedDs, dev_data: EncodedDs) -> None: """ :param train_data: The network is fit/trained on this @@ -106,10 +100,10 @@ def _fit(self, train_data: EncodedDs, dev_data: EncodedDs) -> None: # Find learning rate # keep the weights self._init_net(train_data) - self.lr, self.model = self._find_lr(train_dl) + self.lr, self.model = self._find_lr(train_data) # Keep on training - optimizer = self._select_optimizer() + optimizer = self._select_optimizer(self.model, lr=self.lr) criterion = self._select_criterion() scaler = GradScaler() diff --git a/lightwood/mixer/random_forest.py b/lightwood/mixer/random_forest.py index 10df3c3f9..89f6ca682 100644 --- a/lightwood/mixer/random_forest.py +++ b/lightwood/mixer/random_forest.py @@ -14,7 +14,7 @@ from type_infer.dtype import dtype from lightwood.helpers.log import log from lightwood.encoder.base import BaseEncoder -from lightwood.data.encoded_ds import ConcatedEncodedDs, EncodedDs +from lightwood.data.encoded_ds import EncodedDs, ConcatedEncodedDs from lightwood.mixer.base import BaseMixer from lightwood.api.types import PredictionArguments @@ -33,8 +33,8 @@ def __init__( target: str, dtype_dict: Dict[str, str], fit_on_dev: bool, - use_optuna: bool, - target_encoder: BaseEncoder + target_encoder: BaseEncoder, + use_optuna: bool = False, ): """ The `RandomForest` mixer supports both regression and classification tasks. @@ -57,7 +57,7 @@ def __init__( self.model = None self.positive_domain = False - self.num_trials = 20 + self.num_trials = 5 self.cv = 3 self.map = {} @@ -100,7 +100,6 @@ def fit(self, train_data: EncodedDs, dev_data: EncodedDs) -> None: init_params = { 'n_estimators': 50, 'max_depth': 5, - 'max_features': 1., 'bootstrap': True, 'n_jobs': -1, 'random_state': 0 @@ -128,15 +127,10 @@ def fit(self, train_data: EncodedDs, dev_data: EncodedDs) -> None: else (mean_squared_error, 'predict') def objective(trial: trial_module.Trial): - criterion = trial.suggest_categorical("criterion", - ["gini", "entropy"]) if self.is_classifier else 'squared_error' + criterion = trial.suggest_categorical("criterion", "gini") if self.is_classifier else 'squared_error' params = { 'n_estimators': trial.suggest_int('n_estimators', 2, 512), - 'max_depth': trial.suggest_int('max_depth', 2, 15), - 'min_samples_split': trial.suggest_int("min_samples_split", 2, 20), - 'min_samples_leaf': trial.suggest_int("min_samples_leaf", 1, 20), - 'max_features': trial.suggest_float("max_features", 0.1, 1), 'criterion': criterion, } @@ -203,7 +197,7 @@ def __call__(self, ds: EncodedDs, :return: dataframe with predictions. """ - data = ds.get_encoded_data(include_target=False) + data = ds.get_encoded_data(include_target=False).numpy() if self.is_classifier: predictions = self.model.predict_proba(data) diff --git a/lightwood/mixer/regression.py b/lightwood/mixer/regression.py index 99c2a9905..88b2ab709 100644 --- a/lightwood/mixer/regression.py +++ b/lightwood/mixer/regression.py @@ -87,10 +87,7 @@ def __call__(self, ds: EncodedDs, :returns: A dataframe cotaining the decoded predictions and (depending on the args) additional information such as the probabilites for each target class """ # noqa - X = [] - for x, _ in ds: - X.append(x.tolist()) - + X = ds.get_encoded_data(include_target=False) Yh = self.model.predict(X) decoded_predictions = self.target_encoder.decode(torch.Tensor(Yh)) diff --git a/requirements.txt b/requirements.txt index a07b5a275..335c78f41 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ -type_infer ==0.0.9 -dataprep_ml ==0.0.8 -mindsdb-evaluator >=0.0.7 +type_infer >=0.0.10 +dataprep_ml >=0.0.9 +mindsdb-evaluator >=0.0.9 numpy nltk >=3,<3.6 python-dateutil >=2.8.1 diff --git a/requirements_image.txt b/requirements_image.txt index c35ae2276..a66506e04 100644 --- a/requirements_image.txt +++ b/requirements_image.txt @@ -1,2 +1,2 @@ -torchvision >=0.10.0,<0.11.0 +torchvision pillow >8.3.1 diff --git a/tests/unit_tests/encoder/categorical/test_binary.py b/tests/unit_tests/encoder/categorical/test_binary.py index ad2aff72a..4eb7a8837 100644 --- a/tests/unit_tests/encoder/categorical/test_binary.py +++ b/tests/unit_tests/encoder/categorical/test_binary.py @@ -72,7 +72,7 @@ def test_check_only_binary(self): """ Ensure binary strictly enforces binary typing """ data = ["apple", "apple", "orange", "banana", "apple", "orange"] - enc = BinaryEncoder() + enc = BinaryEncoder(handle_unknown='error') self.assertRaises(ValueError, enc.prepare, data) def test_check_probabilities(self): diff --git a/tests/unit_tests/encoder/numeric/test_numeric.py b/tests/unit_tests/encoder/numeric/test_numeric.py index 93590c071..b81ef7a06 100644 --- a/tests/unit_tests/encoder/numeric/test_numeric.py +++ b/tests/unit_tests/encoder/numeric/test_numeric.py @@ -1,5 +1,6 @@ import unittest import numpy as np +import pandas as pd import torch from lightwood.encoder.numeric import NumericEncoder from lightwood.encoder.numeric import TsNumericEncoder @@ -16,31 +17,38 @@ def _pollute(array): class TestNumericEncoder(unittest.TestCase): def test_encode_and_decode(self): - data = [1, 1.1, 2, -8.6, None, 0] + data = pd.Series([1, 1.1, 2, -8.6, None, 0]) encoder = NumericEncoder() - encoder.prepare(data) encoded_vals = encoder.encode(data) - self.assertTrue(encoded_vals[1][1] > 0) - self.assertTrue(encoded_vals[2][1] > 0) - self.assertTrue(encoded_vals[3][1] > 0) - for i in range(0, 3): - self.assertTrue(encoded_vals[i][2] == 0) - self.assertTrue(encoded_vals[3][2] == 1) - self.assertTrue(encoded_vals[4][3] == 0) + # sign component check + self.assertTrue(encoded_vals[0][0] > 0) + self.assertTrue(encoded_vals[1][0] > 0) + self.assertTrue(encoded_vals[2][0] > 0) + self.assertTrue(encoded_vals[3][0] == 0) - decoded_vals = encoder.decode(encoded_vals) + # none component check + for i in range(0, len(encoded_vals)): + if i != 4: + self.assertTrue(encoded_vals[i][-1] == 0) + else: + self.assertTrue(encoded_vals[i][-1] == 1) - for i in range(len(encoded_vals)): - if decoded_vals[i] is None: - self.assertTrue(decoded_vals[i] == data[i]) + # exp component nan edge case check + self.assertTrue(encoded_vals[4][2] == 0) + + # compare decoded v/s real + decoded_vals = encoder.decode(encoded_vals) + for decoded, real in zip(decoded_vals, data.tolist()): + if decoded is None: + self.assertTrue((real is None) or (real != real)) else: - np.testing.assert_almost_equal(round(decoded_vals[i], 10), round(data[i], 10)) + np.testing.assert_almost_equal(round(decoded, 6), round(real, 6)) def test_positive_domain(self): - data = [-1, -2, -100, 5, 10, 15] + data = pd.Series([-1, -2, -100, 5, 10, 15]) for encoder in [NumericEncoder(), TsNumericEncoder()]: encoder.is_target = True # only affects target values encoder.positive_domain = True @@ -51,7 +59,7 @@ def test_positive_domain(self): self.assertTrue(val >= 0) def test_log_overflow_and_none(self): - data = list(range(-2000, 2000, 66)) + data = pd.Series(list(range(-2000, 2000, 66))) encoder = NumericEncoder() encoder.is_target = True @@ -61,7 +69,7 @@ def test_log_overflow_and_none(self): encoder.decode(encoder.encode(data)) for i in range(0, 70, 10): - encoder.decode([[0, pow(2, i), 0]]) + encoder.decode(torch.Tensor([[0, pow(2, i), 0]])) def test_nan_encoding(self): # Generate some numbers @@ -72,10 +80,10 @@ def test_nan_encoding(self): # Prepare with the correct data and decode invalid data encoder = NumericEncoder() - encoder.prepare(data) + encoder.prepare(pd.Series(data)) for array in invalid_data: # Make sure the encoding has no nans or infs - encoded_repr = encoder.encode(array) + encoded_repr = encoder.encode(pd.Series(array)) assert not torch.isnan(encoded_repr).any() assert not torch.isinf(encoded_repr).any() @@ -88,29 +96,17 @@ def test_nan_encoding(self): # Prepare with the invalid data and decode the valid data for array in invalid_data: encoder = NumericEncoder() - encoder.prepare(array) + encoder.prepare(pd.Series(array)) # Make sure the encoding has no nans or infs - encoded_repr = encoder.encode(data) + encoded_repr = encoder.encode(pd.Series(array)) assert not torch.isnan(encoded_repr).any() assert not torch.isinf(encoded_repr).any() # Make sure the invalid value is decoded as `None` and the rest as numbers decoded_repr = encoder.decode(encoded_repr) - for x in decoded_repr: - assert not is_none(x) - - # Prepare with the invalid data and decode invalid data - for array in invalid_data: - encoder = NumericEncoder() - encoder.prepare(array) - # Make sure the encoding has no nans or infs - encoded_repr = encoder.encode(array) - assert not torch.isnan(encoded_repr).any() - assert not torch.isinf(encoded_repr).any() - - # Make sure the invalid value is decoded as `None` and the rest as numbers - decoded_repr = encoder.decode(encoded_repr) - for x in decoded_repr[:-1]: - assert not is_none(x) - assert decoded_repr[-1] is None + for dec, real in zip(decoded_repr, array): + if is_none(real): + assert is_none(dec) + else: + assert not is_none(x) or x != 0.0