Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New yaml-parameter: Value focus max #164

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion tf/chunkparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ def __init__(self,
buffer_size=1,
batch_size=256,
value_focus_min=1,
value_focus_max=1,
value_focus_slope=0,
workers=None):
"""
Expand Down Expand Up @@ -168,6 +169,7 @@ def __init__(self,
self.sample = sample
# set the min and slope for value focus, defaults accept all positions
self.value_focus_min = value_focus_min
self.value_focus_max = value_focus_max
self.value_focus_slope = value_focus_slope
# set the mini-batch size
self.batch_size = batch_size
Expand Down Expand Up @@ -443,7 +445,7 @@ def sample_record(self, chunkdata):
# if orig_q is NaN, accept, else accept based on value focus
if not np.isnan(orig_q):
diff_q = abs(best_q - orig_q)
thresh_p = self.value_focus_min + self.value_focus_slope * diff_q
thresh_p = min(self.value_focus_max, self.value_focus_min + self.value_focus_slope * diff_q)
if thresh_p < 1.0 and random.random() > thresh_p:
continue

Expand Down
4 changes: 3 additions & 1 deletion tf/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,6 +412,7 @@ def main(cmd):
ChunkParser.BATCH_SIZE = split_batch_size

value_focus_min = cfg['training'].get('value_focus_min', 1)
value_focus_max = cfg['training'].get('value_focus_max', 1)
value_focus_slope = cfg['training'].get('value_focus_slope', 0)

root_dir = os.path.join(cfg['training']['path'], cfg['name'])
Expand All @@ -422,7 +423,7 @@ def main(cmd):
extractor = select_extractor(tfprocess.INPUT_MODE)

if experimental_parser and (value_focus_min != 1
or value_focus_slope != 0):
or value_focus_max != 1 or value_focus_slope != 0):
raise ValueError(
'Experimental parser does not support non-default value \
focus parameters.')
Expand All @@ -447,6 +448,7 @@ def read(x):
sample=SKIP,
batch_size=ChunkParser.BATCH_SIZE,
value_focus_min=value_focus_min,
value_focus_max=value_focus_max,
value_focus_slope=value_focus_slope,
workers=train_workers)
train_dataset = tf.data.Dataset.from_generator(
Expand Down