Skip to content

Commit

Permalink
Fix #499 pkcli gets commands from class Commands
Browse files Browse the repository at this point in the history
- Instantiate the class first
- Ignore superclass methods
  • Loading branch information
robnagler committed Aug 23, 2024
1 parent 49d5c46 commit f1433fa
Showing 1 changed file with 36 additions and 26 deletions.
62 changes: 36 additions & 26 deletions pykern/pkcli/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,13 @@
import argparse
import importlib
import inspect
import itertools
import os
import os.path
import pkgutil
import re
import sys
import types

# Avoid pykern imports so avoid dependency issues for pkconfig
from pykern import pkconfig
Expand Down Expand Up @@ -139,6 +141,8 @@ def main(root_pkg, argv=None):
cli = _module(root_pkg, module_name)
if not cli:
return 1
if c := getattr(cli, "Commands", None):
cli = c()
prog = prog + " " + module_name
parser = CustomParser(prog)
cmds = _commands(cli)
Expand All @@ -160,8 +164,6 @@ def main(root_pkg, argv=None):
parser.error("too few arguments")
if argv[0][0] != "-":
argv[0] = _module_to_cmd(argv[0])
from pykern.pkdebug import pkdp

try:
res = argh.dispatch(parser, argv=argv)
except CommandError as e:
Expand All @@ -187,20 +189,42 @@ def _argh_name_mapping_policy():


def _commands(cli):
"""Extracts all public functions from `cli`
"""Extracts all public functions or methods from `cli`
Args:
cli (module): where commands are executed from
cli (object): where commands are executed from
Returns:
list of function: public functions sorted alphabetically
list: cmomands sorted alphabetically
"""
res = []
for n, t in inspect.getmembers(cli):
if _is_command(t, cli):
res.append(t)
sorted(res, key=lambda f: f.__name__.lower())
return res

def _exclude():
return itertools.chain(*(dir(b) for b in cli.__class__.__bases__))

def _functions():
return _iter(
lambda t: inspect.isfunction(t)
and hasattr(t, "__module__")
and t.__module__ == cli.__name__
)

def _iter(predicate):
for n, t in inspect.getmembers(cli, predicate=predicate):
if not n.startswith("_"):
yield (t)

def _methods():
x = frozenset(_exclude())
return _iter(
lambda t: inspect.ismethod(t)
and t.__name__ not in x
and t.__name__ in dir(cli)
)

return sorted(
_functions() if isinstance(cli, types.ModuleType) else _methods(),
key=lambda f: f.__name__.lower(),
)


def _default_command(cmds, argv):
Expand Down Expand Up @@ -270,24 +294,10 @@ def _imp(path_list):
return _imp(path + [name])


def _is_command(obj, cli):
"""Is this a valid command function?
Args:
obj (object): candidate
cli (module): module to which function should belong
Returns:
bool: True if obj is a valid command
"""
if not inspect.isfunction(obj) or obj.__name__.startswith("_"):
return False
return hasattr(obj, "__module__") and obj.__module__ == cli.__name__


def _is_help(argv):
"""Does the user want help?
Args:
argv (list): list of args
Expand Down

0 comments on commit f1433fa

Please sign in to comment.