-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: add arctic sea (environment monitoring ts reg) new api example
- Loading branch information
Showing
1 changed file
with
50 additions
and
48 deletions.
There are no files selected for viewing
98 changes: 50 additions & 48 deletions
98
...eal_world_examples/industrial_examples/enviroment_monitoring/ts_forecasting/arctic_sea.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,70 +1,72 @@ | ||
from pathlib import Path | ||
|
||
import matplotlib | ||
from matplotlib import use | ||
import matplotlib.pyplot as plt | ||
import pandas as pd | ||
from fedot.core.data.data import InputData | ||
from fedot.core.data.data_split import train_test_data_setup | ||
from fedot.core.pipelines.pipeline_builder import PipelineBuilder | ||
from fedot.core.repository.tasks import TsForecastingParams, Task, TaskTypesEnum | ||
from matplotlib import pyplot as plt | ||
from sklearn.metrics import mean_squared_error, mean_absolute_percentage_error | ||
|
||
from fedot_ind.core.architecture.settings.computational import backend_methods as np | ||
from fedot_ind.tools.serialisation.path_lib import PROJECT_PATH | ||
from fedot_ind.tools.serialisation.path_lib import EXAMPLES_DATA_PATH | ||
|
||
matplotlib.use('TKagg') | ||
|
||
horizon = 365 | ||
PATH = Path(PROJECT_PATH, 'examples', 'data', 'ices_areas_ts.csv') | ||
|
||
time_series_df = pd.read_csv(PATH).iloc[:, 1:] | ||
target_series = time_series_df['Карское'].values | ||
if __name__ == '__main__': | ||
use('TKagg') | ||
horizon = 365 | ||
|
||
input_data = InputData.from_numpy_time_series( | ||
target_series, | ||
task=Task( | ||
TaskTypesEnum.ts_forecasting, | ||
task_params=TsForecastingParams( | ||
forecast_length=horizon))) | ||
train_data, test_data = train_test_data_setup(input_data) | ||
time_series_df = pd.read_csv(Path(EXAMPLES_DATA_PATH, | ||
'real_world/ice_forecasting/ices_areas_ts.csv')) | ||
time_series_df = time_series_df.iloc[:, 1:] | ||
target_series = time_series_df['Карское'].values | ||
|
||
pipeline_based = PipelineBuilder().add_node('lagged').add_node('rfr').build() | ||
pipeline_based.fit(train_data) | ||
input_data = InputData.from_numpy_time_series( | ||
target_series, | ||
task=Task(TaskTypesEnum.ts_forecasting, | ||
task_params=TsForecastingParams(forecast_length=horizon))) | ||
train_data, test_data = train_test_data_setup(input_data) | ||
|
||
topological_pipeline = PipelineBuilder().add_node('lagged').add_node( | ||
'topological_features') .add_node('lagged', branch_idx=2).join_branches('rfr').build() | ||
topological_pipeline.fit(train_data) | ||
pipeline_based = ( | ||
PipelineBuilder() | ||
.add_node('lagged') | ||
.add_node('rfr') | ||
.build() | ||
) | ||
pipeline_based.fit(train_data) | ||
|
||
forecast_base = np.ravel(pipeline_based.predict(test_data).predict) | ||
forecast_topo = np.ravel(topological_pipeline.predict(test_data).predict) | ||
topological_pipeline = ( | ||
PipelineBuilder() | ||
.add_node('lagged') | ||
.add_node('topological_features') | ||
.add_node('lagged', branch_idx=2) | ||
.join_branches('rfr') | ||
.build() | ||
) | ||
topological_pipeline.fit(train_data) | ||
|
||
forecast_base[forecast_base < 0] = 0 | ||
forecast_topo[forecast_topo < 0] = 0 | ||
forecast_base = np.ravel(pipeline_based.predict(test_data).predict) | ||
forecast_topo = np.ravel(topological_pipeline.predict(test_data).predict) | ||
|
||
plt.plot(input_data.features, label='real data') | ||
plt.plot(np.arange(len(target_series) - horizon, len(target_series)), | ||
forecast_base, label='forecast base') | ||
plt.plot(np.arange(len(target_series) - horizon, len(target_series)), | ||
forecast_topo, label='forecast topo') | ||
forecast_base[forecast_base < 0] = 0 | ||
forecast_topo[forecast_topo < 0] = 0 | ||
|
||
plt.grid() | ||
plt.legend() | ||
plt.show() | ||
plt.plot(input_data.features, label='real data') | ||
plt.plot(np.arange(len(target_series) - horizon, len(target_series)), | ||
forecast_base, label='forecast base') | ||
plt.plot(np.arange(len(target_series) - horizon, len(target_series)), | ||
forecast_topo, label='forecast topo') | ||
|
||
print('base') | ||
print(mean_squared_error(test_data.target, forecast_base, squared=False)) | ||
print( | ||
mean_absolute_percentage_error( | ||
test_data.target + | ||
1000, | ||
forecast_base + | ||
1000)) | ||
plt.grid() | ||
plt.legend() | ||
plt.show() | ||
|
||
print('topo') | ||
print(mean_squared_error(test_data.target, forecast_topo, squared=False)) | ||
print( | ||
mean_absolute_percentage_error( | ||
test_data.target + | ||
1000, | ||
forecast_topo + | ||
1000)) | ||
print('base') | ||
print(mean_squared_error(test_data.target, forecast_base, squared=False)) | ||
print(mean_absolute_percentage_error(test_data.target + 1000, forecast_base + 1000)) | ||
|
||
print('topo') | ||
print(mean_squared_error(test_data.target, forecast_topo, squared=False)) | ||
print(mean_absolute_percentage_error(test_data.target + 1000, forecast_topo + 1000)) |