Skip to content

Commit

Permalink
small documentation fixes and touchups
Browse files Browse the repository at this point in the history
  • Loading branch information
thehrh committed Sep 19, 2024
1 parent 8496f00 commit 217df0c
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 71 deletions.
126 changes: 68 additions & 58 deletions pisa/core/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,7 @@
from __future__ import absolute_import, print_function

from collections.abc import Sequence
from collections import OrderedDict, defaultdict
import copy
from itertools import chain
from collections import defaultdict

import numpy as np

Expand All @@ -20,10 +18,9 @@
from pisa.core.translation import histogram, lookup, resample
from pisa.utils.comparisons import ALLCLOSE_KW
from pisa.utils.log import logging
from pisa.utils.profiler import line_profile, profile


class ContainerSet(object):
class ContainerSet():
"""
Class to hold a set of container objects
Expand Down Expand Up @@ -51,10 +48,12 @@ def __repr__(self):

@property
def is_map(self):
'''Is current representation a map/grid'''
if len(self.containers):
return self.containers[0].is_map

def add_container(self, container):
'''Append a container whose name mustn't exist yet'''
if container.name in self.names:
raise ValueError('container with name %s already exists'%container.name)
self.containers.append(container)
Expand Down Expand Up @@ -97,8 +96,10 @@ def get_shared_keys(self, rep_indep=True):
return ()

return tuple(
set.intersection(*[set(c.all_keys_incl_aux_data if rep_indep else c.keys_incl_aux_data)
for c in self.containers])
set.intersection(*[
set(c.all_keys_incl_aux_data if rep_indep else c.keys_incl_aux_data)
for c in self.containers
])
)

def link_containers(self, key, names):
Expand All @@ -118,8 +119,10 @@ def link_containers(self, key, names):

link_names = set(names) & set(self.names)
if len(link_names) < len(names):
logging.warning("Skipping containers %s in linking, as those are not present"%(set(names) - set(self.names)))

logging.warning(
"Skipping containers %s in linking, as those are not present"
% (set(names) - set(self.names))
)
containers = [self.__getitem__(name) for name in link_names]
logging.trace('Linking containers %s into %s'%(link_names, key))
new_container = VirtualContainer(key, containers)
Expand All @@ -139,7 +142,7 @@ def __getitem__(self, key):
if len(self.linked_containers) > 0:
linked_names = [c.name for c in self.linked_containers]
if key in linked_names:
return self.linked_containers[linked_names.index(key)]
return self.linked_containers[linked_names.index(key)]
raise KeyError(f"No name `{key}` in container")

def __iter__(self):
Expand All @@ -150,7 +153,7 @@ def __iter__(self):
return iter(containers_to_be_iterated)

def get_mapset(self, key, error=None):
"""For a given key, get a PISA MapSet
"""For a given key, get a MapSet
Parameters
----------
Expand All @@ -170,16 +173,16 @@ def get_mapset(self, key, error=None):
return MapSet(name=self.name, maps=maps)


class VirtualContainer(object):
class VirtualContainer():
"""
Class providing a virtual container for linked individual containers
It should just behave like a normal container
For reading, it just uses one container as a representative (no checkng at the mment
if the others actually contain the same data)
For reading, it just uses one container as a representative
(no checking at the moment if the others actually contain the same data)
For writting, it creates one object that is added to all containers
For writing, it creates one object that is added to all containers
Parameters
----------
Expand All @@ -194,15 +197,18 @@ def __init__(self, name, containers):
# check and set link flag
for container in containers:
if container.linked:
raise ValueError('Cannot link container %s since it is already linked'%container.name)
raise ValueError(
'Cannot link container %s since it is already linked'
% container.name
)
container.linked = True
self.containers = containers

def __repr__(self):
return f'VirtualContainer containing {[c.name for c in self]}'

def unlink(self):
'''Reset flag and copy all accessed keys'''
'''Reset link flag and copy all accessed keys'''
# reset flag
for container in self:
container.linked = False
Expand All @@ -219,28 +225,32 @@ def __setitem__(self, key, value):
container[key] = value

def set_aux_data(self, key, val):
'''See `Container.set_aux_data`'''
for container in self:
container.set_aux_data(key, val)

def mark_changed(self, key):
# copy all
'''Copy data under this key from representative container into
all others and then mark all as changed (see `Container.mark_changed`)'''
for container in self.containers[1:]:
container[key] = np.copy(self.containers[0][key])
for container in self:
container.mark_changed(key)

def mark_valid(self, key):
'''See `Container.mark_valid`'''
for container in self:
container.mark_valid(key)
container.mark_valid(key)

@property
def representation(self):
return self.containers[0].representation

@representation.setter
def representation(self, representation):
for container in self:
container.representation = representation

