Skip to content

Commit

Permalink
Fix parsing error when handling validation expressions with inequality
Browse files Browse the repository at this point in the history
Currently, when there an expression such as "a >= 1", the parser will think that the operator is ">" and the right hand size is "= 1". However, the operator should be ">=".

PiperOrigin-RevId: 592669272
  • Loading branch information
LouYu2015 authored and tensorflower-gardener committed Dec 20, 2023
1 parent b092458 commit e923b8a
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 10 deletions.
20 changes: 10 additions & 10 deletions official/modeling/hyperparams/params_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,24 +297,17 @@ def _get_kvs(tokens, params_dict):
raise KeyError(
'Found inconsistency between key `{}` and key `{}`.'.format(
tokens[0], tokens[1]))
elif '<' in restriction:
tokens = restriction.split('<')
_, left_v, _, right_v = _get_kvs(tokens, params_dict)
if left_v >= right_v:
raise KeyError(
'Found inconsistency between key `{}` and key `{}`.'.format(
tokens[0], tokens[1]))
elif '<=' in restriction:
tokens = restriction.split('<=')
_, left_v, _, right_v = _get_kvs(tokens, params_dict)
if left_v > right_v:
raise KeyError(
'Found inconsistency between key `{}` and key `{}`.'.format(
tokens[0], tokens[1]))
elif '>' in restriction:
tokens = restriction.split('>')
elif '<' in restriction:
tokens = restriction.split('<')
_, left_v, _, right_v = _get_kvs(tokens, params_dict)
if left_v <= right_v:
if left_v >= right_v:
raise KeyError(
'Found inconsistency between key `{}` and key `{}`.'.format(
tokens[0], tokens[1]))
Expand All @@ -325,6 +318,13 @@ def _get_kvs(tokens, params_dict):
raise KeyError(
'Found inconsistency between key `{}` and key `{}`.'.format(
tokens[0], tokens[1]))
elif '>' in restriction:
tokens = restriction.split('>')
_, left_v, _, right_v = _get_kvs(tokens, params_dict)
if left_v <= right_v:
raise KeyError(
'Found inconsistency between key `{}` and key `{}`.'.format(
tokens[0], tokens[1]))
else:
raise ValueError('Unsupported relation in restriction.')

Expand Down
5 changes: 5 additions & 0 deletions official/modeling/hyperparams/params_dict_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ def test_validate(self):
'a': 10
}
}, ['b == c'])
params.validate()

# Raise error due to inconsistency
with self.assertRaises(KeyError):
Expand Down Expand Up @@ -198,6 +199,10 @@ def test_validate(self):
}, ['a == None', 'c.a == 1'])
params.validate()

# Valid restrictions with inequality.
params = params_dict.ParamsDict({'a': 1}, ['a >= 1'])
params.validate()


class ParamsDictIOTest(tf.test.TestCase):

Expand Down

0 comments on commit e923b8a

Please sign in to comment.