Skip to content

Commit

Permalink
Add input type-checking to Commands
Browse files Browse the repository at this point in the history
  • Loading branch information
sara hartse committed Mar 10, 2020
1 parent 0ea73b6 commit f304dc8
Show file tree
Hide file tree
Showing 5 changed files with 60 additions and 22 deletions.
71 changes: 51 additions & 20 deletions sdb/command.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import argparse
import inspect
import textwrap
from typing import Any, Callable, Dict, Iterable, List, Optional, Type, TypeVar
from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Type, TypeVar

import drgn

Expand Down Expand Up @@ -244,6 +244,29 @@ def _call(self,
"""
raise NotImplementedError()

def _valid_input_types(self) -> Set[str]:
"""
Returns a set of strings which are the canonicalized names of valid input types
for this command
"""
assert self.input_type is not None
return {type_canonicalize_name(self.input_type)}

def __input_type_check(
self, objs: Iterable[drgn.Object]) -> Iterable[drgn.Object]:
valid_input_types = self._valid_input_types()
prev_type = None
for obj in objs:
cur_type = type_canonical_name(obj.type_)
if cur_type not in valid_input_types or (prev_type and
cur_type != prev_type):
raise CommandError(
self.name,
f'expected input of type {self.input_type}, but received '
f'type {obj.type_}')
prev_type = cur_type
yield obj

def __invalid_memory_objects_check(self, objs: Iterable[drgn.Object],
fatal: bool) -> Iterable[drgn.Object]:
"""
Expand Down Expand Up @@ -281,15 +304,19 @@ def call(self, objs: Iterable[drgn.Object]) -> Iterable[drgn.Object]:
# the command is running.
#
try:
result = self._call(objs)
if self.input_type and objs:
result = self._call(self.__input_type_check(objs))
else:
result = self._call(objs)

if result is not None:
#
# The whole point of the SingleInputCommands are that
# they don't stop executing in the first encounter of
# a bad dereference. That's why we check here whether
# the command that we are running is a subclass of
# SingleInputCommand and we set the `fatal` flag
# accordinly.
# accordingly.
#
yield from self.__invalid_memory_objects_check(
result, not issubclass(self.__class__, SingleInputCommand))
Expand Down Expand Up @@ -634,22 +661,6 @@ def pretty_print(self, objs: Iterable[drgn.Object]) -> None:
# pylint: disable=missing-docstring
raise NotImplementedError

def check_input_type(self,
objs: Iterable[drgn.Object]) -> Iterable[drgn.Object]:
"""
This function acts as a generator, checking that each passed object
matches the input type for the command
"""
assert self.input_type is not None
type_name = type_canonicalize_name(self.input_type)
for obj in objs:
if type_canonical_name(obj.type_) != type_name:
raise CommandError(
self.name,
f'expected input of type {self.input_type}, but received '
f'type {obj.type_}')
yield obj

def _call( # type: ignore[return]
self,
objs: Iterable[drgn.Object]) -> Optional[Iterable[drgn.Object]]:
Expand All @@ -658,7 +669,7 @@ def _call( # type: ignore[return]
verifying the types as we go.
"""
assert self.input_type is not None
self.pretty_print(self.check_input_type(objs))
self.pretty_print(objs)


class Locator(Command):
Expand All @@ -673,6 +684,26 @@ class Locator(Command):

output_type: Optional[str] = None

def _valid_input_types(self) -> Set[str]:
"""
Some Locators support multiple input types. Check for InputHandler
implementations to expand the set of valid input types.
"""
assert self.input_type is not None
valid_types = [type_canonicalize_name(self.input_type)]

for (_, method) in inspect.getmembers(self, inspect.ismethod):
if hasattr(method, "input_typename_handled"):
valid_types.append(
type_canonicalize_name(method.input_typename_handled))

valid_types += [
type_canonicalize_name(type_)
for type_, class_ in Walker.allWalkers.items()
]

return set(valid_types)

def no_input(self) -> Iterable[drgn.Object]:
# pylint: disable=missing-docstring
raise CommandError(self.name, 'command requires an input')
Expand Down
1 change: 0 additions & 1 deletion sdb/commands/pretty_print.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@


class PrettyPrint(sdb.Command):

names = ["pretty_print", "pp"]

def _call(self, objs: Iterable[drgn.Object]) -> None:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
ADDR NAME
------------------------------------------------------------
0xffffa0894e720000 data
0xffffa089413b8000 meta-domain
0xffffa08955c44000 rpool
Original file line number Diff line number Diff line change
@@ -1 +1 @@
sdb: cpu_counter_sum: input is not a percpu_counter
sdb: cpu_counter_sum: expected input of type struct percpu_counter *, but received type avl_tree_t *
3 changes: 3 additions & 0 deletions tests/integration/test_core_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,9 @@
# locator that receives no input from a filter
'thread | filter obj.comm == \"bogus\" | thread',

# input type checking - locators can use walkers
'addr spa_namespace_avl | spa',

# member - generic
"member no_object",
"addr spa_namespace_avl | member avl_root->avl_child[0]->avl_child",
Expand Down

0 comments on commit f304dc8

Please sign in to comment.