Skip to content

Commit

Permalink
Add the ability to select which keys the output should have
Browse files Browse the repository at this point in the history
This allows to manually select blocks in a calculation
  • Loading branch information
hurricane642 authored and Luthaf committed Dec 1, 2022
1 parent e152794 commit 5a2c2d7
Show file tree
Hide file tree
Showing 6 changed files with 251 additions and 15 deletions.
1 change: 1 addition & 0 deletions python/rascaline/_c_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ class rascal_calculation_options_t(ctypes.Structure):
("use_native_system", ctypes.c_bool),
("selected_samples", rascal_labels_selection_t),
("selected_properties", rascal_labels_selection_t),
("selected_keys", POINTER(eqs_labels_t)),
]


Expand Down
33 changes: 24 additions & 9 deletions python/rascaline/calculators.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def _options_to_c(
use_native_system,
selected_samples,
selected_properties,
selected_keys,
):
if gradients is None:
gradients = []
Expand Down Expand Up @@ -87,6 +88,13 @@ def _options_to_c(
f"{type(selected_properties)} instead"
)

if selected_keys is None:
# nothing to do, all pointers are already NULL
pass
elif isinstance(selected_keys, Labels):
selected_keys = selected_keys._as_eqs_labels_t()
c_options.selected_keys = ctypes.pointer(selected_keys)
c_options.__keepalive["selected_keys"] = selected_keys
return c_options


Expand Down Expand Up @@ -141,6 +149,7 @@ def compute(
use_native_system: bool = True,
selected_samples: Optional[Union[Labels, TensorMap]] = None,
selected_properties: Optional[Union[Labels, TensorMap]] = None,
selected_keys: Optional[Labels] = None,
) -> TensorMap:
r"""Runs a calculation with this calculator on the given ``systems``.
Expand All @@ -155,10 +164,10 @@ def compute(
faster than having to cross the FFI boundary often when accessing
the neighbor list. Otherwise the Python neighbor list is used.
:param gradients: List of gradients to compute. If this is ``None``
or an empty list ``[]``, no gradients are computed.
Gradients are stored inside the different blocks, and can be
accessed with ``descriptor.block(...).gradient(<parameter>)``, where
:param gradients: List of gradients to compute. If this is ``None`` or
an empty list ``[]``, no gradients are computed. Gradients are
stored inside the different blocks, and can be accessed with
``descriptor.block(...).gradient(<parameter>)``, where
``<parameter>`` is ``"positions"`` or ``"cell"``. The following
gradients are available:
Expand All @@ -171,15 +180,15 @@ def compute(
\frac{\partial \langle q \vert A \rangle}
{\partial \mathbf{h}}
where :math:`\mathbf{h}` is the cell matrix and
:math:`\langle q \vert A \rangle` indicates each of the
components of the representation.
where :math:`\mathbf{h}` is the cell matrix and :math:`\langle q
\vert A \rangle` indicates each of the components of the
representation.
**Note**: When computing the virial, one often needs to evaluate
the gradient of the representation with respect to the strain
:math:`\epsilon`. To recover the typical expression from the cell
gradient one has to multiply the cell gradients with the
cell matrix :math:`\mathbf{h}`
gradient one has to multiply the cell gradients with the cell
matrix :math:`\mathbf{h}`
.. math::
-\frac{\partial \langle q \vert A \rangle}
Expand Down Expand Up @@ -223,6 +232,11 @@ def compute(
properties, then only properties from the default set with the same
values for these variables as one of the entries in
``selected_properties`` will be used.
:param selected_keys: Selection for the keys to include in the output.
If this is ``None``, the default set of keys (as determined by the
calculator) will be used. Note that this default set of keys can
depend on which systems we are running the calculation on.
"""

c_systems = _convert_systems(systems)
Expand All @@ -233,6 +247,7 @@ def compute(
use_native_system=use_native_system,
selected_samples=selected_samples,
selected_properties=selected_properties,
selected_keys=selected_keys,
)
self._lib.rascal_calculator_compute(
self, tensor_map_ptr, c_systems, c_systems._length_, c_options
Expand Down
166 changes: 166 additions & 0 deletions python/tests/calculators.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,6 +419,172 @@ def test_errors(self):
)


class TestComputeSelectedKeys(unittest.TestCase):
def test_selection_existing(self):
system = TestSystem()
calculator = DummyCalculator(cutoff=3.2, delta=2, name="")

# Manually select the keys
selected_keys = Labels(
names=["species_center"],
values=np.array([[1]], dtype=np.int32),
)
descriptor = calculator.compute(
system, use_native_system=False, selected_keys=selected_keys
)

