Skip to content

Commit

Permalink
Oops, need to pass the transformer args - otherwise --use_bert withou…
Browse files Browse the repository at this point in the history
…t the model name is not correctly testing the dev/test sets
  • Loading branch information
AngledLuffa committed Nov 4, 2023
1 parent 1c30703 commit 51e00b1
Showing 1 changed file with 8 additions and 8 deletions.
16 changes: 8 additions & 8 deletions stanza/utils/training/run_ner.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def add_ner_args(parser):
parser.add_argument('--use_bert', default=False, action="store_true", help='Use the default transformer for this language')


def build_pretrain_args(language, dataset, charlm="default", extra_args=None, model_dir=DEFAULT_MODEL_DIR):
def build_pretrain_args(language, dataset, charlm="default", command_args=None, extra_args=None, model_dir=DEFAULT_MODEL_DIR):
"""
Returns one list with the args for this language & dataset's charlm and pretrained embedding
"""
Expand All @@ -61,7 +61,9 @@ def build_pretrain_args(language, dataset, charlm="default", extra_args=None, mo
wordvec_pretrain = find_wordvec_pretrain(language, default_pretrains, ner_pretrains, dataset, model_dir=model_dir)
wordvec_args = ['--wordvec_pretrain_file', wordvec_pretrain]

return charlm_args + wordvec_args
bert_args = common.choose_transformer(language, command_args, extra_args, warn=False)

return charlm_args + wordvec_args + bert_args


# TODO: refactor? tagger and depparse should be pretty similar
Expand All @@ -70,14 +72,13 @@ def build_model_filename(paths, short_name, command_args, extra_args):

# TODO: can avoid downloading the charlm at this point, since we
# might not even be training
pretrain_args = build_pretrain_args(short_language, dataset, command_args.charlm, extra_args)
bert_args = common.choose_transformer(short_language, command_args, extra_args, warn=False)
pretrain_args = build_pretrain_args(short_language, dataset, command_args.charlm, command_args, extra_args)

dataset_args = DATASET_EXTRA_ARGS.get(short_name, [])

train_args = ["--shorthand", short_name,
"--mode", "train"]
train_args = train_args + pretrain_args + bert_args + dataset_args + extra_args
train_args = train_args + pretrain_args + dataset_args + extra_args
if command_args.save_name is not None:
train_args.extend(["--save_name", command_args.save_name])
if command_args.save_dir is not None:
Expand Down Expand Up @@ -111,7 +112,7 @@ def run_treebank(mode, paths, treebank, short_name,
except Exception as e:
raise FileNotFoundError(f"An exception occurred while trying to build the data for {treebank} At least one portion of the data was missing: {missing_file} Please correctly build these files and then try again.") from e

pretrain_args = build_pretrain_args(language, dataset, command_args.charlm, extra_args)
pretrain_args = build_pretrain_args(language, dataset, command_args.charlm, command_args, extra_args)

if mode == Mode.TRAIN:
# VI example arguments:
Expand All @@ -124,13 +125,12 @@ def run_treebank(mode, paths, treebank, short_name,
# --charlm --charlm_shorthand vi_conll17
# --dropout 0.6 --word_dropout 0.1 --locked_dropout 0.1 --char_dropout 0.1
dataset_args = DATASET_EXTRA_ARGS.get(short_name, [])
bert_args = common.choose_transformer(language, command_args, extra_args)

train_args = ['--train_file', train_file,
'--eval_file', dev_file,
'--shorthand', short_name,
'--mode', 'train']
train_args = train_args + pretrain_args + bert_args + dataset_args + extra_args
train_args = train_args + pretrain_args + dataset_args + extra_args
logger.info("Running train step with args: {}".format(train_args))
ner_tagger.main(train_args)

Expand Down

0 comments on commit 51e00b1

Please sign in to comment.