@property
def shape(self):
return self.containers[0].shape
Expand All @@ -249,59 +259,59 @@ def shape(self):
def size(self):
return np.product(self.shape)


class Container():

"""
Container to hold data in multiple representations
Parameters:
-----------
name : str
name of container
representation : hashable object, e.g. str or MultiDimBinning
Representation in which to initialize the container
"""

default_translation_mode = "average"
translation_modes = ("average", "sum", None)
array_representations = ("events", "log_events")


def __init__(self, name, representation='events'):

'''
Container to hold data in multiple representations
Parameters:
-----------
name : str
name of container
representation : hashable object, e.g. str or MultiDimBinning
Representation in which to initialize the container
'''

self.name = name
self._representation = None

self.linked = False
# ToDo: simple auxillary data like scalars

# ToDo: simple auxiliary data like scalars
# dict of form [variable]
self._aux_data = {}

# validity bit
# dict of form [variable][representation_hash]
self.validity = defaultdict(dict)

# translation mode
# dict of form [variable]
self.tranlation_modes = {}

# Actual data
# dict of form [representation_hash][variable]
self.data = defaultdict(dict)

# Representation objects
# dict of form [representation_hash]
self._representations = {}

# Precedence of representation (lower number = higher precedence)
# dict of form [representation_hash]
self.precedence = defaultdict(int)

self.representation = representation

def __repr__(self):
return f'Container containing keys {self.all_keys}'

Expand All @@ -311,7 +321,8 @@ def representation(self):


def set_aux_data(self, key, val):
'''Add any auxillary data, which will not be translated or tied to a specific representation'''
'''Add any auxiliary data, which will not be translated or
tied to a specific representation'''
if key in self.all_keys:
raise KeyError(f'Key {key} already exsits')

Expand Down Expand Up @@ -349,8 +360,7 @@ def size(self):
def num_dims(self):
if self.is_map:
return self.representation.num_dims
else:
return 1
return 1

@property
def representations(self):
Expand All @@ -369,6 +379,7 @@ def keys(self):

@property
def keys_incl_aux_data(self):
'''same as keys, but including auxiliary data'''
return list(self.keys) + list(self._aux_data.keys())

@property
Expand All @@ -378,7 +389,7 @@ def all_keys(self):

@property
def all_keys_incl_aux_data(self):
'''same as `all_keys`, but including auxilliary data'''
'''same as `all_keys`, but including auxiliary data'''
return self.all_keys + list(self._aux_data.keys())

@property
Expand Down Expand Up @@ -481,8 +492,7 @@ def __get_data(self, key):
else:
if key in self._aux_data.keys():
return self._aux_data[key]
else:
raise KeyError(f'Data {key} not present in Container')
raise KeyError(f'Data {key} not present in Container')

valid = self.validity[key][hash(self.representation)]
if not valid:
Expand Down Expand Up @@ -809,7 +819,7 @@ def test_container_set():

# still in events rep.
shared_aux_key = 'AmIEvil'
# add auxilliary data to both containers
# add auxiliary data to both containers
container1.set_aux_data(key=shared_aux_key, val=False)
container2.set_aux_data(key=shared_aux_key, val=True)
# get shared keys across all reps. and for the current one
Expand Down
9 changes: 3 additions & 6 deletions pisa/core/stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from collections.abc import Mapping
import inspect
from time import time
import numpy as np

from pisa.core.binning import MultiDimBinning
from pisa.core.container import Container, ContainerSet
Expand Down Expand Up @@ -154,9 +153,9 @@ def __init__(
if supported_reps is None:
supported_reps = {}
assert isinstance(supported_reps, Mapping)
if not 'calc_mode' in supported_reps:
if 'calc_mode' not in supported_reps:
supported_reps['calc_mode'] = list(Container.array_representations) + [MultiDimBinning]
if not 'apply_mode' in supported_reps:
if 'apply_mode' not in supported_reps:
supported_reps['apply_mode'] = list(Container.array_representations) + [MultiDimBinning]
self.supported_reps = supported_reps

Expand Down Expand Up @@ -351,7 +350,6 @@ def _check_exp_keys_in_data(self, error_on_missing=False):
'Service %s.%s is not specifying expected container keys.'
% (self.stage_name, self.service_name)
)
return
exp_k = set(self.expected_container_keys)
got_k = set(self.data.get_shared_keys(rep_indep=True))
missing = exp_k.difference(got_k)
Expand Down Expand Up @@ -449,6 +447,7 @@ def error_method(self):

@property
def is_map(self):
"""See ContainerSet.is_map for documentation"""
return self.data.is_map

def setup(self):
Expand Down Expand Up @@ -524,5 +523,3 @@ def apply_function(self):
def run(self):
self.compute()
self.apply()
return None

Loading

0 comments on commit 217df0c

Please sign in to comment.