self.assertEqual(len(descriptor.keys), 1)
self.assertEqual(tuple(descriptor.keys[0]), (1,))

def test_select_key_not_in_systems(self):
system = TestSystem()
calculator = DummyCalculator(cutoff=3.2, delta=2, name="")

# Manually select the keys
selected_keys = Labels(
names=["species_center"],
values=np.array([[4]], dtype=np.int32),
)
descriptor = calculator.compute(
system, use_native_system=False, selected_keys=selected_keys
)

C_block = descriptor.block(species_center=4)
self.assertEqual(C_block.values.shape, (0, 2))

def test_predefined_selection(self):
system = TestSystem()
calculator = DummyCalculator(cutoff=3.2, delta=2, name="")

selected_keys = Labels(
names=["species_center"],
values=np.array([[1]], dtype=np.int32),
)

keys = Labels(
names=["species_center"],
values=np.array([[1], [8]], dtype=np.int32),
)

# selection from TensorMap
selected = [
Labels(
names=["index_delta", "x_y_z"],
values=np.array([[1, 0]], dtype=np.int32),
),
Labels(
names=["index_delta", "x_y_z"],
values=np.array([[0, 1]], dtype=np.int32),
),
]
selected_properties = _tensor_map_selection("properties", keys, selected)

descriptor = calculator.compute(
system,
use_native_system=False,
selected_properties=selected_properties,
selected_keys=selected_keys,
)

self.assertEqual(len(descriptor.keys), 1)
H_block = descriptor.block(species_center=1)
self.assertEqual(H_block.values.shape, (2, 1))
self.assertTrue(np.all(H_block.values[0] == (2,)))
self.assertTrue(np.all(H_block.values[1] == (3,)))

def test_name_errors(self):
system = TestSystem()
calculator = DummyCalculator(cutoff=3.2, delta=2, name="")

selected_keys = Labels(
names=["bad name"],
values=np.array([0, 3, 1], dtype=np.int32).reshape(3, 1),
)

with self.assertRaises(RascalError) as cm:
calculator.compute(
system, use_native_system=False, selected_keys=selected_keys
)

self.assertEqual(
str(cm.exception),
"invalid parameter: 'bad name' is not a valid label name",
)

selected_keys = Labels(
names=["bad_name"],
values=np.array([0, 3, 1], dtype=np.int32).reshape(3, 1),
)

with self.assertRaises(RascalError) as cm:
calculator.compute(
system, use_native_system=False, selected_keys=selected_keys
)

self.assertEqual(
str(cm.exception),
"invalid parameter: names for the keys of the calculator "
"[species_center] and selected keys [bad_name] do not match",
)

def test_key_errors(self):
system = TestSystem()
calculator = DummyCalculator(cutoff=3.2, delta=2, name="")

selected_keys = Labels(
names=["species_center"],
values=np.empty((0, 1), dtype=np.int32),
)

with self.assertRaises(RascalError) as cm:
calculator.compute(
system, use_native_system=False, selected_keys=selected_keys
)

self.assertEqual(
str(cm.exception),
"invalid parameter: selected keys can not be empty",
)

# in the case of selected_properies/selected_samples and selected_keys
# the selected keys must be in the keys of the predefined tensor_map
selected_keys = Labels(
names=["species_center"],
values=np.array([[4]], dtype=np.int32),
)

keys = Labels(
names=["species_center"],
values=np.array([[1], [8]], dtype=np.int32),
)

# selection from TensorMap
selected = [
Labels(
names=["index_delta", "x_y_z"],
values=np.array([[1, 0]], dtype=np.int32),
),
Labels(
names=["index_delta", "x_y_z"],
values=np.array([[0, 1]], dtype=np.int32),
),
]
selected_properties = _tensor_map_selection("properties", keys, selected)

with self.assertRaises(RascalError) as cm:
calculator.compute(
system,
use_native_system=False,
selected_properties=selected_properties,
selected_keys=selected_keys,
)

self.assertEqual(
str(cm.exception),
"invalid parameter: expected a key [4] in predefined properties selection",
)


class TestSortedDistances(unittest.TestCase):
def test_name(self):
calculator = SortedDistances(
Expand Down
7 changes: 7 additions & 0 deletions rascaline-c-api/include/rascaline.h
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,13 @@ typedef struct rascal_calculation_options_t {
* Selection of properties to compute for the samples
*/
struct rascal_labels_selection_t selected_properties;
/**
* Selection for the keys to include in the output. Set this parameter to
* `NULL` to use the default set of keys, as determined by the calculator.
* Note that this default set of keys can depend on which systems we are
* running the calculation on.
*/
const eqs_labels_t *selected_keys;
} rascal_calculation_options_t;

#ifdef __cplusplus
Expand Down
19 changes: 19 additions & 0 deletions rascaline-c-api/src/calculator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,18 @@ fn convert_labels_selection<'a>(
}
}

