Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement validation #13

Open
wants to merge 15 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/Core/Database/BaseDatabaseConfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def __init__(self,
if not isdir(existing_dir):
raise ValueError(f"[{self.name}] The given 'existing_dir'={existing_dir} does not exist.")
if len(existing_dir.split(sep)) > 1 and existing_dir.split(sep)[-1] == 'dataset':
existing_dir = join(*existing_dir.split(sep)[:-1])
existing_dir = sep.join(existing_dir.split(sep)[:-1])

# Check storage variables
if mode is not None:
Expand Down
6 changes: 6 additions & 0 deletions src/Core/Manager/DataManager.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,12 @@ def get_prediction(self,
self.pipeline.network_manager.compute_online_prediction(instance_id=instance_id,
normalization=self.normalization)

def set_eval(self):
self.database_manager.set_eval()

def set_train(self):
self.database_manager.set_train()

def close(self) -> None:
"""
Launch the closing procedure of the DataManager.
Expand Down
19 changes: 14 additions & 5 deletions src/Core/Manager/DatabaseManager.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,8 @@ def __init__(self,
self.create_partition()
# Complete a Database in a new session --> copy and load the existing directory
else:
copy_dir(src_dir=database_config.existing_dir, dest_dir=session, sub_folders='dataset')
copy_dir(src_dir=database_config.existing_dir, dest_dir=session,
sub_folders='dataset')
self.load_directory(rename_partitions=True)
# Complete a Database in the same session --> load the directory
else:
Expand All @@ -111,7 +112,8 @@ def __init__(self,
self.create_partition()
# Complete a Database in a new session --> copy and load the existing directory
else:
copy_dir(src_dir=database_config.existing_dir, dest_dir=session, sub_folders='dataset')
copy_dir(src_dir=database_config.existing_dir, dest_dir=session,
sub_folders='dataset')
self.load_directory()
# Complete a Database in the same directory --> load the directory
else:
Expand Down Expand Up @@ -285,8 +287,8 @@ def change_mode(self,

:param mode: Name of the Database mode.
"""

pass
self.mode = mode
self.index_samples()

##########################################################################################
##########################################################################################
Expand Down Expand Up @@ -416,6 +418,7 @@ def index_samples(self) -> None:
Create a new indexing list of samples. Samples are identified by [partition_id, line_id].
"""

self.sample_indices = empty((0, 2), dtype=int)
# Create the indices for each sample such as [partition_id, line_id]
for i, nb_sample in enumerate(self.json_content['nb_samples'][self.mode]):
partition_indices = empty((nb_sample, 2), dtype=int)
Expand Down Expand Up @@ -507,7 +510,7 @@ def compute_normalization(self) -> Dict[str, List[float]]:
for field in self.json_content['data_shape']:
table_name, field_name = field.split('.')
fields += [field_name] if table_name == 'Training' else []
normalization = {field: [0., 0.] for field in fields}
normalization = {field: [0., 1.] for field in fields}

# 2. Compute the mean of samples for each field
means = {field: [] for field in fields}
Expand Down Expand Up @@ -607,6 +610,12 @@ def load_partitions_fields(partition: Database,
##########################################################################################
##########################################################################################

def set_eval(self):
self.change_mode('validation')

def set_train(self):
self.change_mode('training')

def close(self):
"""
Launch the closing procedure of the DatabaseManager.
Expand Down
5 changes: 5 additions & 0 deletions src/Core/Manager/NetworkManager.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,9 @@ def compute_prediction_and_loss(self,
lines_id=data_lines)
# Apply normalization and convert to tensor
for field in batch.keys():
# batch can contain dicts if fields refer to joined tables
if isinstance(batch[field], dict):
batch[field] = batch[field][field]
batch[field] = array(batch[field])
if field in normalization:
batch[field] = self.normalize_data(data=batch[field],
Expand Down Expand Up @@ -228,6 +231,8 @@ def compute_online_prediction(self,

# Apply normalization and convert to tensor
for field in sample.keys():
if isinstance(sample[field], dict):
sample[field] = sample[field][field]
sample[field] = array([sample[field]])
if field in normalization.keys():
sample[field] = self.normalize_data(data=sample[field],
Expand Down
12 changes: 10 additions & 2 deletions src/Core/Manager/StatsManager.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,16 @@ def __init__(self,

# Open Tensorboard
tb = program.TensorBoard()
tb.configure(argv=[None, '--logdir', self.log_dir])
url = tb.launch()
port = 6006
tb.configure(argv=[None, '--logdir', self.log_dir, '--port', str(port)])
while True and port<7000:
try:
url = tb.launch()
break
except:
port +=1
tb.configure(argv=[None, '--logdir', self.log_dir, '--port', str(port)])
continue
w_open(url)

# Values
Expand Down
6 changes: 4 additions & 2 deletions src/Core/Network/BaseNetwork.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from typing import Any, Dict
from numpy import ndarray
from collections import namedtuple
from torch import as_tensor
import torch


class BaseNetwork:
Expand Down Expand Up @@ -115,7 +117,7 @@ def numpy_to_tensor(self,
:return: Converted tensor.
"""

return data.astype(self.config.data_type)
return as_tensor(data, dtype=getattr(torch, self.config.data_type)).requires_grad_(grad)

def tensor_to_numpy(self,
data: Any) -> ndarray:
Expand All @@ -126,7 +128,7 @@ def tensor_to_numpy(self,
:return: Converted array.
"""

return data.astype(self.config.data_type)
return data.detach().cpu().numpy()

def __str__(self) -> str:

Expand Down
12 changes: 10 additions & 2 deletions src/Core/Network/BaseNetworkConfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,11 @@ def __init__(self,
lr: Optional[float] = None,
require_training_stuff: bool = True,
loss: Optional[Any] = None,
optimizer: Optional[Any] = None):
loss_parameters: Type[dict] = None,
optimizer: Optional[Any] = None,
optimizer_parameters: Type[dict] = None,
scheduler_class: Any = None,
scheduler_parameters: dict = None):
"""
BaseNetworkConfig is a configuration class to parameterize and create BaseNetwork, BaseOptimization and
BaseTransformation for the NetworkManager.
Expand Down Expand Up @@ -88,7 +92,11 @@ def __init__(self,
configuration_name='optimization_config',
loss=loss,
lr=lr,
optimizer=optimizer)
optimizer=optimizer,
optimizer_parameters=optimizer_parameters,
loss_parameters=loss_parameters,
scheduler_class=scheduler_class,
scheduler_parameters=scheduler_parameters)
self.training_stuff: bool = (loss is not None) and (optimizer is not None) or (not require_training_stuff)

# NetworkManager parameterization
Expand Down
41 changes: 41 additions & 0 deletions src/Core/Pipelines/BaseTraining.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ def epoch_begin(self) -> None:
Called one at the beginning of each epoch.
"""

self.set_train()
self.batch_id = 0

def batch_condition(self) -> bool:
Expand Down Expand Up @@ -174,6 +175,23 @@ def optimize(self) -> None:
normalization=self.data_manager.normalization,
optimize=True)

def execute_validation(self):
self.set_eval()
id_batch = 0
while id_batch < self.nb_validation_batches:
self.validate()
id_batch += 1

def validate(self):
"""
| Pulls data from the manager and run a prediction step.
"""
self.data_manager.get_data(epoch=0)
self.loss_dict = self.network_manager.compute_prediction_and_loss(
data_lines=self.data_manager.data_lines,
normalization=self.data_manager.normalization,
optimize=False)

def batch_count(self) -> None:
"""
Increment the batch counter.
Expand Down Expand Up @@ -210,6 +228,9 @@ def epoch_end(self) -> None:
if self.stats_manager is not None:
self.stats_manager.add_train_epoch_loss(self.loss_dict['loss'], self.epoch_id)
self.network_manager.save_network()
if self.do_validation:
self.execute_validation()
self.stats_manager.add_test_loss(self.loss_dict['loss'], self.epoch_id)

def train_end(self) -> None:
"""
Expand All @@ -221,6 +242,26 @@ def train_end(self) -> None:
if self.stats_manager is not None:
self.stats_manager.close()

def set_eval(self):
# Set DBManager mode, build indices
self.data_manager.set_eval()
# Set network to eval mode
self.network_manager.set_eval()
# Connect the handler to the validation partition
self.data_manager.connect_handler(self.network_manager.get_database_handler())
# Create the links
self.network_manager.link_clients(self.data_manager.nb_environment)

def set_train(self):
# Set DBManager mode, build indices
self.data_manager.set_train()
# Set network to train mode
self.network_manager.set_train()
# Connect the handler to the training partition
self.data_manager.connect_handler(self.network_manager.get_database_handler())
# Create the links
self.network_manager.link_clients(self.data_manager.nb_environment)

def save_info_file(self) -> None:
"""
Save a .txt file that provides a template for user notes and the description of all the components.
Expand Down
112 changes: 112 additions & 0 deletions src/Core/Utils/yamlUtils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
import copy
import yaml
import importlib


def BaseYamlExporter(filename: str=None, var_dict:dict=None):
"""
| Exports variables in a yaml file, excluding classes, modules and functions. Additionally, variables with a name in
| excluded will not be exported.
:param str filename: Path to the file in which var_dict will be saved after filtering
:param dict var_dict: Dictionnary containing the key:val pairs to be saved. Key is a variable name and val its value
"""
export_dict = copy.deepcopy(var_dict)
recursive_convert_type_to_type_str(export_dict)
if filename is not None:
#Export to yaml file
with open(filename,'w') as f:
print(f"[BaseYamlExporter] Saving conf to {filename}.")
yaml.dump(export_dict, f)
return export_dict


def BaseYamlLoader(filename: str):
"""Loads a yaml file and converts the str representation of types to types using convert_type_str_to_type."""
with open(filename, 'r') as f:
loaded_dict = yaml.load(f, yaml.Loader)
recursive_convert_type_str_to_type(loaded_dict)
return loaded_dict


def recursive_convert_type_to_type_str(var_container):
"""Recursively converts types in a nested dict or iterable to their str representation using
convert_type_to_type_str."""
var_container_type = type(var_container)
if isinstance(var_container, dict):
keys = list(var_container.keys())
if 'excluded' in keys: #Special keyword that specify which keys should be removed
for exclude_key in var_container['excluded']:
if exclude_key in var_container: var_container.pop(exclude_key) #Remove the key listed in excluded
keys = list(var_container.keys()) #Update the keys
elif isinstance(var_container, (tuple, list, set)): #Is not a dict but is iterable.
keys = range(len(var_container))
var_container = list(var_container) #Allows to change elements in var_container
else:
raise ValueError(f"BaseYamlExporter: encountered an object to convert which is not a dict, tuple or list.")
for k in keys:
v = var_container[k]
if isinstance(v, type): # Object is just a type, not an instance
new_val = convert_type_to_type_str(v)
new_val = dict(type=new_val)
elif hasattr(v, '__iter__') and not isinstance(v, str): # Object contains other objects
new_val = recursive_convert_type_to_type_str(v)
else: # Object is assumed to not contain other objects
new_val = v
var_container[k] = new_val
if var_container_type in (tuple, set):
var_container = var_container_type(var_container) #Convert back to original type
return var_container


def recursive_convert_type_str_to_type(var_container):
"""Recursively converts str representation of types in a nested dict or iterable using convert_type_str_to_type."""
var_container_type = type(var_container)
if isinstance(var_container, dict):
keys = list(var_container.keys())
elif isinstance(var_container, (tuple, list, set)): # Is not a dict but is iterable.
keys = range(len(var_container))
var_container = list(var_container) #Allows to change elements in var_container
else:
raise ValueError(f"recursive_convert_type_str_to_type: "
f"encountered an object to convert which is not a dict, tuple or list.")
for k in keys:
v = var_container[k]
# Detection of a type object that was converted to str
if isinstance(v,dict) and len(v) == 1 and 'type' in v and isinstance(v['type'], str):
new_val = convert_type_str_to_type(v['type'])
elif hasattr(v, '__iter__') and not isinstance(v, str): # Object contains other objects
new_val = recursive_convert_type_str_to_type(v)
else:
new_val = v
var_container[k] = new_val
if var_container_type in (tuple, set):
var_container = var_container_type(var_container) #Convert back to original type
return var_container


def convert_type_str_to_type(name: str):
"""Converts a str representation of a type to a type."""
module = importlib.import_module('.'.join(name.split('.')[:-1]))
object_name_in_module = name.split('.')[-1]
return getattr(module, object_name_in_module)


def convert_type_to_type_str(type_to_convert: type):
"""Converts a type to its str representation."""
repr_str = repr(type_to_convert)
if repr_str.__contains__("<class "): #Class object, not instanciated
return repr_str.split("<class '")[1].split("'>")[0]
else:
raise ValueError(f"BaseYamlExporter: {repr_str} could not be converted to an object name.")


def unpack_pipeline_config(pipeline_config):
"""Initializes the network, environment and dataset config objects from the pipeline config."""
nested_configs_keys = ['network_config', 'environment_config', 'database_config']
unpacked = {}
for key in pipeline_config:
if key in nested_configs_keys and pipeline_config[key] is not None:
unpacked.update({key: pipeline_config[key][0](**pipeline_config[key][1])})
else:
unpacked.update({key: pipeline_config[key]})
return unpacked