Skip to content

Commit

Permalink
Add DEFINE_doc to support module level document
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 583230711
  • Loading branch information
Abseil Team authored and copybara-github committed Nov 17, 2023
1 parent 6929bf0 commit 3d4c55d
Show file tree
Hide file tree
Showing 5 changed files with 76 additions and 6 deletions.
2 changes: 2 additions & 0 deletions absl/flags/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
'DEFINE_multi_enum',
'DEFINE_multi_enum_class',
'DEFINE_alias',
'DEFINE_doc',
# Flag validators.
'register_validator',
'validator',
Expand Down Expand Up @@ -143,6 +144,7 @@
DEFINE_multi_enum = _defines.DEFINE_multi_enum
DEFINE_multi_enum_class = _defines.DEFINE_multi_enum_class
DEFINE_alias = _defines.DEFINE_alias
DEFINE_doc = _defines.DEFINE_doc
# pylint: enable=invalid-name

# Flag validators.
Expand Down
8 changes: 8 additions & 0 deletions absl/flags/_defines.py
Original file line number Diff line number Diff line change
Expand Up @@ -1684,3 +1684,11 @@ def value(self, value):
flag.default,
help_msg,
boolean=flag.boolean), flag_values, module_name)


def DEFINE_doc( # pylint: disable=invalid-name
doc: Text,
flag_values: _flagvalues.FlagValues = _flagvalues.FLAGS,
) -> None:
module = _helpers.get_calling_module()
flag_values.register_doc_by_module(module, doc)
44 changes: 44 additions & 0 deletions absl/flags/_flagvalues.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,9 @@ def __init__(self):
# Dictionary: module id (int) -> list of Flag objects that are defined by
# that module.
self.__dict__['__flags_by_module_id'] = {}
# Dictionary: module name (string) -> string that is the document defined by
# that module.
self.__dict__['__doc_by_module'] = {}
# Dictionary: module name (string) -> list of Flag objects that are
# key for that module.
self.__dict__['__key_flags_by_module'] = {}
Expand Down Expand Up @@ -167,6 +170,15 @@ def flags_by_module_id_dict(self) -> Dict[int, List[Flag]]:
"""
return self.__dict__['__flags_by_module_id']

def doc_by_module_dict(self) -> Dict[Text, Text]:
"""Returns the dictionary of module_name -> document string.
Returns:
A dictionary. Its keys are module names (strings). Its values
is document (string).
"""
return self.__dict__['__doc_by_module']

def key_flags_by_module_dict(self) -> Dict[Text, List[Flag]]:
"""Returns the dictionary of module_name -> list of key flags.
Expand Down Expand Up @@ -199,6 +211,16 @@ def register_flag_by_module_id(self, module_id: int, flag: Flag) -> None:
flags_by_module_id = self.flags_by_module_id_dict()
flags_by_module_id.setdefault(module_id, []).append(flag)

def register_doc_by_module(self, module_name: Text, doc: Text) -> None:
"""Records the module that has a specific doc.
Args:
module_name: str, the name of a Python module.
doc: str, the document of the module.
"""
doc_by_module = self.doc_by_module_dict()
doc_by_module[module_name] = doc

def register_key_flag_for_module(self, module_name: Text, flag: Flag) -> None:
"""Specifies that a flag is a key flag for a module.
Expand Down Expand Up @@ -273,6 +295,22 @@ def get_flags_for_module(self, module: Union[Text, Any]) -> List[Flag]:

return list(self.flags_by_module_dict().get(module, []))

def get_doc_for_module(self, module: Union[Text, Any]) -> Text:
"""Returns the doc defined by a module.
Args:
module: module|str, the module to get document from.
Returns:
str, document of the module.
"""
if not isinstance(module, str):
module = module.__name__
if module == '__main__':
module = sys.argv[0]

return self.doc_by_module_dict().get(module, '')

def get_key_flags_for_module(self, module: Union[Text, Any]) -> List[Flag]:
"""Returns the list of key flags for a module.
Expand Down Expand Up @@ -947,6 +985,12 @@ def _render_module_flags(self, module, flags, output_lines, prefix=''):
if not isinstance(module, str):
module = module.__name__
output_lines.append('\n%s%s:' % (prefix, module))
doc = self.get_doc_for_module(module)
if doc:
doc = _helpers.text_wrap(
doc, indent=prefix + ' ', firstline_indent=prefix
)
output_lines.append('%s\n' % (doc))
self._render_flag_list(flags, output_lines, prefix + ' ')

def _render_our_module_flags(self, module, output_lines, prefix=''):
Expand Down
27 changes: 21 additions & 6 deletions absl/flags/tests/_flagvalues_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,7 +397,8 @@ def test_get_help(self):
(default: '')''', fv.get_help())

module_foo.define_flags(fv)
self.assertMultiLineEqual('''
self.assertMultiLineEqual(
"""
absl.flags.tests.module_bar:
--tmod_bar_t: Sample int flag.
(default: '4')
Expand All @@ -416,6 +417,8 @@ def test_get_help(self):
(default: 'false')
absl.flags.tests.module_foo:
test module foo
--[no]tmod_foo_bool: Boolean flag from module foo.
(default: 'true')
--tmod_foo_int: Sample int flag.
Expand All @@ -431,9 +434,12 @@ def test_get_help(self):
the command line even if the program does not define a flag with that name.
IMPORTANT: flags in this list that have arguments MUST use the --flag=value
format.
(default: '')''', fv.get_help())
(default: '')""",
fv.get_help(),
)

self.assertMultiLineEqual('''
self.assertMultiLineEqual(
"""
xxxxabsl.flags.tests.module_bar:
xxxx --tmod_bar_t: Sample int flag.
xxxx (default: '4')
Expand All @@ -452,6 +458,8 @@ def test_get_help(self):
xxxx (default: 'false')
xxxxabsl.flags.tests.module_foo:
xxxxtest module foo
xxxx --[no]tmod_foo_bool: Boolean flag from module foo.
xxxx (default: 'true')
xxxx --tmod_foo_int: Sample int flag.
Expand All @@ -468,9 +476,12 @@ def test_get_help(self):
xxxx on the command line even if the program does not define a flag with that
xxxx name. IMPORTANT: flags in this list that have arguments MUST use the
xxxx --flag=value format.
xxxx (default: '')''', fv.get_help(prefix='xxxx'))
xxxx (default: '')""",
fv.get_help(prefix='xxxx'),
)

self.assertMultiLineEqual('''
self.assertMultiLineEqual(
"""
absl.flags.tests.module_bar:
--tmod_bar_t: Sample int flag.
(default: '4')
Expand All @@ -489,13 +500,17 @@ def test_get_help(self):
(default: 'false')
absl.flags.tests.module_foo:
test module foo
--[no]tmod_foo_bool: Boolean flag from module foo.
(default: 'true')
--tmod_foo_int: Sample int flag.
(default: '3')
(an integer)
--tmod_foo_str: String flag.
(default: 'default')''', fv.get_help(include_special_flags=False))
(default: 'default')""",
fv.get_help(include_special_flags=False),
)

def test_str(self):
fv = _flagvalues.FlagValues()
Expand Down
1 change: 1 addition & 0 deletions absl/flags/tests/module_foo.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def define_flags(flag_values=FLAGS):
flag_values=flag_values)
flags.DEFINE_integer('tmod_foo_int', 3, 'Sample int flag.',
flag_values=flag_values)
flags.DEFINE_doc('test module foo', flag_values=flag_values)


def declare_key_flags(flag_values=FLAGS):
Expand Down

0 comments on commit 3d4c55d

Please sign in to comment.