-
Notifications
You must be signed in to change notification settings - Fork 41
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Example] fix pretrain_bert patch for tf2.16 (#2698)
- Loading branch information
Showing
3 changed files
with
79 additions
and
27 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,7 @@ | ||
From 01e3540197163d508969d8eb80bef5bfd50f771c Mon Sep 17 00:00:00 2001 | ||
From 964757dcf90a89b2546229aff95ca2ce956f0e1e Mon Sep 17 00:00:00 2001 | ||
From: yitingw1 <[email protected]> | ||
Date: Thu, 14 Mar 2024 01:45:15 -0700 | ||
Subject: [PATCH] 0314 modify for itex bf16 & HVD | ||
Date: Thu, 11 Apr 2024 00:44:31 -0700 | ||
Subject: [PATCH] 0411 modify for itex bf16 & HVD | ||
|
||
--- | ||
.../LanguageModeling/BERT/common_flags.py | 5 + | ||
|
@@ -17,16 +17,17 @@ Subject: [PATCH] 0314 modify for itex bf16 & HVD | |
.../networks/albert_transformer_encoder.py | 4 +- | ||
.../nlp/modeling/networks/masked_lm.py | 14 ++- | ||
.../modeling/networks/transformer_encoder.py | 24 +++-- | ||
.../BERT/official/utils/flags/_base.py | 30 +++--- | ||
.../BERT/official/utils/flags/flags_test.py | 3 + | ||
.../LanguageModeling/BERT/optimization.py | 27 ++++-- | ||
.../LanguageModeling/BERT/optimization.py | 22 +++-- | ||
.../LanguageModeling/BERT/run_pretraining.py | 37 +++++--- | ||
.../LanguageModeling/BERT/run_squad.py | 66 +++++++++---- | ||
.../BERT/scripts/run_pretraining_adam.sh | 6 +- | ||
.../BERT/scripts/run_pretraining_lamb.sh | 6 +- | ||
.../scripts/run_pretraining_lamb_phase1.sh | 20 ++-- | ||
.../scripts/run_pretraining_lamb_phase2.sh | 20 ++-- | ||
.../BERT/scripts/run_squad.sh | 28 ++++-- | ||
22 files changed, 339 insertions(+), 139 deletions(-) | ||
23 files changed, 348 insertions(+), 155 deletions(-) | ||
|
||
diff --git a/TensorFlow2/LanguageModeling/BERT/common_flags.py b/TensorFlow2/LanguageModeling/BERT/common_flags.py | ||
index 0c471089..b7c4780e 100644 | ||
|
@@ -643,6 +644,54 @@ index 6b8d3f77..b10ee381 100644 | |
|
||
data = embeddings | ||
attention_mask = layers.SelfAttentionMask()([data, mask]) | ||
diff --git a/TensorFlow2/LanguageModeling/BERT/official/utils/flags/_base.py b/TensorFlow2/LanguageModeling/BERT/official/utils/flags/_base.py | ||
index a12bdb79..de9ad384 100644 | ||
--- a/TensorFlow2/LanguageModeling/BERT/official/utils/flags/_base.py | ||
+++ b/TensorFlow2/LanguageModeling/BERT/official/utils/flags/_base.py | ||
@@ -22,7 +22,7 @@ from absl import flags | ||
import tensorflow as tf | ||
|
||
from official.utils.flags._conventions import help_wrap | ||
-from official.utils.logs import hooks_helper | ||
+# from official.utils.logs import hooks_helper | ||
|
||
|
||
def define_base(data_dir=True, model_dir=True, clean=False, train_epochs=False, | ||
@@ -113,20 +113,20 @@ def define_base(data_dir=True, model_dir=True, clean=False, train_epochs=False, | ||
name="run_eagerly", default=False, | ||
help="Run the model op by op without building a model function.") | ||
|
||
- if hooks: | ||
- # Construct a pretty summary of hooks. | ||
- hook_list_str = ( | ||
- u"\ufeff Hook:\n" + u"\n".join([u"\ufeff {}".format(key) for key | ||
- in hooks_helper.HOOKS])) | ||
- flags.DEFINE_list( | ||
- name="hooks", short_name="hk", default="LoggingTensorHook", | ||
- help=help_wrap( | ||
- u"A list of (case insensitive) strings to specify the names of " | ||
- u"training hooks.\n{}\n\ufeff Example: `--hooks ProfilerHook," | ||
- u"ExamplesPerSecondHook`\n See official.utils.logs.hooks_helper " | ||
- u"for details.".format(hook_list_str)) | ||
- ) | ||
- key_flags.append("hooks") | ||
+ # if hooks: | ||
+ # # Construct a pretty summary of hooks. | ||
+ # hook_list_str = ( | ||
+ # u"\ufeff Hook:\n" + u"\n".join([u"\ufeff {}".format(key) for key | ||
+ # in hooks_helper.HOOKS])) | ||
+ # flags.DEFINE_list( | ||
+ # name="hooks", short_name="hk", default="LoggingTensorHook", | ||
+ # help=help_wrap( | ||
+ # u"A list of (case insensitive) strings to specify the names of " | ||
+ # u"training hooks.\n{}\n\ufeff Example: `--hooks ProfilerHook," | ||
+ # u"ExamplesPerSecondHook`\n See official.utils.logs.hooks_helper " | ||
+ # u"for details.".format(hook_list_str)) | ||
+ # ) | ||
+ # key_flags.append("hooks") | ||
|
||
if export_dir: | ||
flags.DEFINE_string( | ||
diff --git a/TensorFlow2/LanguageModeling/BERT/official/utils/flags/flags_test.py b/TensorFlow2/LanguageModeling/BERT/official/utils/flags/flags_test.py | ||
index e11a1642..0ecb5921 100644 | ||
--- a/TensorFlow2/LanguageModeling/BERT/official/utils/flags/flags_test.py | ||
|
@@ -658,29 +707,25 @@ index e11a1642..0ecb5921 100644 | |
[__file__, "--dtype", "fp16", "--loss_scale", "5"]) | ||
self.assertEqual(flags_core.get_loss_scale(flags.FLAGS, | ||
diff --git a/TensorFlow2/LanguageModeling/BERT/optimization.py b/TensorFlow2/LanguageModeling/BERT/optimization.py | ||
index e2b75b08..b74d871b 100644 | ||
index e2b75b08..53262d84 100644 | ||
--- a/TensorFlow2/LanguageModeling/BERT/optimization.py | ||
+++ b/TensorFlow2/LanguageModeling/BERT/optimization.py | ||
@@ -22,6 +22,12 @@ import re | ||
@@ -21,7 +21,8 @@ from __future__ import print_function | ||
import re | ||
|
||
import tensorflow as tf | ||
import tensorflow_addons.optimizers as tfa_optimizers | ||
+Has_ITEXLAMB=False | ||
+try: | ||
+ from intel_extension_for_tensorflow.python.ops import LAMBOptimizer as itex_LAMB | ||
+ Has_ITEXLAMB=True | ||
+except: | ||
+ Has_ITEXLAMB=False | ||
-import tensorflow_addons.optimizers as tfa_optimizers | ||
+# import tensorflow_addons.optimizers as tfa_optimizers | ||
+from intel_extension_for_tensorflow.python.ops import LAMBOptimizer as itex_LAMB | ||
|
||
class WarmUp(tf.keras.optimizers.schedules.LearningRateSchedule): | ||
"""Applys a warmup schedule on a given learning rate decay schedule.""" | ||
@@ -96,19 +102,28 @@ def create_optimizer(init_lr, num_train_steps, num_warmup_steps, optimizer_type= | ||
@@ -96,19 +97,26 @@ def create_optimizer(init_lr, num_train_steps, num_warmup_steps, optimizer_type= | ||
epsilon=1e-6, | ||
exclude_from_weight_decay=['LayerNorm', 'layer_norm', 'bias']) | ||
else: | ||
- skip_list = ['None'] # to avoid exclude_from_layer_adaptation set to exclude_from_weight_decay if the arg is None | ||
- optimizer = tfa_optimizers.LAMB( | ||
+ if(Has_ITEXLAMB): | ||
+ optimizer = itex_LAMB( | ||
learning_rate=learning_rate_fn, | ||
weight_decay_rate=0.01, | ||
|
@@ -690,16 +735,15 @@ index e2b75b08..b74d871b 100644 | |
- exclude_from_weight_decay=['LayerNorm', 'layer_norm', 'bias'], | ||
- exclude_from_layer_adaptation=skip_list) | ||
+ epsilon=1e-6) | ||
+ else: | ||
+ skip_list = ['None'] # to avoid exclude_from_layer_adaptation set to exclude_from_weight_decay if the arg is None | ||
+ optimizer = tfa_optimizers.LAMB( | ||
+ learning_rate=learning_rate_fn, | ||
+ weight_decay_rate=0.01, | ||
+ beta_1=0.9, | ||
+ beta_2=0.999, | ||
+ epsilon=1e-6, | ||
+ exclude_from_weight_decay=['LayerNorm', 'layer_norm', 'bias'], | ||
+ exclude_from_layer_adaptation=skip_list) | ||
+ # skip_list = ['None'] # to avoid exclude_from_layer_adaptation set to exclude_from_weight_decay if the arg is None | ||
+ # optimizer = tfa_optimizers.LAMB( | ||
+ # learning_rate=learning_rate_fn, | ||
+ # weight_decay_rate=0.01, | ||
+ # beta_1=0.9, | ||
+ # beta_2=0.999, | ||
+ # epsilon=1e-6, | ||
+ # exclude_from_weight_decay=['LayerNorm', 'layer_norm', 'bias'], | ||
+ # exclude_from_layer_adaptation=skip_list) | ||
+ | ||
return optimizer | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters