diff --git a/absl/flags/__init__.py b/absl/flags/__init__.py index 21e05c47..bf46b3db 100644 --- a/absl/flags/__init__.py +++ b/absl/flags/__init__.py @@ -59,6 +59,7 @@ 'DEFINE_multi_enum', 'DEFINE_multi_enum_class', 'DEFINE_alias', + 'DEFINE_doc', # Flag validators. 'register_validator', 'validator', @@ -143,6 +144,7 @@ DEFINE_multi_enum = _defines.DEFINE_multi_enum DEFINE_multi_enum_class = _defines.DEFINE_multi_enum_class DEFINE_alias = _defines.DEFINE_alias +DEFINE_doc = _defines.DEFINE_doc # pylint: enable=invalid-name # Flag validators. diff --git a/absl/flags/_defines.py b/absl/flags/_defines.py index c7b102f2..1655439c 100644 --- a/absl/flags/_defines.py +++ b/absl/flags/_defines.py @@ -1684,3 +1684,11 @@ def value(self, value): flag.default, help_msg, boolean=flag.boolean), flag_values, module_name) + + +def DEFINE_doc( # pylint: disable=invalid-name + doc: Text, + flag_values: _flagvalues.FlagValues = _flagvalues.FLAGS, +) -> None: + module = _helpers.get_calling_module() + flag_values.register_doc_by_module(module, doc) diff --git a/absl/flags/_flagvalues.py b/absl/flags/_flagvalues.py index e25f1d3e..92b8843b 100644 --- a/absl/flags/_flagvalues.py +++ b/absl/flags/_flagvalues.py @@ -102,6 +102,9 @@ def __init__(self): # Dictionary: module id (int) -> list of Flag objects that are defined by # that module. self.__dict__['__flags_by_module_id'] = {} + # Dictionary: module name (string) -> string that is the document defined by + # that module. + self.__dict__['__doc_by_module'] = {} # Dictionary: module name (string) -> list of Flag objects that are # key for that module. self.__dict__['__key_flags_by_module'] = {} @@ -167,6 +170,15 @@ def flags_by_module_id_dict(self) -> Dict[int, List[Flag]]: """ return self.__dict__['__flags_by_module_id'] + def doc_by_module_dict(self) -> Dict[Text, Text]: + """Returns the dictionary of module_name -> document string. + + Returns: + A dictionary. Its keys are module names (strings). Its values + is document (string). + """ + return self.__dict__['__doc_by_module'] + def key_flags_by_module_dict(self) -> Dict[Text, List[Flag]]: """Returns the dictionary of module_name -> list of key flags. @@ -199,6 +211,16 @@ def register_flag_by_module_id(self, module_id: int, flag: Flag) -> None: flags_by_module_id = self.flags_by_module_id_dict() flags_by_module_id.setdefault(module_id, []).append(flag) + def register_doc_by_module(self, module_name: Text, doc: Text) -> None: + """Records the module that has a specific doc. + + Args: + module_name: str, the name of a Python module. + doc: str, the document of the module. + """ + doc_by_module = self.doc_by_module_dict() + doc_by_module[module_name] = doc + def register_key_flag_for_module(self, module_name: Text, flag: Flag) -> None: """Specifies that a flag is a key flag for a module. @@ -273,6 +295,22 @@ def get_flags_for_module(self, module: Union[Text, Any]) -> List[Flag]: return list(self.flags_by_module_dict().get(module, [])) + def get_doc_for_module(self, module: Union[Text, Any]) -> Text: + """Returns the doc defined by a module. + + Args: + module: module|str, the module to get document from. + + Returns: + str, document of the module. + """ + if not isinstance(module, str): + module = module.__name__ + if module == '__main__': + module = sys.argv[0] + + return self.doc_by_module_dict().get(module, '') + def get_key_flags_for_module(self, module: Union[Text, Any]) -> List[Flag]: """Returns the list of key flags for a module. @@ -947,6 +985,12 @@ def _render_module_flags(self, module, flags, output_lines, prefix=''): if not isinstance(module, str): module = module.__name__ output_lines.append('\n%s%s:' % (prefix, module)) + doc = self.get_doc_for_module(module) + if doc: + doc = _helpers.text_wrap( + doc, indent=prefix + ' ', firstline_indent=prefix + ) + output_lines.append('%s\n' % (doc)) self._render_flag_list(flags, output_lines, prefix + ' ') def _render_our_module_flags(self, module, output_lines, prefix=''): diff --git a/absl/flags/tests/_flagvalues_test.py b/absl/flags/tests/_flagvalues_test.py index 09071d7e..f3b3b84d 100644 --- a/absl/flags/tests/_flagvalues_test.py +++ b/absl/flags/tests/_flagvalues_test.py @@ -397,7 +397,8 @@ def test_get_help(self): (default: '')''', fv.get_help()) module_foo.define_flags(fv) - self.assertMultiLineEqual(''' + self.assertMultiLineEqual( + """ absl.flags.tests.module_bar: --tmod_bar_t: Sample int flag. (default: '4') @@ -416,6 +417,8 @@ def test_get_help(self): (default: 'false') absl.flags.tests.module_foo: +test module foo + --[no]tmod_foo_bool: Boolean flag from module foo. (default: 'true') --tmod_foo_int: Sample int flag. @@ -431,9 +434,12 @@ def test_get_help(self): the command line even if the program does not define a flag with that name. IMPORTANT: flags in this list that have arguments MUST use the --flag=value format. - (default: '')''', fv.get_help()) + (default: '')""", + fv.get_help(), + ) - self.assertMultiLineEqual(''' + self.assertMultiLineEqual( + """ xxxxabsl.flags.tests.module_bar: xxxx --tmod_bar_t: Sample int flag. xxxx (default: '4') @@ -452,6 +458,8 @@ def test_get_help(self): xxxx (default: 'false') xxxxabsl.flags.tests.module_foo: +xxxxtest module foo + xxxx --[no]tmod_foo_bool: Boolean flag from module foo. xxxx (default: 'true') xxxx --tmod_foo_int: Sample int flag. @@ -468,9 +476,12 @@ def test_get_help(self): xxxx on the command line even if the program does not define a flag with that xxxx name. IMPORTANT: flags in this list that have arguments MUST use the xxxx --flag=value format. -xxxx (default: '')''', fv.get_help(prefix='xxxx')) +xxxx (default: '')""", + fv.get_help(prefix='xxxx'), + ) - self.assertMultiLineEqual(''' + self.assertMultiLineEqual( + """ absl.flags.tests.module_bar: --tmod_bar_t: Sample int flag. (default: '4') @@ -489,13 +500,17 @@ def test_get_help(self): (default: 'false') absl.flags.tests.module_foo: +test module foo + --[no]tmod_foo_bool: Boolean flag from module foo. (default: 'true') --tmod_foo_int: Sample int flag. (default: '3') (an integer) --tmod_foo_str: String flag. - (default: 'default')''', fv.get_help(include_special_flags=False)) + (default: 'default')""", + fv.get_help(include_special_flags=False), + ) def test_str(self): fv = _flagvalues.FlagValues() diff --git a/absl/flags/tests/module_foo.py b/absl/flags/tests/module_foo.py index 649047c1..94f13979 100644 --- a/absl/flags/tests/module_foo.py +++ b/absl/flags/tests/module_foo.py @@ -42,6 +42,7 @@ def define_flags(flag_values=FLAGS): flag_values=flag_values) flags.DEFINE_integer('tmod_foo_int', 3, 'Sample int flag.', flag_values=flag_values) + flags.DEFINE_doc('test module foo', flag_values=flag_values) def declare_key_flags(flag_values=FLAGS):