fn key_selection(value: *const eqs_labels_t, labels: &'_ mut Option<Labels>) -> Result<Option<&'_ Labels>, rascaline::Error> {
if value.is_null() {
return Ok(None);
}

unsafe {
*labels = Some(Labels::try_from(&*value)?);
}

return Ok(labels.as_ref());
}

/// Options that can be set to change how a calculator operates.
#[repr(C)]
pub struct rascal_calculation_options_t {
Expand Down Expand Up @@ -251,6 +263,11 @@ pub struct rascal_calculation_options_t {
selected_samples: rascal_labels_selection_t,
/// Selection of properties to compute for the samples
selected_properties: rascal_labels_selection_t,
/// Selection for the keys to include in the output. Set this parameter to
/// `NULL` to use the default set of keys, as determined by the calculator.
/// Note that this default set of keys can depend on which systems we are
/// running the calculation on.
selected_keys: *const eqs_labels_t,
}

#[allow(clippy::doc_markdown)]
Expand Down Expand Up @@ -300,12 +317,14 @@ pub unsafe extern fn rascal_calculator_compute(

let mut selected_samples = None;
let mut selected_properties = None;
let mut selected_keys = None;

let rust_options = CalculationOptions {
gradients: &gradients,
use_native_system: options.use_native_system,
selected_samples: convert_labels_selection(&options.selected_samples, &mut selected_samples)?,
selected_properties: convert_labels_selection(&options.selected_properties, &mut selected_properties)?,
selected_keys: key_selection(options.selected_keys, &mut selected_keys)?,
};

let tensor = (*calculator).compute(&mut systems, rust_options)?;
Expand Down
40 changes: 34 additions & 6 deletions rascaline/src/calculator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,15 @@ impl<'a> LabelsSelection<'a> {
tensor.keys().names().join(", ")
)));
}
for key in keys.iter() {
if !tensor.keys().contains(key){
return Err(Error::InvalidParameter(format!(
"expected a key [{}] in predefined {} selection",
key.iter().map(|v| v.to_string()).collect::<Vec<_>>().join(", "),
label_kind,
)));
}
}
let default_names = get_default_names();

let mut results = Vec::new();
Expand Down Expand Up @@ -185,6 +194,11 @@ pub struct CalculationOptions<'a> {
pub selected_samples: LabelsSelection<'a>,
/// Selection of properties to compute for the samples
pub selected_properties: LabelsSelection<'a>,
/// Selection for the keys to include in the output. If this is `None`, the
/// default set of keys (as determined by the calculator) will be used. Note
/// that this default set of keys can depend on which systems we are running
/// the calculation on.
pub selected_keys: Option<&'a Labels>,
}

impl<'a> Default for CalculationOptions<'a> {
Expand All @@ -194,6 +208,7 @@ impl<'a> Default for CalculationOptions<'a> {
use_native_system: false,
selected_samples: LabelsSelection::All,
selected_properties: LabelsSelection::All,
selected_keys: None,
}
}
}
Expand Down Expand Up @@ -245,15 +260,28 @@ impl Calculator {
&self.parameters
}

/// Get the set of keys this calculator would produce for the given systems
pub fn default_keys(&self, systems: &mut [Box<dyn System>]) -> Result<Labels, Error> {
self.implementation.keys(systems)
}

#[time_graph::instrument(name="Calculator::prepare")]
fn prepare(&mut self, systems: &mut [Box<dyn System>], options: CalculationOptions,) -> Result<TensorMap, Error> {
// TODO: allow selecting a subset of keys?
let keys = self.implementation.keys(systems)?;

let default_keys = self.implementation.keys(systems)?;
let keys = match options.selected_keys {
Some(keys) if keys.is_empty() => {
return Err(Error::InvalidParameter("selected keys can not be empty".into()));
}
Some(keys) => {
if default_keys.names() == keys.names() {
keys.clone()
} else {
return Err(Error::InvalidParameter(format!(
"names for the keys of the calculator [{}] and selected keys [{}] do not match",
default_keys.names().join(", "),
keys.names().join(", "))
));
}
}
None => default_keys,
};

let samples = options.selected_samples.select(
"samples",
Expand Down

0 comments on commit 5a2c2d7

Please sign in to comment.