From 1a192b2b1ad434fb508d1892cccf5833d292087b Mon Sep 17 00:00:00 2001 From: InnopolisU Date: Fri, 15 Nov 2024 11:36:57 +0300 Subject: [PATCH] Removed old ui based on streamlit --- ui/README.md | 2 - ui/__init__.py | 0 ui/augmentations.json | 1219 ---------------------- ui/control.py | 108 -- ui/dataset/__init__.py | 0 ui/dataset/hdf5.py | 580 ---------- ui/image_ui.py | 136 --- ui/pages/augmentation.py | 286 ----- ui/schema/__init__.py | 2 - ui/schema/base.py | 59 -- ui/schema/kneighbors_classifier.py | 61 -- ui/schema/linear_regression.py | 29 - ui/table_ui.py | 78 -- ui/tmp_pages/2_augmentations_assitant.py | 159 --- ui/tmp_pages/__init__.py | 0 ui/utils.py | 170 --- ui/visuals.py | 114 -- ui/webserver.py | 141 --- 18 files changed, 3144 deletions(-) delete mode 100755 ui/README.md delete mode 100755 ui/__init__.py delete mode 100755 ui/augmentations.json delete mode 100755 ui/control.py delete mode 100755 ui/dataset/__init__.py delete mode 100755 ui/dataset/hdf5.py delete mode 100644 ui/image_ui.py delete mode 100644 ui/pages/augmentation.py delete mode 100644 ui/schema/__init__.py delete mode 100644 ui/schema/base.py delete mode 100644 ui/schema/kneighbors_classifier.py delete mode 100644 ui/schema/linear_regression.py delete mode 100644 ui/table_ui.py delete mode 100755 ui/tmp_pages/2_augmentations_assitant.py delete mode 100644 ui/tmp_pages/__init__.py delete mode 100755 ui/utils.py delete mode 100755 ui/visuals.py delete mode 100755 ui/webserver.py diff --git a/ui/README.md b/ui/README.md deleted file mode 100755 index 6a682db6..00000000 --- a/ui/README.md +++ /dev/null @@ -1,2 +0,0 @@ -This project is a modification of the open-source project -https://github.com/IliaLarchenko/albumentations-demo diff --git a/ui/__init__.py b/ui/__init__.py deleted file mode 100755 index e69de29b..00000000 diff --git a/ui/augmentations.json b/ui/augmentations.json deleted file mode 100755 index 0b1cdd94..00000000 --- a/ui/augmentations.json +++ /dev/null @@ -1,1219 +0,0 @@ -{ - "Blur": [ - { - "defaults": [ - 3, - 7 - ], - "limits_list": [ - 3, - 100 - ], - "param_name": "blur_limit", - "type": "num_interval" - } - ], - "CLAHE": [ - { - "defaults": [ - 1, - 4 - ], - "limits_list": [ - 1, - 100 - ], - "param_name": "clip_limit", - "type": "num_interval" - }, - { - "defaults_list": [ - 8, - 8 - ], - "limits_list": [ - [ - 1, - 100 - ], - [ - 1, - 100 - ] - ], - "param_name": "tile_grid_size", - "subparam_names": [ - "height", - "width" - ], - "type": "several_nums" - } - ], - "CenterCrop": [ - { - "param_name": "height", - "placeholder": { - "defaults": "image_half_height", - "limits_list": [ - 1, - "image_height" - ] - }, - "type": "num_interval" - }, - { - "param_name": "width", - "placeholder": { - "defaults": "image_half_width", - "limits_list": [ - 1, - "image_width" - ] - }, - "type": "num_interval" - } - ], - "ChannelDropout": [ - { - "defaults": [ - 1, - 1 - ], - "limits_list": [ - 1, - 2 - ], - "param_name": "channel_drop_range", - "type": "num_interval" - }, - { - "defaults": 0, - "limits_list": [ - 0, - 255 - ], - "param_name": "fill_value", - "type": "num_interval" - } - ], - "ChannelShuffle": [], - "CoarseDropout": [ - { - "defaults_list": [ - 8, - 8 - ], - "limits_list": [ - 1, - 100 - ], - "min_diff": 0, - "param_name": [ - "min_holes", - "max_holes" - ], - "type": "min_max" - }, - { - "defaults_list": [ - 8, - 8 - ], - "limits_list": [ - 1, - 100 - ], - "min_diff": 0, - "param_name": [ - "min_height", - "max_height" - ], - "type": "min_max" - }, - { - "defaults_list": [ - 8, - 8 - ], - "limits_list": [ - 1, - 100 - ], - "min_diff": 0, - "param_name": [ - "min_width", - "max_width" - ], - "type": "min_max" - }, - { - "param_name": "fill_value", - "type": "rgb" - } - ], - "Crop": [ - { - "min_diff": 1, - "param_name": [ - "x_min", - "x_max" - ], - "placeholder": { - "defaults_list": [ - 0, - "image_half_width" - ], - "limits_list": [ - 0, - "image_width" - ] - }, - "type": "min_max" - }, - { - "min_diff": 1, - "param_name": [ - "y_min", - "y_max" - ], - "placeholder": { - "defaults_list": [ - 0, - "image_half_height" - ], - "limits_list": [ - 0, - "image_height" - ] - }, - "type": "min_max" - } - ], - "Cutout": [ - { - "defaults": 8, - "limits_list": [ - 1, - 100 - ], - "param_name": "num_holes", - "type": "num_interval" - }, - { - "defaults": 8, - "limits_list": [ - 1, - 100 - ], - "param_name": "max_h_size", - "type": "num_interval" - }, - { - "defaults": 8, - "limits_list": [ - 1, - 100 - ], - "param_name": "max_w_size", - "type": "num_interval" - }, - { - "param_name": "fill_value", - "type": "rgb" - } - ], - "Downscale": [ - { - "defaults_list": [ - 0.25, - 0.25 - ], - "limits_list": [ - 0.01, - 0.99 - ], - "param_name": [ - "scale_min", - "scale_max" - ], - "type": "min_max" - }, - { - "options_list": [ - 0, - 1, - 2, - 3, - 4 - ], - "param_name": "interpolation", - "type": "radio" - } - ], - "ElasticTransform": [ - { - "defaults": 1.0, - "limits_list": [ - 0.0, - 10.0 - ], - "param_name": "alpha", - "type": "num_interval" - }, - { - "defaults": 50.0, - "limits_list": [ - 0.0, - 200.0 - ], - "param_name": "sigma", - "type": "num_interval" - }, - { - "defaults": 50.0, - "limits_list": [ - 0.0, - 200.0 - ], - "param_name": "alpha_affine", - "type": "num_interval" - }, - { - "options_list": [ - 0, - 1, - 2, - 3, - 4 - ], - "param_name": "interpolation", - "type": "radio" - }, - { - "options_list": [ - 0, - 1, - 2, - 3, - 4 - ], - "param_name": "border_mode", - "type": "radio" - }, - { - "param_name": "value", - "type": "rgb" - } - ], - "Equalize": [ - { - "options_list": [ - "cv", - "pil" - ], - "param_name": "mode", - "type": "radio" - }, - { - "defaults": 1, - "param_name": "by_channels", - "type": "checkbox" - } - ], - "Flip": [], - "GaussNoise": [ - { - "defaults": [ - 10.0, - 50.0 - ], - "limits_list": [ - 0.0, - 500.0 - ], - "param_name": "var_limit", - "type": "num_interval" - }, - { - "defaults": 0.0, - "limits_list": [ - -100.0, - 100.0 - ], - "param_name": "mean", - "type": "num_interval" - } - ], - "GridDistortion": [ - { - "defaults": 5, - "limits_list": [ - 1, - 15 - ], - "param_name": "num_steps", - "type": "num_interval" - }, - { - "defaults": [ - -0.3, - 0.3 - ], - "limits_list": [ - -2.0, - 2.0 - ], - "param_name": "distort_limit", - "type": "num_interval" - }, - { - "options_list": [ - 0, - 1, - 2, - 3, - 4 - ], - "param_name": "interpolation", - "type": "radio" - }, - { - "options_list": [ - 0, - 1, - 2, - 3, - 4 - ], - "param_name": "border_mode", - "type": "radio" - }, - { - "param_name": "value", - "type": "rgb" - } - ], - "HorizontalFlip": [], - "HueSaturationValue": [ - { - "defaults": [ - -20, - 20 - ], - "limits_list": [ - -100, - 100 - ], - "param_name": "hue_shift_limit", - "type": "num_interval" - }, - { - "defaults": [ - -30, - 30 - ], - "limits_list": [ - -100, - 100 - ], - "param_name": "sat_shift_limit", - "type": "num_interval" - }, - { - "defaults": [ - -20, - 20 - ], - "limits_list": [ - -100, - 100 - ], - "param_name": "val_shift_limit", - "type": "num_interval" - } - ], - "ISONoise": [ - { - "defaults": [ - 0.01, - 0.05 - ], - "limits_list": [ - 0.0, - 1.0 - ], - "param_name": "color_shift", - "type": "num_interval" - }, - { - "defaults": [ - 0.1, - 0.5 - ], - "limits_list": [ - 0.0, - 2.0 - ], - "param_name": "intensity", - "type": "num_interval" - } - ], - "ImageCompression": [ - { - "options_list": [ - 0, - 1 - ], - "param_name": "compression_type", - "type": "radio" - }, - { - "defaults_list": [ - 80, - 100 - ], - "limits_list": [ - 0, - 100 - ], - "param_name": [ - "quality_lower", - "quality_upper" - ], - "type": "min_max" - } - ], - "InvertImg": [], - "JpegCompression": [ - { - "defaults_list": [ - 80, - 100 - ], - "limits_list": [ - 0, - 100 - ], - "param_name": [ - "quality_lower", - "quality_upper" - ], - "type": "min_max" - } - ], - "LongestMaxSize": [ - { - "defaults": 512, - "limits_list": [ - 1, - 1024 - ], - "param_name": "max_size", - "type": "num_interval" - }, - { - "options_list": [ - 0, - 1, - 2, - 3, - 4 - ], - "param_name": "interpolation", - "type": "radio" - } - ], - "MotionBlur": [ - { - "defaults": [ - 3, - 7 - ], - "limits_list": [ - 3, - 100 - ], - "param_name": "blur_limit", - "type": "num_interval" - } - ], - "MultiplicativeNoise": [ - { - "defaults": [ - 0.9, - 1.1 - ], - "limits_list": [ - 0.1, - 5.0 - ], - "param_name": "multiplier", - "type": "num_interval" - }, - { - "defaults": 1, - "param_name": "per_channel", - "type": "checkbox" - }, - { - "defaults": 1, - "param_name": "elementwise", - "type": "checkbox" - } - ], - "OpticalDistortion": [ - { - "defaults": [ - -0.3, - 0.3 - ], - "limits_list": [ - -2.0, - 2.0 - ], - "param_name": "distort_limit", - "type": "num_interval" - }, - { - "defaults": [ - -0.05, - 0.05 - ], - "limits_list": [ - -1.0, - 1.0 - ], - "param_name": "shift_limit", - "type": "num_interval" - }, - { - "options_list": [ - 0, - 1, - 2, - 3, - 4 - ], - "param_name": "interpolation", - "type": "radio" - }, - { - "options_list": [ - 0, - 1, - 2, - 3, - 4 - ], - "param_name": "border_mode", - "type": "radio" - }, - { - "param_name": "value", - "type": "rgb" - } - ], - "Posterize": [ - { - "defaults_list": [ - 4, - 4, - 4 - ], - "limits_list": [ - [ - 0, - 8 - ], - [ - 0, - 8 - ], - [ - 0, - 8 - ] - ], - "param_name": "num_bits", - "subparam_names": [ - "r", - "g", - "b" - ], - "type": "several_nums" - } - ], - "RGBShift": [ - { - "defaults": [ - -20, - 20 - ], - "limits_list": [ - -255, - 255 - ], - "param_name": "r_shift_limit", - "type": "num_interval" - }, - { - "defaults": [ - -20, - 20 - ], - "limits_list": [ - -255, - 255 - ], - "param_name": "g_shift_limit", - "type": "num_interval" - }, - { - "defaults": [ - -20, - 20 - ], - "limits_list": [ - -255, - 255 - ], - "param_name": "b_shift_limit", - "type": "num_interval" - } - ], - "RandomBrightness": [ - { - "defaults": [ - -0.2, - 0.2 - ], - "limits_list": [ - -1.0, - 1.0 - ], - "param_name": "limit", - "type": "num_interval" - } - ], - "RandomBrightnessContrast": [ - { - "defaults": [ - -0.2, - 0.2 - ], - "limits_list": [ - -1.0, - 1.0 - ], - "param_name": "brightness_limit", - "type": "num_interval" - }, - { - "defaults": [ - -0.2, - 0.2 - ], - "limits_list": [ - -1.0, - 1.0 - ], - "param_name": "contrast_limit", - "type": "num_interval" - }, - { - "defaults": 1, - "param_name": "brightness_by_max", - "type": "checkbox" - } - ], - "RandomContrast": [ - { - "defaults": [ - -0.2, - 0.2 - ], - "limits_list": [ - -1.0, - 1.0 - ], - "param_name": "limit", - "type": "num_interval" - } - ], - "RandomFog": [ - { - "defaults_list": [ - 0.1, - 0.2 - ], - "limits_list": [ - 0.0, - 1.0 - ], - "param_name": [ - "fog_coef_lower", - "fog_coef_upper" - ], - "type": "min_max" - }, - { - "defaults": 0.08, - "limits_list": [ - 0.0, - 1.0 - ], - "param_name": "alpha_coef", - "type": "num_interval" - } - ], - "RandomGamma": [ - { - "defaults": [ - 80, - 120 - ], - "limits_list": [ - 0, - 200 - ], - "param_name": "gamma_limit", - "type": "num_interval" - } - ], - "RandomGridShuffle": [ - { - "defaults_list": [ - 3, - 3 - ], - "limits_list": [ - [ - 1, - 10 - ], - [ - 1, - 10 - ] - ], - "param_name": "grid", - "subparam_names": [ - "vertical", - "horizontal" - ], - "type": "several_nums" - } - ], - "RandomRain": [ - { - "defaults_list": [ - -10, - 10 - ], - "limits_list": [ - -20, - 20 - ], - "param_name": [ - "slant_lower", - "slant_upper" - ], - "type": "min_max" - }, - { - "defaults": 20, - "limits_list": [ - 0, - 100 - ], - "param_name": "drop_length", - "type": "num_interval" - }, - { - "defaults": 1, - "limits_list": [ - 1, - 5 - ], - "param_name": "drop_width", - "type": "num_interval" - }, - { - "param_name": "drop_color", - "type": "rgb" - }, - { - "defaults": 7, - "limits_list": [ - 1, - 15 - ], - "param_name": "blur_value", - "type": "num_interval" - }, - { - "defaults": 0.7, - "limits_list": [ - 0.0, - 1.0 - ], - "param_name": "brightness_coefficient", - "type": "num_interval" - }, - { - "options_list": [ - "None", - "drizzle", - "heavy", - "torrential" - ], - "param_name": "rain_type", - "type": "radio" - } - ], - "RandomResizedCrop": [ - { - "param_name": "height", - "placeholder": { - "defaults": "image_height", - "limits_list": [ - 1, - "image_height" - ] - }, - "type": "num_interval" - }, - { - "param_name": "width", - "placeholder": { - "defaults": "image_width", - "limits_list": [ - 1, - "image_width" - ] - }, - "type": "num_interval" - }, - { - "defaults": [ - 0.08, - 1.0 - ], - "limits_list": [ - 0.01, - 1.0 - ], - "param_name": "scale", - "type": "num_interval" - }, - { - "defaults": [ - 0.75, - 1.3333333333333333 - ], - "limits_list": [ - 0.1, - 10.0 - ], - "param_name": "ratio", - "type": "num_interval" - }, - { - "options_list": [ - 0, - 1, - 2, - 3, - 4 - ], - "param_name": "interpolation", - "type": "radio" - } - ], - "RandomRotate90": [], - "RandomScale": [ - { - "defaults": [ - -0.1, - 0.1 - ], - "limits_list": [ - -0.9, - 2.0 - ], - "param_name": "scale_limit", - "type": "num_interval" - }, - { - "options_list": [ - 0, - 1, - 2, - 3, - 4 - ], - "param_name": "interpolation", - "type": "radio" - } - ], - "RandomSizedCrop": [ - { - "param_name": "min_max_height", - "placeholder": { - "defaults": [ - "image_half_height", - "image_height" - ], - "limits_list": [ - 1, - "image_height" - ] - }, - "type": "num_interval" - }, - { - "param_name": "height", - "placeholder": { - "defaults": "image_height", - "limits_list": [ - 1, - "image_height" - ] - }, - "type": "num_interval" - }, - { - "param_name": "width", - "placeholder": { - "defaults": "image_width", - "limits_list": [ - 1, - "image_width" - ] - }, - "type": "num_interval" - }, - { - "defaults": 1.0, - "limits_list": [ - 0.1, - 1.0 - ], - "param_name": "w2h_ratio", - "type": "num_interval" - }, - { - "options_list": [ - 0, - 1, - 2, - 3, - 4 - ], - "param_name": "interpolation", - "type": "radio" - } - ], - "RandomSnow": [ - { - "defaults_list": [ - 0.1, - 0.2 - ], - "limits_list": [ - 0.0, - 1.0 - ], - "param_name": [ - "snow_point_lower", - "snow_point_upper" - ], - "type": "min_max" - }, - { - "defaults": 2.5, - "limits_list": [ - 0.0, - 5.0 - ], - "param_name": "brightness_coeff", - "type": "num_interval" - } - ], - "Resize": [ - { - "param_name": "height", - "placeholder": { - "defaults": "image_half_height", - "limits_list": [ - 1, - "image_height" - ] - }, - "type": "num_interval" - }, - { - "param_name": "width", - "placeholder": { - "defaults": "image_half_width", - "limits_list": [ - 1, - "image_width" - ] - }, - "type": "num_interval" - }, - { - "options_list": [ - 0, - 1, - 2, - 3, - 4 - ], - "param_name": "interpolation", - "type": "radio" - } - ], - "Rotate": [ - { - "defaults": [ - -90, - 90 - ], - "limits_list": [ - -360, - 360 - ], - "param_name": "limit", - "type": "num_interval" - }, - { - "options_list": [ - 0, - 1, - 2, - 3, - 4 - ], - "param_name": "interpolation", - "type": "radio" - }, - { - "options_list": [ - 0, - 1, - 2, - 3, - 4 - ], - "param_name": "border_mode", - "type": "radio" - }, - { - "param_name": "value", - "type": "rgb" - } - ], - "ShiftScaleRotate": [ - { - "defaults": [ - -0.06, - 0.06 - ], - "limits_list": [ - -1.0, - 1.0 - ], - "param_name": "shift_limit", - "type": "num_interval" - }, - { - "defaults": [ - -0.1, - 0.1 - ], - "limits_list": [ - -2.0, - 2.0 - ], - "param_name": "scale_limit", - "type": "num_interval" - }, - { - "defaults": [ - -90, - 90 - ], - "limits_list": [ - -360, - 360 - ], - "param_name": "rotate_limit", - "type": "num_interval" - }, - { - "options_list": [ - 0, - 1, - 2, - 3, - 4 - ], - "param_name": "interpolation", - "type": "radio" - }, - { - "options_list": [ - 0, - 1, - 2, - 3, - 4 - ], - "param_name": "border_mode", - "type": "radio" - }, - { - "param_name": "value", - "type": "rgb" - } - ], - "SmallestMaxSize": [ - { - "defaults": 512, - "limits_list": [ - 1, - 1024 - ], - "param_name": "max_size", - "type": "num_interval" - }, - { - "options_list": [ - 0, - 1, - 2, - 3, - 4 - ], - "param_name": "interpolation", - "type": "radio" - } - ], - "Solarize": [ - { - "defaults": 128, - "limits_list": [ - 0, - 255 - ], - "param_name": "threshold", - "type": "num_interval" - } - ], - "ToGray": [], - "ToSepia": [], - "Transpose": [], - "VerticalFlip": [] -} diff --git a/ui/control.py b/ui/control.py deleted file mode 100755 index 04147e8a..00000000 --- a/ui/control.py +++ /dev/null @@ -1,108 +0,0 @@ -import streamlit as st - - -def select_num_interval( - param_name: str, limits_list: list, defaults, n_for_hash, **kwargs -): - st.sidebar.subheader(param_name) - min_max_interval = st.sidebar.slider( - "", - limits_list[0], - limits_list[1], - defaults, - key=hash(param_name + str(n_for_hash)), - ) - return min_max_interval - - -def select_several_nums( - param_name, - subparam_names, - limits_list, - defaults_list, - n_for_hash, - **kwargs, -): - st.sidebar.subheader(param_name) - result = [] - assert len(limits_list) == len(defaults_list) - assert len(subparam_names) == len(defaults_list) - - for name, limits, defaults in zip( - subparam_names, limits_list, defaults_list - ): - result.append( - st.sidebar.slider( - name, - limits[0], - limits[1], - defaults, - key=hash(param_name + name + str(n_for_hash)), - ) - ) - return tuple(result) - - -def select_min_max( - param_name, limits_list, defaults_list, n_for_hash, min_diff=0, **kwargs -): - assert len(param_name) == 2 - result = list( - select_num_interval( - " & ".join(param_name), limits_list, defaults_list, n_for_hash - ) - ) - if result[1] - result[0] < min_diff: - diff = min_diff - result[1] + result[0] - if result[1] + diff <= limits_list[1]: - result[1] = result[1] + diff - elif result[0] - diff >= limits_list[0]: - result[0] = result[0] - diff - else: - result = limits_list - return tuple(result) - - -def select_RGB(param_name, n_for_hash, **kwargs): - result = select_several_nums( - param_name, - subparam_names=["Red", "Green", "Blue"], - limits_list=[[0, 255], [0, 255], [0, 255]], - defaults_list=[0, 0, 0], - n_for_hash=n_for_hash, - ) - return tuple(result) - - -def replace_none(string): - if string == "None": - return None - else: - return string - - -def select_radio(param_name, options_list, n_for_hash, **kwargs): - st.sidebar.subheader(param_name) - result = st.sidebar.radio( - "", options_list, key=hash(param_name + str(n_for_hash)) - ) - return replace_none(result) - - -def select_checkbox(param_name, defaults, n_for_hash, **kwargs): - st.sidebar.subheader(param_name) - result = st.sidebar.checkbox( - "True", defaults, key=hash(param_name + str(n_for_hash)) - ) - return result - - -# dict from param name to function showing this param -param2func = { - "num_interval": select_num_interval, - "several_nums": select_several_nums, - "radio": select_radio, - "rgb": select_RGB, - "checkbox": select_checkbox, - "min_max": select_min_max, -} diff --git a/ui/dataset/__init__.py b/ui/dataset/__init__.py deleted file mode 100755 index e69de29b..00000000 diff --git a/ui/dataset/hdf5.py b/ui/dataset/hdf5.py deleted file mode 100755 index 59d99a55..00000000 --- a/ui/dataset/hdf5.py +++ /dev/null @@ -1,580 +0,0 @@ -# import math -import random -import sys -from functools import partial -from typing import List - -import h5py -import numpy as np -import torch -from aeronet.dataset import BandCollection -from aeronet.dataset import parse_directory -from tqdm import tqdm as tqdm - -# import os -# import glob -# import albumentations as albu -# import torch.nn as nn -# import torchvision - - -def _to_tensor(x, **kwargs): - return x.transpose(2, 0, 1).astype("float32") - - -def _get_preprocessing_fn(mean=None, std=None): - def preprocess_input(x, mean, std, **kwargs): - if mean is not None: - mean = np.array(mean) - x = x - mean - if std is not None: - std = np.array(std) - x = x / std - return x - - return partial(preprocess_input, mean=mean, std=std) - - -def _augment_and_preproc(image, mask, augmentations, preprocessing): - if augmentations is not None: - augmented = augmentations(image=image, mask=mask) - image = augmented["image"] - mask = augmented["mask"] - - if preprocessing is not None: - image = preprocessing(image) - - # assert self.path_to_hdf5.endswith('NIR_B11_Red_Green_1-11.hdf5') - # orig = image.copy() - # eps = 1e-6 - # image[:,:,0] = (orig[:,:,0] - orig[:,:,2] + eps) / (orig[:,:,0] + orig[:,:,2] + eps) # NDVI - # image[:,:,1] = (orig[:,:,3] - orig[:,:,0] + eps) / (orig[:,:,3] + orig[:,:,0] + eps) # NDWI - # image[:,:,2] = (orig[:,:,3] - orig[:,:,1] + eps) / (orig[:,:,3] + orig[:,:,1] + eps) # MNDWI - # image[:,:,3] = (orig[:,:,1] - orig[:,:,0] + eps) / (orig[:,:,1] + orig[:,:,0] + eps) # NDBI - - # convert to tensor shape - image = _to_tensor(image) - mask = _to_tensor(mask) - - return image, mask - - -def _get_class_weights(arr): - labels = np.array(arr) - uniq_values = np.unique(labels) - class_weights = np.zeros(labels.shape) - for v in uniq_values: - x = (labels == v).astype("float32") - p = np.sum(x) / len(labels) - class_weights = np.where( - x > 0, - # '0' - background, '1' - quarries, '2' - fp quarries (also background) - (1 / p if v == 1 else 0.75 / p if v == 0 else 0.25 / p) - if len(uniq_values) == 3 - else 1 - p - if len(uniq_values) == 2 - else 1 / p, - # 1 / p if v == 1 else 1 / (p * (len(uniq_values) - 1)), - class_weights, - ) - return class_weights - - -class Dataset(torch.utils.data.Dataset): - def __init__( - self, - path_to_hdf5, - in_channels=3, - augmentations=None, - preprocessing=None, - with_mosaic=False, - ): - self.path_to_hdf5 = path_to_hdf5 - self.augmentations = augmentations - self.preprocessing = preprocessing - self.in_channels = in_channels - self.mean = None - self.std = None - self.with_mosaic = with_mosaic - - with h5py.File(self.path_to_hdf5, "r") as f: - self.len = f["len"][0] - try: - self.class_ids = f["class_ids"][:] - self.class_weights = _get_class_weights(self.class_ids) - except Exception as e: - self.class_weights = f["class_weights"][:] - self.class_ids = (self.class_weights > 1).astype("float32") - assert len(self.class_ids) == len(self.class_weights) - - print("dataset:", self.path_to_hdf5) - print("class_ids unique:", np.unique(self.class_ids)) - print("class_weights unique:", np.unique(self.class_weights)) - - try: - self.mean = f["mean"][:] - self.std = f["std"][:] - except Exception as e: - # default RGB-values for hdf5 without corresponding data - self.mean = [0.03935977, 0.06333545, 0.07543217] - self.std = [0.00962875, 0.01025132, 0.00893212] - # assert len(self.mean) == self.in_channels - # assert len(self.mean) == len(self.std) - - self.norm_class_weights = self.class_weights / self.class_weights.sum() - - if self.preprocessing is None: - self.preprocessing = _get_preprocessing_fn(self.mean, self.std) - else: - self.mean = self.preprocessing.keywords["mean"] - self.std = self.preprocessing.keywords["std"] - - # assert (self.mean is None or \ - # len(self.mean) == self.in_channels) - # assert self.len == len(self.class_weights) - - def _get_item_data(self, index): - with h5py.File(self.path_to_hdf5, "r") as f: - data = f[str(index)] - image = data[:, :, : self.in_channels] - mask = data[:, :, self.in_channels :] - try: - raster_name = f[str(index) + "_raster_name"][()].decode( - "utf-8" - ) - geo_bounds = f[str(index) + "_geo_bounds"][:] - assert len(geo_bounds) == 4 - meta = {"raster_name": raster_name, "geo_bounds": geo_bounds} - except Exception: - meta = {} - - assert len(image.shape) == len(mask.shape) - assert image.shape[:2] == mask.shape[:2] - - return image, mask, meta - - def __getitem__(self, index): - image, mask, meta = self._get_item_data(index) - - # https://www.kaggle.com/nvnnghia/awesome-augmentation - if self.with_mosaic and random.randint(0, 1): - s = image.shape[:2] - h, w = image.shape[:2] - # indices = [index] + [np.random.randint(0, self.len) for _ in range(3)] - indices = [index] + [ - ind - for ind in np.random.choice( - self.len, size=3, p=self.norm_class_weights - ) - ] - for i, ind in enumerate(indices): - if i > 0: - image, mask, _ = self._get_item_data(ind) - img, msk = image, mask - if i == 0: # top left - xc, yc = [ - int(random.uniform(s[0] * 0.5, s[1] * 1.5)) - for _ in range(2) - ] # mosaic center x, y - img4 = np.full( - (s[0] * 2, s[1] * 2, img.shape[2]), 0, dtype=np.float32 - ) # base image with 4 tiles - msk4 = np.full( - (s[0] * 2, s[1] * 2, msk.shape[2]), 0, dtype=np.float32 - ) # base mask with 4 tiles - x1a, y1a, x2a, y2a = ( - 0, - 0, - xc, - yc, - ) # xmin, ymin, xmax, ymax (large image) - x1b, y1b, x2b, y2b = ( - w - xc, - h - yc, - w, - h, - ) # xmin, ymin, xmax, ymax (small image) - elif i == 1: # top right - x1a, y1a, x2a, y2a = xc, 0, s[1] * 2, yc - x1b, y1b, x2b, y2b = 0, h - yc, x2a - x1a, h - elif i == 2: # bottom left - x1a, y1a, x2a, y2a = 0, yc, xc, s[0] * 2 - x1b, y1b, x2b, y2b = w - xc, 0, w, y2a - y1a - elif i == 3: # bottom right - x1a, y1a, x2a, y2a = xc, yc, s[1] * 2, s[0] * 2 - x1b, y1b, x2b, y2b = 0, 0, x2a - x1a, y2a - y1a - - yb = np.abs(np.arange(y1b, y2b)) - xb = np.abs(np.arange(x1b, x2b)) - - # 101-reflection for indices greater or equal than h (or w) - # transform [..., h-3, h-2, h-1, h+0, h+1, ...] to - # [..., h-3, h-2, h-1, h-2, h-3, ...] - bad_ybi = np.where(yb >= h)[0] - if bad_ybi.any(): - fixed_ybi = [ - y - 2 * (i + 1) for i, y in enumerate(bad_ybi) - ] - yb[bad_ybi] = yb[fixed_ybi] - - bad_xbi = np.where(xb >= w)[0] - if bad_xbi.any(): - fixed_xbi = [ - x - 2 * (i + 1) for i, x in enumerate(bad_xbi) - ] - xb[bad_xbi] = xb[fixed_xbi] - - img4[y1a:y2a, x1a:x2a] = img[np.ix_(yb, xb)] - msk4[y1a:y2a, x1a:x2a] = msk[np.ix_(yb, xb)] - image, mask = img4, msk4 - - image, mask = _augment_and_preproc( - image, mask, self.augmentations, self.preprocessing - ) - - return {"image": image, "mask": mask, "metadata": meta} - - def __len__(self): - return self.len - - -class DatasetUnion(torch.utils.data.Dataset): - def __init__(self, dataset_list): - self.datasets = [] - self.class_ids = [] - self.len = 0 - - for dataset in dataset_list: - self.datasets.append(dataset) - self.class_ids.extend(list(dataset.class_ids)) - self.len += len(dataset) - self.class_weights = _get_class_weights(self.class_ids) - - print("DatasetUnion") - print("class_ids unique:", np.unique(self.class_ids)) - print("class_weights unique:", np.unique(self.class_weights)) - - mean_std_valid = [ - dataset.mean is not None and dataset.std is not None - for dataset in dataset_list - ] - assert all(mean_std_valid) or all(list(~np.array(mean_std_valid))) - - self.mean = None - self.std = None - if all(mean_std_valid): - mean_s = [d.mean for d in dataset_list] - std_s = [d.std for d in dataset_list] - self.mean = np.mean(mean_s, axis=0) - self.std = np.mean(std_s, axis=0) - self.preprocessing = _get_preprocessing_fn(self.mean, self.std) - - def __getitem__(self, index): - for dataset in self.datasets: - if index < len(dataset): - return dataset[index] - else: - index -= len(dataset) - raise Exception(f"DatasetUnion: dataset element {index} not found") - - def __len__(self): - return self.len - - -class WeightedRandomCropDataset(torch.utils.data.Dataset): - def __init__( - self, - tif_folders: List[str], - crop_size=(224, 224), - band_channels=["B04", "B03", "B02"], - band_labels=["100"], - augmentations=None, - is_train=True, - verbose=False, - ): - self.tif_folders = tif_folders - self.augmentations = augmentations - self.preprocessing = None - self.in_channels = len(band_channels) - self.mean = None - self.std = None - self.verbose = verbose - - self.band_channels = band_channels - self.band_labels = band_labels - self.crop_size = crop_size - self.samples_per_image = ( - (10980 + crop_size[0] - 1) // crop_size[0] - ) ** 2 - self.len = self.samples_per_image * len(self.tif_folders) - - self.is_train = is_train - - self.images = [] - self.masks = [] - self.sizes = [] - self.fg_indicies = [] - self.bg_indicies = [] - self.val_indices = [] - with tqdm( - self.tif_folders, - desc="WeightedRandomCrop dataset: " - + ("train" if is_train else "val"), - file=sys.stdout, - ) as iterator: - for tif_folder in iterator: - image = BandCollection( - parse_directory(tif_folder, self.band_channels) - ).ordered(*self.band_channels) - mask = BandCollection( - parse_directory(tif_folder, self.band_labels) - ) - - self.images.append(image) - self.masks.append(mask) - self.sizes.append(mask.numpy().shape) - - m = mask.numpy().flatten() - if self.is_train: - self.fg_indicies.append(np.where(m > 0)[0]) - self.bg_indicies.append(np.where(m == 0)[0]) - else: - weights = _get_class_weights(mask.numpy().flatten()) - weights /= weights.sum() - self.val_indices.append( - np.random.choice( - len(weights), self.samples_per_image, p=weights - ) - ) - - def __getitem__(self, index): - image_idx = index // self.samples_per_image - - image = self.images[image_idx] - mask = self.masks[image_idx] - _, height, width = self.sizes[image_idx] - - if self.is_train: - fg_indicies = self.fg_indicies[image_idx] - bg_indicies = self.bg_indicies[image_idx] - if len(fg_indicies) > 0 and np.random.rand() < 0.5: - fg_idx = np.random.randint(len(fg_indicies)) - crop_idx = fg_indicies[fg_idx] - else: - bg_idx = np.random.randint(len(bg_indicies)) - crop_idx = bg_indicies[bg_idx] - else: - crop_idx = self.val_indices[image_idx][ - index % self.samples_per_image - ] - - h_crop, w_crop = self.crop_size - y_crop = crop_idx // width - x_crop = crop_idx % width - dy = np.random.randint(self.crop_size[0]) - dx = np.random.randint(self.crop_size[1]) - - y_crop = np.max(y_crop - dy, 0) - x_crop = np.max(x_crop - dx, 0) - - sample = image.sample(y_crop, x_crop, h_crop, w_crop) - - if self.verbose: - coords = [x_crop, y_crop, w_crop, h_crop] - print("coords:", coords) - - x_sample = sample.numpy().astype("float32") - y_sample = ( - mask.sample(y_crop, x_crop, h_crop, w_crop) - .numpy() - .astype("float32") - ) - - image = x_sample.transpose(1, 2, 0) - image[image < 0] = 0 - image[image > 1] = 1 - - mask = y_sample.transpose(1, 2, 0) - - assert len(image.shape) == len(mask.shape) - assert image.shape[:2] == mask.shape[:2] - - image, mask = _augment_and_preproc( - image, mask, self.augmentations, self.preprocessing - ) - - return { - "image": image, - "mask": mask, - "metadata": { - "raster_name": self.tif_folders[image_idx], - "geo_bounds": torch.tensor(sample.bounds[:]), - }, - } - - def __len__(self): - return self.len - - -def _intersection(a, b): - if a is None or b is None: - return 0 - - # [xl, yl, xr, yr] - rect_a = np.array([a[0], a[1], a[0] + a[2] - 1, a[1] + a[3] - 1]) - rect_b = np.array([b[0], b[1], b[0] + b[2] - 1, b[1] + b[3] - 1]) - - bl = np.max([rect_a, rect_b], axis=0) - ur = np.min([rect_a, rect_b], axis=0) - - wh = ur[2:4] - bl[0:2] + 1 - - return max(0, wh[0] * wh[1]) - - -class _TiledQuarryImage(torch.utils.data.Dataset): - def __init__( - self, - tif_folder: str, - crop_size, - crop_step, - band_channels, - band_labels, - augmentations, - ): - if crop_step[0] < 1 or crop_step[0] > crop_size[0]: - raise ValueError() - if crop_step[1] < 1 or crop_step[1] > crop_size[1]: - raise ValueError() - - self.tif_folder = tif_folder - self.crop_size = crop_size - self.crop_step = crop_step - self.band_channels = band_channels - self.band_labels = band_labels - self.augmentations = augmentations - self.mean = None - self.std = None - self.preprocessing = _get_preprocessing_fn(self.mean, self.std) - - self.image = BandCollection( - parse_directory(tif_folder, self.band_channels) - ) - self.mask = BandCollection( - parse_directory(tif_folder, self.band_labels) - ) - - if self.image.shape[1:] != self.mask.shape[1:]: - raise ValueError( - "Shape of image does not corresponds to shape of mask" - ) - - _, rows, cols = self.mask.shape - - class_ids = [] - crops = [] - y_tiles = ( - rows - self.crop_size[0] + 1 + self.crop_step[0] - 1 - ) // self.crop_step[0] - x_tiles = ( - cols - self.crop_size[1] + 1 + self.crop_step[1] - 1 - ) // self.crop_step[1] - n_tiles = x_tiles * y_tiles - - skip = False - with tqdm( - range(n_tiles), desc="TiledQuarryImage", file=sys.stdout - ) as iterator: - for n in iterator: - y = (n // x_tiles) * self.crop_step[0] - x = (n % x_tiles) * self.crop_step[1] - crop = (x, y, self.crop_size[1], self.crop_size[0]) - prev_crop = crops[-1] if len(crops) > 0 else None - s = _intersection(prev_crop, crop) - if skip and s > 0: - continue - crops.append(crop) - mask = self.mask.sample( - y, x, self.crop_size[0], self.crop_size[1] - ).numpy() - m_sum = np.sum(mask) - class_ids.append(1 if m_sum > 0 else 0) - skip = not (m_sum > 0) - - self.crops = np.array(crops) - self.class_ids = np.array(class_ids) - self.class_weights = _get_class_weights(self.class_ids) - - print(np.unique(self.class_ids)) - print(np.unique(self.class_weights)) - - def __getitem__(self, index): - x, y, w, h = self.crops[index] - - mask = self.mask.sample(y, x, h, w).numpy().astype("float32") - - sample = self.image.sample(y, x, h, w) - image = sample.numpy().astype("float32") - - image = image.transpose(1, 2, 0) - mask = mask.transpose(1, 2, 0) - - assert len(image.shape) == len(mask.shape) - assert image.shape[:2] == mask.shape[:2] - - image[image < 0] = 0 - image[image > 1] = 1 - image, mask = _augment_and_preproc( - image, mask, self.augmentations, self.preprocessing - ) - - return { - "image": image, - "mask": mask, - "metadata": { - "raster_name": self.tif_folder, - "geo_bounds": torch.tensor(sample.bounds[:]), - }, - } - - def __len__(self): - return len(self.crops) - - -class TiledDataset(torch.utils.data.ConcatDataset): - def __init__( - self, - tif_folders: List[str], - crop_size=(224, 224), - crop_step=(224, 224), - band_channels=["B04", "B03", "B02"], - band_labels=["100"], - augmentations=None, - ): - self.mean = None - self.std = None - self.preprocessing = _get_preprocessing_fn(self.mean, self.std) - - class_ids = [] - datasets = [] - for n, tif_folder in enumerate(tif_folders): - print(f"Image [{n + 1}/{len(tif_folders)}]: {tif_folder}") - dataset = _TiledQuarryImage( - tif_folder, - crop_size, - crop_step, - band_channels, - band_labels, - augmentations, - ) - class_ids.extend(list(dataset.class_ids)) - datasets.append(dataset) - self.class_ids = np.array(class_ids) - self.class_weights = _get_class_weights(self.class_ids) - - print("TiledDataset") - print("class_ids unique:", np.unique(self.class_ids)) - print("p for 1-class:", np.sum(self.class_ids) / len(self.class_ids)) - print("class_weights unique:", np.unique(self.class_weights)) - - super().__init__(datasets) diff --git a/ui/image_ui.py b/ui/image_ui.py deleted file mode 100644 index 768ca26c..00000000 --- a/ui/image_ui.py +++ /dev/null @@ -1,136 +0,0 @@ -# def image_input_handler(st): -# task = st.selectbox( -# "What task you want to solve?", -# [ -# "segmentation", -# "detection", -# "image generation", -# "classification", -# "regression", -# ], -# ) -# src_data_folderpath = st.text_input("Path to the folder with data") -# -# datamodule_cfg = {} -# model_params = {} -# recreate_compressed_format = st.multiselect( -# "Which datasets should be compressed/recompressed?", -# ["fit", "val", "test"], -# default=[], -# ) # 'fit', 'val', 'test' -# -# bands_names = st.multiselect( -# "Image channel names", -# ["RED", "GRN", "BLU", "NIR"], -# default=["RED", "GRN", "BLU", "NIR"], -# ) -# -# compressed_ds_dst_path = st.text_input("Where to store compressed files?") -# -# datamodule_cfg = { -# "overwrite": recreate_compressed_format, -# "src_data_folderpath": make_valid_path(src_data_folderpath), -# "dst_tile_ds_path": make_valid_path(compressed_ds_dst_path), -# "bands_idx": bands_names, -# } -# model = None -# if task == "segmentation": -# model = st.selectbox( -# "What model you want to try?", ["Unet", "Unet++", "DeepLabV3", "DeepLabV3+"] -# ) -# model_params["name"] = model.lower() -# -# model_size = st.selectbox( -# "What size of the model you want?", ["Small", "Medium", "Large"] -# ) -# -# with st.expander("Open to tune parameters"): -# model_params["activation"] = st.selectbox( -# "activation function", ["sigmoid", "None", "softmax"] -# ) -# -# if model_size == "Small": -# model_params["encoder_name"] = st.selectbox( -# "encoder", ["resnet18", "dpn68"] -# ) -# # model_params['n_jobs'] = st.slider("n_jobs", value=1, min_value=-1, max_value=16) -# elif model_size == "Medium": -# model_params["encoder_name"] = st.selectbox( -# "encoder", ["resnet34", "resnet50"] -# ) -# elif model_size == "Large": -# model_params["encoder_name"] = st.selectbox( -# "encoder", ["resnet101", "resnet152"] -# ) -# else: -# raise NotImplementedError -# elif task == "classification": -# # get the number of classes -# model = st.selectbox( -# "Выберите модель:", -# ["resnet18", "alexnet", "vgg16", "googlenet", "inception_v3"], -# ) -# model_params["name"] = model.lower() -# -# model_params["classes"] = st.number_input("number of classes", value=1, min_value=1) -# model_params["in_channels"] = st.number_input( -# "number of channels", value=3, min_value=1, max_value=8 -# ) -# -# if model is not None: -# # dictionary where key is the filename where values are the descriptions -# suitable_metrics = {} -# metrics_files = list(Path("config/metrics/").iterdir()) -# from omegaconf import OmegaConf -# from innofw.utils.framework import ( -# is_suitable_for_framework, -# is_suitable_for_task, -# map_model_to_framework, -# ) -# from innofw import utils -# -# name = model_params["name"] -# model_params["_target_"] = utils.find_suitable_model(name) -# del model_params["name"] -# if task == "segmentation": -# model = hydra.utils.instantiate(model_params) -# elif task == "classification": -# model_params_copy = model_params.copy() -# del model_params_copy["in_channels"] -# del model_params_copy["classes"] -# # del model_params_copy['in_channels'] -# model = hydra.utils.instantiate(model_params_copy) -# model_params["name"] = name -# framework = map_model_to_framework(model) -# for file in metrics_files: -# with open(file, "r") as f: -# contents = OmegaConf.load(f) -# try: -# # st.info(f"image-{task}" in {contents.requirements.task}) -# if is_suitable_for_task( -# contents, f"image-{task}" -# ) and is_suitable_for_framework(contents, framework): -# suitable_metrics[file.stem] = contents.objects[ -# framework -# ].description -# except: -# pass -# -# if len(suitable_metrics) == 0: -# st.warning( -# f"Unable to find required metrics for model: {model_params['name']}" -# ) -# else: -# metrics = st.multiselect("What metrics to measure", suitable_metrics.keys()) -# if len(metrics) != 0: -# st.write(f"Metrics Description: {suitable_metrics[metrics[0]]}") -# -# # ====== Configuration Creation ===== # -# if st.button("Create!"): -# save_config("image", task, datamodule_cfg, model_params) -# st.markdown( -# "[click here for augmentations](/Аугментация)", unsafe_allow_html=True -# ) -# # import streamlit.components.v1 as components -# # components.html("click here") -# diff --git a/ui/pages/augmentation.py b/ui/pages/augmentation.py deleted file mode 100644 index 7d3cc405..00000000 --- a/ui/pages/augmentation.py +++ /dev/null @@ -1,286 +0,0 @@ -if __name__ == "__main__": - import math - from pathlib import Path - - # - import numpy as np - import streamlit as st - import albumentations as albu - import matplotlib.pyplot as plt - - from innofw.constants import CLI_FLAGS - from innofw.core.augmentations import Augmentation - - # - from innofw.utils.framework import ( - get_datamodule, - get_obj, - get_model, - map_model_to_framework, - ) - from innofw.utils.getters import get_trainer_cfg - from ui.utils import ( - load_augmentations_config, - ) - - # read from configuration file - # instantiate the augmentations - # show a textarea with config file contents - - import os - import argparse - import streamlit as st - import argparse - import os - - parser = argparse.ArgumentParser( - description="This app provides UI for the framework" - ) - - parser.add_argument("experiments", default="ui.yaml", help="Config name") - try: - args = parser.parse_args() - except SystemExit as e: - # This exception will be raised if --help or invalid command line arguments - # are used. Currently streamlit prevents the program from exiting normally - # so we have to do a hard exit. - os._exit(e.code) - - # set up the env flag - try: - os.environ[CLI_FLAGS.DISABLE.value] = "True" - except: - pass - - # init - if "conf_button_clicked" not in st.session_state: - st.session_state.conf_button_clicked = False - - if "new_batch_btn_clicked" not in st.session_state: - st.session_state.new_batch_btn_clicked = False - - if "indices" not in st.session_state: - st.session_state.indices = None - - def callback(): - st.session_state.conf_button_clicked = True - - def callback_new_batch(): - st.session_state.new_batch_btn_clicked = True - - # if st.sidebar.button("Configure!", on_click=callback) or st.session_state.conf_button_clicked: - - st.header("LOGO") - - # task = st.sidebar.selectbox("task", ["image-classification", "image-segmentation"]) - # framework = st.sidebar.selectbox( - # "framework", list(Frameworks), format_func=lambda x: x.value.lower() - # ) - - # data_path = st.sidebar.text_input("path to data") - # in_channels = st.sidebar.number_input("input channels:", min_value=1, max_value=20) - # batch_size = st.sidebar.slider("batch size:", value=3, min_value=1, max_value=16) - # prep_func_type = st.sidebar.selectbox('preprocessing function:', prep_funcs.keys()) - - # get the list of transformations names - interface_type = "Simple" - # transform_names = select_transformations(augmentations, interface_type) - # get parameters for each transform - transforms = [albu.Resize(300, 300, always_apply=True)] - transforms = albu.ReplayCompose(transforms) - - # apply augmentations - # augmentations = get_training_augmentation() - from hydra import compose, initialize - - # hydra.core.global_hydra.GlobalHydra.instance().clear() - selected_aug = None - aug_conf = None - - # from hydra.core.global_hydra import GlobalHydra - # GlobalHydra.instance().clear() - - # load the config - augmentations = load_augmentations_config( - None, - str(Path("ui/augmentations.json").resolve()), # placeholder_params - ) - - try: - initialize(config_path="../../config/", version_base="1.1") - except: - pass - cfg = compose( - config_name="train", - overrides=[f"experiments={args.experiments}"], - return_hydra_config=True, - ) - - dm_cfg = cfg["datasets"] - # dm_cfg = DictConfig({ - # 'task': [task], - # 'data_path': data_path - # }) - if ( - dm_cfg["train"]["source"] is not None - and dm_cfg["train"]["source"] != "" - ): - trainer_cfg = get_trainer_cfg(cfg) - model = get_model(cfg.models, trainer_cfg) - task = cfg.get("task") - framework = map_model_to_framework(model) - batch_size = cfg.get("batch_size") - if batch_size is None: - batch_size = 6 - # st.info(f"creating datamodule") - - if "train_dataloader" in st.session_state: - train_dataloader = st.session_state.train_dataloader - else: - dm = get_datamodule( - dm_cfg, framework, task=task, augmentations=transforms - ) - - dm.setup() - - train_dataloader = dm.train_dataloader() - st.session_state.train_dataloader = train_dataloader - - # st.info(f"getting images") - if ( - "images" in st.session_state and "indices" in st.session_state - ): # or ('preserve' in st.session_state and not st.session_state.preserve) - images = st.session_state.images - dataset_len = images.shape[0] - with_replace = batch_size > dataset_len - indices = st.session_state.indices - else: - batch = iter(train_dataloader).next() - try: - images = batch[0].detach().cpu().numpy() - # images = batch["scenes"].detach().cpu().numpy() - except: - if isinstance(batch, list): - images = ( - batch[0].detach().cpu().numpy() - ) # np.array([img.detach().cpu().numpy() for img in batch]) - else: - images = batch.detach().cpu().numpy() - - dataset_len = images.shape[0] - with_replace = batch_size > dataset_len - indices = np.random.choice( - range(dataset_len), batch_size, with_replace - ) - - if not isinstance(images, np.ndarray): - images = images.detach().cpu().numpy() - - # selected_aug = None - - try: - # st.info("initializing augmentations") - selected_aug = Augmentation( - get_obj(cfg, "augmentations", task, framework) - ) - aug_conf = cfg["augmentations"]["implementations"]["torch"][ - "Compose" - ]["object"] - # selected_aug = hydra.utils.instantiate(aug_conf) - except Exception as e: - st.warning(f"unable to process the train.yaml file. {e}") - - # st.info(f"{augmentations} {aug_conf}") - - if selected_aug is not None and aug_conf is not None: - # st.info("created augmentations") - st.info("Результаты применения трансформации") - - # selected_aug = get_transormations_params(transform_names, augmentations) - # selected_aug = albu.ReplayCompose(selected_aug) - - aug_images = [] - for img in images: - try: - aug_img = selected_aug(img)["image"] - aug_img = aug_img[:3, ...] - except: - aug_img = selected_aug(img) - aug_img = aug_img[:3, ...] - - if not isinstance(aug_img, np.ndarray): - aug_img = aug_img.detach().cpu().numpy() - - aug_img = np.moveaxis(aug_img, 0, -1) - aug_images.append(aug_img) - - ncols = 3 if batch_size >= 3 else batch_size - nrows = math.ceil(batch_size / ncols) - - fig, axs = plt.subplots(nrows, ncols) - if nrows > 1: - axs = axs.flatten() - elif ncols == 1: - axs = [axs] - - # apply_aug = True - # data = None - - st.warning("Исходные изображения") - for i, ax in zip(indices, axs): - ax.set_axis_off() - imgs = images[i] - # st.info(f"orig: {imgs.shape}") - - # imgs = np.moveaxis(imgs, 0, -1) # [..., :3] - - ax.imshow(imgs) - st.pyplot(fig) - - fig, axs = plt.subplots(nrows, ncols) - if nrows > 1: - axs = axs.flatten() - elif ncols == 1: - axs = [axs] - - dataset_len = len(aug_images) - with_replace = batch_size > dataset_len - - st.warning("Применение трансформации") - - for i, ax in zip(indices, axs): - ax.set_axis_off() - ax.imshow(aug_images[i]) - - st.pyplot(fig) - - col1, col2 = st.columns(2) - - st.session_state.indices = indices - st.session_state.images = images - - b1 = col1.button("Обновить") - - def clear_img_ind(): - try: - del st.session_state.images - except: - pass - - try: - del st.session_state.indices - except: - pass - - b2 = col2.button("Следующий набор", on_click=clear_img_ind) - from omegaconf import OmegaConf - from pprint import pformat - - formatted_conf_str = pformat(OmegaConf.to_yaml(aug_conf), indent=4) - formatted_conf_str = [ - item.replace("\\n", "\n").replace("'", "") - for item in formatted_conf_str[1:-1].split("\n") - ] - formatted_conf_str = "\n".join(formatted_conf_str) - st.markdown("\n**Конфигурация:**\n") - st.text(formatted_conf_str) diff --git a/ui/schema/__init__.py b/ui/schema/__init__.py deleted file mode 100644 index 8d670e7a..00000000 --- a/ui/schema/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from .kneighbors_classifier import KNeighborsClassifierSchema -from .linear_regression import LinearRegressionSchema diff --git a/ui/schema/base.py b/ui/schema/base.py deleted file mode 100644 index 0c57bf26..00000000 --- a/ui/schema/base.py +++ /dev/null @@ -1,59 +0,0 @@ -from pydantic import BaseModel -from pydantic import Field - - -class Model(BaseModel): - pass - - -class SklearnSchema(Model): - n_jobs: int = Field( - None, - ge=-1, - description="The number of parallel jobs to run for neighbors search.", - ) - - # from typing import Set - # - # import streamlit as st - - # - - # class OtherData(BaseModel): - # text: str - # integer: int - # - # - # class SelectionValue(str, Enum): - # FOO = "foo" - # BAR = "bar" - # - # - # class ExampleModel(BaseModel): - # long_text: str = Field( - # ..., format="multi-line", description="Unlimited text property" - # ) - # integer_in_range: int = Field( - # 20, - # ge=10, - # le=30, - # multiple_of=2, - # description="Number property with a limited range.", - # ) - # single_selection: SelectionValue = Field( - # ..., description="Only select a single item from a set." - # ) - # multi_selection: Set[SelectionValue] = Field( - # ..., description="Allows multiple items from a set." - # ) - # read_only_text: str = Field( - # "Lorem ipsum dolor sit amet", - # description="This is a ready only text.", - # readOnly=True, - # ) - # single_object: OtherData = Field( - # ..., - # description="Another object embedded into this model.", - # ) - # - # diff --git a/ui/schema/kneighbors_classifier.py b/ui/schema/kneighbors_classifier.py deleted file mode 100644 index cf5bbbe3..00000000 --- a/ui/schema/kneighbors_classifier.py +++ /dev/null @@ -1,61 +0,0 @@ -from enum import Enum -from typing import ClassVar - -from pydantic import Field - -from .base import SklearnSchema -from innofw.constants import TaskType - - -class WeightValues(str, Enum): - UNIFORM = "uniform" - DISTANCE = "distance" - - -class AlgorithmValues(str, Enum): - AUTO = "auto" - BALL_TREE = "ball_tree" - KD_TREE = "kd_tree" - BRUTE = "brute" - - -class MetricValues(str, Enum): - MINKOWSKI = "minkowski" - - -class KNeighborsClassifierSchema(SklearnSchema): - task: ClassVar[str] = TaskType.CLASSIFICATION - name: ClassVar[str] = "knn_classifier" - target: ClassVar[str] = "sklearn.neighbors.KNeighborsClassifier" - - n_neighbors: int = Field( - default=5, - gt=0, - description="Number of neighbors to use by default for kneighbors queries.", - ) - weights: WeightValues = Field( - default=WeightValues.UNIFORM, - description="Weight function used in prediction.", - ) - algorithm: AlgorithmValues = Field( - default=AlgorithmValues.AUTO, - description="Algorithm used to compute the nearest neighbors.", - ) - leaf_size: int = Field( - 30, - gt=0, - description="Leaf size passed to BallTree or KDTree. This can affect the speed of the construction and query, as well as the memory required to store the tree. The optimal value depends on the nature of the problem.", - ) - p: int = Field( - 2, - ge=1, - description="Power parameter for the Minkowski metric. When p = 1, this is equivalent to using manhattan_distance (l1), and euclidean_distance (l2) for p = 2. For arbitrary p, minkowski_distance (l_p) is used.", - ) - metric: MetricValues = Field( - MetricValues.MINKOWSKI, - description="Metric to use for distance computation.", - ) - # metric_params: dict = Field( - # None, - # description="Additional keyword arguments for the metric function." - # ) diff --git a/ui/schema/linear_regression.py b/ui/schema/linear_regression.py deleted file mode 100644 index 0b726a38..00000000 --- a/ui/schema/linear_regression.py +++ /dev/null @@ -1,29 +0,0 @@ -from typing import ClassVar - -from pydantic import Field - -from .base import SklearnSchema -from innofw.constants import TaskType - - -class LinearRegressionSchema(SklearnSchema): - task: ClassVar[TaskType] = TaskType.REGRESSION - name: ClassVar[str] = "linear_regression" - target: ClassVar[str] = "sklearn.linear_model.LinearRegression" - - fit_intercept: bool = Field( - True, - description="Whether to calculate the intercept for this model. If set to False, no intercept will be used in calculations (i.e. data is expected to be centered).", - ) - normalize: bool = Field( - False, - description="This parameter is ignored when fit_intercept is set to False.", - ) - copy_X: bool = Field( - True, - description="If True, X will be copied; else, it may be overwritten.", - ) - positive: bool = Field( - False, - description="When set to True, forces the coefficients to be positive. This option is only supported for dense arrays.", - ) diff --git a/ui/table_ui.py b/ui/table_ui.py deleted file mode 100644 index 129f0807..00000000 --- a/ui/table_ui.py +++ /dev/null @@ -1,78 +0,0 @@ -import hydra.utils -import pandas as pd -import streamlit as st -import streamlit_pydantic as sp - -from innofw.constants import TaskType -from innofw.utils.extra import is_intersecting -from innofw.utils.framework import map_model_to_framework - -# iterate over a module and find all model schemas -# find schemas which match the task - - -def find_model_schema(task: TaskType): - import inspect - from ui import schema - - # search for suitable datamodule from codebase - clsmembers = inspect.getmembers(schema, inspect.isclass) - objects, classes, class_paths = [], [], [] - for _, cls in clsmembers: - obj = cls() - # st.success(f"{obj.task.const}") - if is_intersecting(task, obj.task): - objects.append(obj) - classes.append(cls) - class_paths.append(".".join([cls.__module__, cls.__name__])) - - if len(objects) == 0: - raise ValueError(f"Could not find model schema for the {task}") - - return objects, classes, class_paths - - -def table_input_handler(task, project_tab, model_tab, data_tab): - objects, classes, class_paths = find_model_schema(task) - model = project_tab.selectbox( - label="Select model", - options=classes, - format_func=lambda cls: cls().name, - ) - idx = classes.index(model) - with model_tab.container(): - model_cfg = sp.pydantic_input(key="my_form", model=model) - - model_cfg["_target_"] = objects[idx].target - - # if model_cfg: - # st.json(model_cfg.json()) - st.success(f"{model_cfg}") - model = hydra.utils.instantiate(model_cfg) - framework = map_model_to_framework(model) - - model_cfg_copy = model_cfg.copy() - model_cfg_copy.update(name="some name", description="some description") - - data_path = project_tab.text_input("Provide location to the file:") - if data_path: - df = pd.read_csv(data_path) - data_tab.dataframe(df) - - columns = list(df.columns) - target_feature = data_tab.selectbox( - "What is the target feature?", columns[::-1] - ) - - dataset_cfg = { - "target_col": target_feature, - "train": {"source": data_path}, - "task": f"table-{task.value}", - "framework": framework, - "name": "something", - "description": "something", - "markup_info": "something", - "date_time": "something", - } - - return {"model_cfg": model_cfg_copy, "dataset_cfg": dataset_cfg} diff --git a/ui/tmp_pages/2_augmentations_assitant.py b/ui/tmp_pages/2_augmentations_assitant.py deleted file mode 100755 index c6cc2da0..00000000 --- a/ui/tmp_pages/2_augmentations_assitant.py +++ /dev/null @@ -1,159 +0,0 @@ -import math -from pathlib import Path - -import albumentяations as albu -import matplotlib.pyplot as plt -import numpy as np -import streamlit as st -from omegaconf import DictConfig - -from innofw.utils import find_suitable_datamodule -from ui.utils import load_augmentations_config -from ui.utils import select_transformations -from ui.visuals import ( - get_transormations_params, -) - -# -# - - -# init -if "conf_button_clicked" not in st.session_state: - st.session_state.conf_button_clicked = False - -if "new_batch_btn_clicked" not in st.session_state: - st.session_state.new_batch_btn_clicked = False - -if "indices" not in st.session_state: - st.session_state.indices = None - - -def callback(): - st.session_state.conf_button_clicked = True - - -def callback_new_batch(): - st.session_state.new_batch_btn_clicked = True - - -if ( - st.sidebar.button("Configure!", on_click=callback) - or st.session_state.conf_button_clicked -): - task = st.sidebar.selectbox( - "task", ["image-classification", "image-segmentation"] - ) - framework = st.sidebar.selectbox("framework", ["torch", "sklearn"]) - - data_path = st.sidebar.text_input("path to data") - in_channels = st.sidebar.number_input( - "input channels:", min_value=1, max_value=20 - ) - batch_size = st.sidebar.slider("batch size:", min_value=1, max_value=16) - # prep_func_type = st.sidebar.selectbox('preprocessing function:', prep_funcs.keys()) - - # /home/qazybek/Projects/InnoFramework3/tests/data/images/segmentation/arable - - # load the config - augmentations = load_augmentations_config( - None, - str(Path("ui/augmentations.json").resolve()), # placeholder_params - ) - - # get the list of transformations names - interface_type = "Simple" - transform_names = select_transformations(augmentations, interface_type) - # get parameters for each transform - transforms = [albu.Resize(300, 300, always_apply=True)] - transforms = albu.ReplayCompose(transforms) - - # apply augmentations - # augmentations = get_training_augmentation() - - dm_cfg = DictConfig({"task": [task], "data_path": data_path}) - if data_path is not None and data_path != "": - dm = find_suitable_datamodule(task, framework, dm_cfg, aug=transforms) - dm.setup() - train_dataloader = dm.train_dataloader() - - if ( - "images" in st.session_state and "indices" in st.session_state - ): # or ('preserve' in st.session_state and not st.session_state.preserve) - images = st.session_state.images - dataset_len = images.shape[0] - with_replace = batch_size > dataset_len - indices = st.session_state.indices - else: - batch = iter(train_dataloader).next() - images = batch[0].detach().cpu().numpy() - dataset_len = images.shape[0] - with_replace = batch_size > dataset_len - indices = np.random.choice( - range(dataset_len), batch_size, with_replace - ) - - selected_aug = get_transormations_params( - transform_names, augmentations - ) - selected_aug = albu.ReplayCompose(selected_aug) - - aug_images = [] - for img in images: - aug_img = selected_aug(image=img)["image"] - aug_images.append(aug_img) - - ncols = 3 if batch_size >= 3 else batch_size - nrows = math.ceil(batch_size / ncols) - - fig, axs = plt.subplots(nrows, ncols) - if nrows > 1: - axs = axs.flatten() - elif ncols == 1: - axs = [axs] - - # apply_aug = True - # data = None - - st.title("Original Images") - for i, ax in zip(indices, axs): - ax.set_axis_off() - ax.imshow(images[i]) - st.pyplot(fig) - - fig, axs = plt.subplots(nrows, ncols) - if nrows > 1: - axs = axs.flatten() - elif ncols == 1: - axs = [axs] - - dataset_len = len(aug_images) - with_replace = batch_size > dataset_len - - st.title("Augmented Images") - - for i, ax in zip(indices, axs): - ax.set_axis_off() - ax.imshow(aug_images[i]) - - st.pyplot(fig) - - col1, col2 = st.columns(2) - - st.session_state.indices = indices - st.session_state.images = images - - b1 = col1.button("Обновить") - - def clear_img_ind(): - try: - del st.session_state.images - except: - pass - - try: - del st.session_state.indices - except: - pass - - b2 = col2.button("Следующий набор", on_click=clear_img_ind) diff --git a/ui/tmp_pages/__init__.py b/ui/tmp_pages/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/ui/utils.py b/ui/utils.py deleted file mode 100755 index 799ad198..00000000 --- a/ui/utils.py +++ /dev/null @@ -1,170 +0,0 @@ -import argparse -import json -import os -import uuid - -import cv2 -import numpy as np -import streamlit as st - - -@st.cache -def get_arguments(): - """Return the values of CLI params""" - parser = argparse.ArgumentParser() - parser.add_argument("--image_folder", default="images") - parser.add_argument("--image_width", default=400, type=int) - args = parser.parse_args() - return getattr(args, "image_folder"), getattr(args, "image_width") - - -@st.cache -def get_images_list(path_to_folder: str) -> list: - """Return the list of images from folder - Args: - path_to_folder (str): absolute or relative path to the folder with images - """ - image_names_list = [ - x - for x in os.listdir(path_to_folder) - if x[-3:] in ["jpg", "peg", "png"] - ] - return image_names_list - - -@st.cache -def load_image(image_name: str, path_to_folder: str, bgr2rgb: bool = True): - """Load the image - Args: - image_name (str): name of the image - path_to_folder (str): path to the folder with image - bgr2rgb (bool): converts BGR image to RGB if True - """ - path_to_image = os.path.join(path_to_folder, image_name) - image = cv2.imread(path_to_image) - if bgr2rgb: - image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) - return image - - -def upload_image(bgr2rgb: bool = True): - """Uoload the image - Args: - bgr2rgb (bool): converts BGR image to RGB if True - """ - file = st.sidebar.file_uploader( - "Upload your image (jpg, jpeg, or png)", ["jpg", "jpeg", "png"] - ) - image = cv2.imdecode(np.fromstring(file.read(), np.uint8), 1) - if bgr2rgb: - image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) - return image - - -# @st.cache -def load_augmentations_config( - placeholder_params: dict, - path_to_config: str = "configs/augmentations.json", -) -> dict: - """Load the json config with params of all transforms - Args: - placeholder_params (dict): dict with values of placeholders - path_to_config (str): path to the json config file - """ - with open(path_to_config, "r") as config_file: - augmentations = json.load(config_file) - # for name, params in augmentations.items(): - # params = [fill_placeholders(param, placeholder_params) for param in params] - return augmentations - - -def fill_placeholders(params: dict, placeholder_params: dict) -> dict: - """Fill the placeholder values in the config file - Args: - params (dict): original params dict with placeholders - placeholder_params (dict): dict with values of placeholders - """ - if "placeholder" in params: - placeholder_dict = params["placeholder"] - for k, v in placeholder_dict.items(): - if isinstance(v, list): - params[k] = [] - for element in v: - if element in placeholder_params: - params[k].append(placeholder_params[element]) - else: - params[k].append(element) - else: - if v in placeholder_params: - params[k] = placeholder_params[v] - else: - params[k] = v - params.pop("placeholder") - return params - - -def get_params_string(param_values: dict) -> str: - """Generate the string from the dict with parameters - Args: - param_values (dict): dict of "param_name" -> "param_value" - """ - params_string = ", ".join( - [k + "=" + str(param_values[k]) for k in param_values.keys()] - ) - return params_string - - -def get_placeholder_params(image): - return { - "image_width": image.shape[1], - "image_height": image.shape[0], - "image_half_width": int(image.shape[1] / 2), - "image_half_height": int(image.shape[0] / 2), - } - - -def select_transformations(augmentations: dict, interface_type: str) -> list: - # in the Simple mode you can choose only one transform - if interface_type == "Simple": - transform_names = [ - st.sidebar.selectbox( - "Select a transformation:", sorted(list(augmentations.keys())) - ) - ] - # in the professional mode you can choose several transforms - elif interface_type == "Professional": - transform_names = [ - st.sidebar.selectbox( - "Select transformation №1:", sorted(list(augmentations.keys())) - ) - ] - while transform_names[-1] != "None": - transform_names.append( - st.sidebar.selectbox( - f"Select transformation №{len(transform_names) + 1}:", - ["None"] + sorted(list(augmentations.keys())), - ) - ) - transform_names = transform_names[:-1] - else: - raise NotImplementedError() - return transform_names - - -def show_random_params(data: dict, interface_type: str = "Professional"): - """Shows random params used for transformation (from A.ReplayCompose)""" - if interface_type == "Professional": - st.subheader("Random params used") - random_values = {} - for applied_params in data["replay"]["transforms"]: - random_values[ - applied_params["__class_fullname__"].split(".")[-1] - ] = applied_params["params"] - st.write(random_values) - - -def get_uuid() -> str: - """Function to generate uuid of length ...""" - gen_uuid = uuid.uuid4() - gen_uuid = str(gen_uuid).split("-")[0] - return gen_uuid diff --git a/ui/visuals.py b/ui/visuals.py deleted file mode 100755 index 9c72cb25..00000000 --- a/ui/visuals.py +++ /dev/null @@ -1,114 +0,0 @@ -import albumentations as A -import cv2 -import streamlit as st - -from .control import param2func -from .utils import get_images_list -from .utils import load_image -from .utils import upload_image - - -def show_logo(): - st.image(load_image("logo.png", "../images"), format="PNG") - - -def select_image(path_to_images: str, interface_type: str = "Simple"): - """Show interface to choose the image, and load it - Args: - path_to_images (dict): path ot folder with images - interface_type (dict): mode of the interface used - Returns: - (status, image) - status (int): - 0 - if everything is ok - 1 - if there is error during loading of image file - 2 - if user hasn't uploaded photo yet - """ - image_names_list = get_images_list(path_to_images) - if len(image_names_list) < 1: - return 1, 0 - else: - if interface_type == "Professional": - image_name = st.sidebar.selectbox( - "Select an image:", image_names_list + ["Upload my image"] - ) - else: - image_name = st.sidebar.selectbox( - "Select an image:", image_names_list - ) - - if image_name != "Upload my image": - try: - image = load_image(image_name, path_to_images) - return 0, image - except cv2.error: - return 1, 0 - else: - try: - image = upload_image() - return 0, image - except cv2.error: - return 1, 0 - except AttributeError: - return 2, 0 - - -def show_transform_control(transform_params: dict, n_for_hash: int) -> dict: - param_values = {"p": 1.0} - if len(transform_params) == 0: - st.sidebar.text("Transform has no parameters") - else: - for param in transform_params: - control_function = param2func[param["type"]] - if isinstance(param["param_name"], list): - returned_values = control_function( - **param, n_for_hash=n_for_hash - ) - for name, value in zip(param["param_name"], returned_values): - param_values[name] = value - else: - param_values[param["param_name"]] = control_function( - **param, n_for_hash=n_for_hash - ) - return param_values - - -def show_credentials(): - st.markdown("* * *") - st.subheader("Credentials:") - st.markdown( - ( - "Source: [github.com/IliaLarchenko/albumentations-demo]" - "(https://github.com/IliaLarchenko/albumentations-demo)" - ) - ) - st.markdown( - ( - "Albumentations library: [github.com/albumentations-team/albumentations]" - "(https://github.com/albumentations-team/albumentations)" - ) - ) - st.markdown( - ( - "Image Source: [pexels.com/royalty-free-images]" - "(https://pexels.com/royalty-free-images/)" - ) - ) - - -def get_transormations_params( - transform_names: list, augmentations: dict -) -> list: - transforms = [] - for i, transform_name in enumerate(transform_names): - # select the params values - st.sidebar.subheader("Params of the " + transform_name) - param_values = show_transform_control(augmentations[transform_name], i) - transforms.append(getattr(A, transform_name)(**param_values)) - return transforms - - -def show_docstring(obj_with_ds): - st.markdown("* * *") - st.subheader("Docstring for " + obj_with_ds.__class__.__name__) - st.text(obj_with_ds.__doc__) diff --git a/ui/webserver.py b/ui/webserver.py deleted file mode 100755 index 17741b16..00000000 --- a/ui/webserver.py +++ /dev/null @@ -1,141 +0,0 @@ -# standard libraries -import argparse -import os -from datetime import datetime -from enum import Enum - -import streamlit as st -from omegaconf import DictConfig -from table_ui import table_input_handler - -from innofw.constants import TaskType -from innofw.schema.experiment import ExperimentConfig -from ui.utils import get_uuid - -# third-party libraries -# local modules - -parser = argparse.ArgumentParser( - description="This app provides UI for the framework" -) - -parser.add_argument("experiments", default="ui.yaml", help="Config name") -try: - args = parser.parse_args() -except SystemExit as e: - # This exception will be raised if --help or invalid command line arguments - # are used. Currently streamlit prevents the program from exiting normally - # so we have to do a hard exit. - os._exit(e.code) - - -def input_type_handler(input_type, task, *args, **kwargs): - if input_type == "table": - return table_input_handler(task, *args, **kwargs) - # elif input_type == "image": - # image_input_handler(st) - - -class ClassificationMetrics(Enum): - ACCURACY = "ACCURACY" - PRECISION = "PRECISION" - RECALL = "RECALL" - - -class RegressionMetrics(Enum): - MAE = "mean_absolute_error" - MSE = "mean_squared_error" - R2 = "R2" - - -def get_task_metrics(task: TaskType): - if task == TaskType.REGRESSION: - return RegressionMetrics - else: - return ClassificationMetrics - - -# ===== USER INTERFACE ===== # -# 1. Project info -if "project_info" not in st.session_state: - st.session_state["project_info"] = { - "title": "Random Title", - "author": "Random Author", - "uuid": get_uuid(), - "date": datetime.today().strftime("%d-%m-%Y"), - } - -st.title(f"{st.session_state.project_info['title']}") -st.header(f"by {st.session_state.project_info['author']}") -st.subheader( - f"Date: {st.session_state.project_info['date']} Uuid: {st.session_state.project_info['uuid']}" -) - -project_tab, model_tab, data_tab = st.tabs(["Project", "Model", "Data"]) - -# tab 1 -st.session_state.project_info["title"] = project_tab.text_input("Title") -st.session_state.project_info["author"] = project_tab.text_input("Author") -clear_ml = project_tab.checkbox("Use ClearML") -task_name = None -queue = None -if clear_ml: - task_name = project_tab.text_input("What is the name of the experiment?") - queue = project_tab.text_input( - "Do you want to execute experiment in the agent? If yes, specify queue." - ) - -task = project_tab.selectbox( - "What task you want to solve?", - list(TaskType), - format_func=lambda x: x.value.lower(), -) -task_metrics = get_task_metrics(task) - -metrics = project_tab.multiselect( - "What metrics to measure", - list(task_metrics), - format_func=lambda x: x.value, -) - -# 2. Project configuration -input_type = project_tab.selectbox( - "What is your input type?", ["table", "image"] -) - -user_input = input_type_handler( - input_type, task, project_tab, model_tab, data_tab -) - -if ( - st.button("save") - and "model_cfg" in user_input - and "dataset_cfg" in user_input -): - cfg = DictConfig( - { - "metrics": metrics, - "models": user_input["model_cfg"], - "datasets": user_input["dataset_cfg"], - "task": f"{input_type}-{task}", - "accelerator": "cpu", - "project": st.session_state.project_info["title"], - } - ) - - # st.success(cfg) - exp = ExperimentConfig(**cfg) - exp.to_yaml() - -# ====== Launching the Training Code ===== # -# @hydra.main(config_path="config/", config_name="ui") -# def start(cfg: DictConfig): -# # task = setup_clear_ml(cfg) -# # if task: -# # st.text_input(f'Link to task in ClearMl ui: {task.get_output_log_web_page()}') -# metric_results = run_pipeline(cfg, train=True) # , ui=True -# -# -# with st.form(key="training"): -# if st.form_submit_button("Start training!"): -# start()