diff --git a/sdb/command.py b/sdb/command.py index 9d5e09df..e9a97736 100644 --- a/sdb/command.py +++ b/sdb/command.py @@ -24,7 +24,7 @@ import argparse import inspect import textwrap -from typing import Any, Callable, Dict, Iterable, List, Optional, Type, TypeVar +from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Type, TypeVar import drgn @@ -244,6 +244,31 @@ def _call(self, """ raise NotImplementedError() + def _valid_input_types(self) -> Set[str]: + """ + Returns a set of strings which are the canonicalized names of valid input types + for this command + """ + assert self.input_type is not None + return {type_canonicalize_name(self.input_type)} + + def __input_type_check( + self, objs: Iterable[drgn.Object]) -> Iterable[drgn.Object]: + valid_input_types = self._valid_input_types() + print("valid types: ") + print(valid_input_types) + prev_type = None + for obj in objs: + cur_type = type_canonical_name(obj.type_) + if cur_type not in valid_input_types or (prev_type and + cur_type != prev_type): + raise CommandError( + self.name, + f'expected input of type {self.input_type}, but received ' + f'type {obj.type_}') + prev_type = cur_type + yield obj + def __invalid_memory_objects_check(self, objs: Iterable[drgn.Object], fatal: bool) -> Iterable[drgn.Object]: """ @@ -281,7 +306,11 @@ def call(self, objs: Iterable[drgn.Object]) -> Iterable[drgn.Object]: # the command is running. # try: - result = self._call(objs) + if self.input_type and objs: + result = self._call(self.__input_type_check(objs)) + else: + result = self._call(objs) + if result is not None: # # The whole point of the SingleInputCommands are that @@ -289,7 +318,7 @@ def call(self, objs: Iterable[drgn.Object]) -> Iterable[drgn.Object]: # a bad dereference. That's why we check here whether # the command that we are running is a subclass of # SingleInputCommand and we set the `fatal` flag - # accordinly. + # accordingly. # yield from self.__invalid_memory_objects_check( result, not issubclass(self.__class__, SingleInputCommand)) @@ -634,22 +663,6 @@ def pretty_print(self, objs: Iterable[drgn.Object]) -> None: # pylint: disable=missing-docstring raise NotImplementedError - def check_input_type(self, - objs: Iterable[drgn.Object]) -> Iterable[drgn.Object]: - """ - This function acts as a generator, checking that each passed object - matches the input type for the command - """ - assert self.input_type is not None - type_name = type_canonicalize_name(self.input_type) - for obj in objs: - if type_canonical_name(obj.type_) != type_name: - raise CommandError( - self.name, - f'expected input of type {self.input_type}, but received ' - f'type {obj.type_}') - yield obj - def _call( # type: ignore[return] self, objs: Iterable[drgn.Object]) -> Optional[Iterable[drgn.Object]]: @@ -658,7 +671,7 @@ def _call( # type: ignore[return] verifying the types as we go. """ assert self.input_type is not None - self.pretty_print(self.check_input_type(objs)) + self.pretty_print(objs) class Locator(Command): @@ -673,6 +686,25 @@ class Locator(Command): output_type: Optional[str] = None + def _valid_input_types(self) -> Set[str]: + """ + Some Locators support multiple input types. Check for InputHandler + implementations to expand the set of valid input types. + """ + valid_types = [type_canonicalize_name(self.input_type)] + + for (_, method) in inspect.getmembers(self, inspect.ismethod): + if hasattr(method, "input_typename_handled"): + valid_types.append( + type_canonicalize_name(method.input_typename_handled)) + + valid_types += [ + type_canonicalize_name(type_) + for type_, class_ in Walker.allWalkers.items() + ] + + return set(valid_types) + def no_input(self) -> Iterable[drgn.Object]: # pylint: disable=missing-docstring raise CommandError(self.name, 'command requires an input') diff --git a/sdb/commands/pretty_print.py b/sdb/commands/pretty_print.py index bcf23c0f..26cedf91 100644 --- a/sdb/commands/pretty_print.py +++ b/sdb/commands/pretty_print.py @@ -23,7 +23,6 @@ class PrettyPrint(sdb.Command): - names = ["pretty_print", "pp"] def _call(self, objs: Iterable[drgn.Object]) -> None: diff --git a/tests/integration/data/regression_output/core/addr spa_namespace_avl | spa b/tests/integration/data/regression_output/core/addr spa_namespace_avl | spa new file mode 100644 index 00000000..921f0052 --- /dev/null +++ b/tests/integration/data/regression_output/core/addr spa_namespace_avl | spa @@ -0,0 +1,5 @@ +ADDR NAME +------------------------------------------------------------ +0xffffa0894e720000 data +0xffffa089413b8000 meta-domain +0xffffa08955c44000 rpool diff --git a/tests/integration/data/regression_output/linux/addr spa_namespace_avl | cpu_counter_sum b/tests/integration/data/regression_output/linux/addr spa_namespace_avl | cpu_counter_sum index 8faabf08..0de4ee78 100644 --- a/tests/integration/data/regression_output/linux/addr spa_namespace_avl | cpu_counter_sum +++ b/tests/integration/data/regression_output/linux/addr spa_namespace_avl | cpu_counter_sum @@ -1 +1 @@ -sdb: cpu_counter_sum: input is not a percpu_counter +sdb: cpu_counter_sum: expected input of type struct percpu_counter *, but received type avl_tree_t * diff --git a/tests/integration/test_core_generic.py b/tests/integration/test_core_generic.py index 45018c73..4429a63f 100644 --- a/tests/integration/test_core_generic.py +++ b/tests/integration/test_core_generic.py @@ -74,6 +74,9 @@ # locator that receives no input from a filter 'thread | filter obj.comm == \"bogus\" | thread', + # input type checking - locators can use walkers + 'addr spa_namespace_avl | spa', + # member - generic "member no_object", "addr spa_namespace_avl | member avl_root->avl_child[0]->avl_child",