diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md
new file mode 100644
index 0000000..22b241c
--- /dev/null
+++ b/CONTRIBUTING.md
@@ -0,0 +1,29 @@
+# How to Contribute
+
+We'd love to accept your patches and contributions to this project. There are
+just a few small guidelines you need to follow.
+
+## Contributor License Agreement
+
+Contributions to this project must be accompanied by a Contributor License
+Agreement (CLA). You (or your employer) retain the copyright to your
+contribution; this simply gives us permission to use and redistribute your
+contributions as part of the project. Head over to
+ to see your current agreements on file or
+to sign a new one.
+
+You generally only need to submit a CLA once, so if you've already submitted one
+(even if it was for a different project), you probably don't need to do it
+again.
+
+## Code reviews
+
+All submissions, including submissions by project members, require review. We
+use GitHub pull requests for this purpose. Consult
+[GitHub Help](https://help.github.com/articles/about-pull-requests/) for more
+information on using pull requests.
+
+## Community Guidelines
+
+This project follows
+[Google's Open Source Community Guidelines](https://opensource.google/conduct/).
diff --git a/LICENSE b/LICENSE
new file mode 100644
index 0000000..d645695
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,202 @@
+
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "[]"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright [yyyy] [name of copyright owner]
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
diff --git a/README.md b/README.md
index 33c58de..c8618d5 100644
--- a/README.md
+++ b/README.md
@@ -1,12 +1,156 @@
# Big Bird: Transformers for Longer Sequences
-We propose, BigBird, a sparse attention mechanism that reduces this quadratic
-dependency to linear. We show that BigBird is a universal approximator of
-sequence functions and is Turing complete, thereby preserving these properties
-of the quadratic, full attention model. The proposed sparse attention can
-handle sequences of length up to 8x of what was previously possible using
-similar hardware. As a consequence of the capability to handle longer context,
-BigBird drastically improves performance on various NLP tasks such as question
-answering and summarization.
-
-Code release in progress.
+Not an official Google product.
+
+# What is BigBird?
+BigBird, is a sparse-attention based transformer which extends Transformer based models, such as BERT to much longer sequences. Moreover, BigBird comes along with a theoretical understanding of the capabilities of a complete transformer that the sparse model can handle.
+
+As a consequence of the capability to handle longer context,
+BigBird drastically improves performance on various NLP tasks such as question answering and summarization.
+
+More details and comparisons can be found in our [presentation](https://docs.google.com/presentation/d/1FdMNqG2b8XYc89_v7-_2sba7Iz6YAlXXWuMxUbrKFK0/preview).
+
+
+# Citation
+If you find this useful, please cite our [NeurIPS 2020 paper](https://papers.nips.cc/paper/2020/hash/c8512d142a2d849725f31a9a7a361ab9-Abstract.html):
+```
+@article{zaheer2020bigbird,
+ title={Big bird: Transformers for longer sequences},
+ author={Zaheer, Manzil and Guruganesh, Guru and Dubey, Kumar Avinava and Ainslie, Joshua and Alberti, Chris and Ontanon, Santiago and Pham, Philip and Ravula, Anirudh and Wang, Qifan and Yang, Li and others},
+ journal={Advances in Neural Information Processing Systems},
+ volume={33},
+ year={2020}
+}
+```
+
+
+# Code
+
+The most important directory is `core`.
+There are three main files in `core`.
+
+* [attention.py](bigbird/core/attention.py):
+ Contains BigBird linear attention mechanism
+* [encoder.py](bigbird/core/encoder.py):
+ Contains the main long sequence encoder stack
+* [modeling.py](bigbird/core/modeling):
+ Contains packaged BERT and seq2seq transformer models with BigBird attention
+
+
+### Colab/IPython Notebook
+
+A quick fine-tuning demonstration for text classification is provided in
+[imdb.ipynb](bigbird/classifier/imdb.ipynb)
+
+
+### Create GCP Instance
+Please create a project first and create an instance in a zone which has quota as follows
+
+```bash
+gcloud compute instances create \
+ bigbird \
+ --zone=europe-west4-a \
+ --machine-type=n1-standard-16 \
+ --boot-disk-size=50GB \
+ --image-project=ml-images \
+ --image-family=tf-2-3-1 \
+ --maintenance-policy TERMINATE \
+ --restart-on-failure \
+ --scopes=cloud-platform
+
+gcloud compute tpus create \
+ bigbird \
+ --zone=europe-west4-a \
+ --accelerator-type=v3-32 \
+ --version=2.3.1
+
+gcloud compute ssh --zone "europe-west4-a" "bigbird"
+
+```
+
+For illustration we used instance name `bigbird` and zone `europe-west4-a`, but feel free to change them.
+More details about creating Google Cloud TPU can be found in [online documentations](https://cloud.google.com/tpu/docs/creating-deleting-tpus#setup_TPU_only).
+
+
+### Install and download dependencies
+```bash
+git clone https://github.com/google-research/bigbird.git
+cd bigbird
+pip3 install -e .
+```
+You can find pretrained and fine-tuned checkpoints in our [Google Cloud Storage Bucket](https://console.cloud.google.com/storage/browser/bigbird-transformer).
+
+Optionally, you can download them using `gsutil` as
+```bash
+mkdir -p bigbird/ckpt
+gsutil cp -r gs://bigbird-transformer/ bigbird/ckpt/
+```
+
+The storage bucket contains:
+* pretrained BERT model for base and large size
+* pretrained Pegasus Encoder-Decoder Transformer in large size
+* fine-tuned `tf.SavedModel` for long document summarization
+
+
+### Running Classification
+
+For quickly starting with BigBird, one can start by running the classification experiment code in `classifier` directory.
+To run the code simply execute
+
+```shell
+export GCP_PROJECT_NAME=bigbird-project # Replace by your project name
+export GCP_EXP_BUCKET=gs://bigbird-transformer-training/ # Replace
+sh -x bigbird/classifier/base_size.sh
+```
+
+
+## Using BigBird Encoder instead BERT/RoBERTa
+
+To directly use the encoder instead of say BERT model, we can use the following
+code.
+
+```python
+from bigbird.core import modeling
+
+bigb_encoder = modeling.BertModel(...)
+```
+
+It can easily replace [BERT's](https://arxiv.org/abs/1810.04805) encoder.
+
+
+Alternatively, one can also try playing with layers of BigBird encoder
+
+```python
+from bigbird.core import encoder
+
+only_layers = encoder.EncoderStack(...)
+```
+
+
+## Understanding Flags & Config
+
+All the flags and config are explained in
+`core/flags.py`. Here we explain
+some of the important config paramaters.
+
+`attention_type` is used to select the type of attention we would use. Setting
+it to `block_sparse` runs the BigBird attention module.
+
+```python
+flags.DEFINE_enum(
+ "attention_type", "block_sparse",
+ ["original_full", "simulated_sparse", "block_sparse"],
+ "Selecting attention implementation. "
+ "'original_full': full attention from original bert. "
+ "'simulated_sparse': simulated sparse attention. "
+ "'block_sparse': blocked implementation of sparse attention.")
+```
+
+`block_size` is used to define the size of blocks, whereas `num_rand_blocks` is
+used to set the number of random blocks. The code currently uses window size of
+3 blocks and 2 global blocks. The current code only supports static tensors.
+
+Important points to note:
+* Hidden dimension should be divisible by the number of heads.
+* For sequene length less than 1024, using `original_full` is advised as there
+is no benefit in using sparse BigBird attention.
diff --git a/bigbird/classifier/__init__.py b/bigbird/classifier/__init__.py
new file mode 100644
index 0000000..f6cd7c8
--- /dev/null
+++ b/bigbird/classifier/__init__.py
@@ -0,0 +1,14 @@
+# Copyright 2020 The BigBird Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
diff --git a/bigbird/classifier/base_size.sh b/bigbird/classifier/base_size.sh
new file mode 100644
index 0000000..7a54ecb
--- /dev/null
+++ b/bigbird/classifier/base_size.sh
@@ -0,0 +1,23 @@
+#!/bin/bash
+
+# TF_XLA_FLAGS=--tf_xla_auto_jit=2
+python3 bigbird/classifier/run_classifier.py \
+ --data_dir="tfds://imdb_reviews/plain_text" \
+ --output_dir="$GCP_EXP_BUCKET"classifier/imdb \
+ --attention_type=block_sparse \
+ --max_encoder_length=4096 \
+ --num_attention_heads=12 \
+ --num_hidden_layers=12 \
+ --hidden_size=768 \
+ --intermediate_size=3072 \
+ --block_size=64 \
+ --train_batch_size=2 \
+ --eval_batch_size=2 \
+ --do_train=True \
+ --do_eval=True \
+ --use_tpu=True \
+ --tpu_name=bigbird \
+ --tpu_zone=europe-west4-a \
+ --gcp_project="$GCP_PROJECT_NAME" \
+ --num_tpu_cores=32 \
+ --init_checkpoint=gs://bigbird-transformer/pretrain/bigbr_base/model.ckpt-0
diff --git a/bigbird/classifier/imdb.ipynb b/bigbird/classifier/imdb.ipynb
new file mode 100644
index 0000000..1b740c6
--- /dev/null
+++ b/bigbird/classifier/imdb.ipynb
@@ -0,0 +1,652 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "YONnGjpAYUdU"
+ },
+ "source": [
+ "\n",
+ "\u003ca href=\"https://colab.research.google.com/github/google-research/bigbird/blob/master/bigbird/classifier/imdb.ipynb\" target=\"_parent\"\u003e\u003cimg src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/\u003e\u003c/a\u003e"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "zrtR2urJV3ST"
+ },
+ "source": [
+ "##### Copyright 2020 The BigBird Authors\n",
+ "\n",
+ "Licensed under the Apache License, Version 2.0 (the \"License\");"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "xyasTfa-LVLe"
+ },
+ "outputs": [],
+ "source": [
+ "# Copyright 2020 The BigBird Authors. All Rights Reserved.\n",
+ "#\n",
+ "# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
+ "# you may not use this file except in compliance with the License.\n",
+ "# You may obtain a copy of the License at\n",
+ "#\n",
+ "# http://www.apache.org/licenses/LICENSE-2.0\n",
+ "#\n",
+ "# Unless required by applicable law or agreed to in writing, software\n",
+ "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
+ "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
+ "# See the License for the specific language governing permissions and\n",
+ "# limitations under the License.\n",
+ "# =============================================================================="
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "fcZZRDx505hq"
+ },
+ "source": [
+ "## Set Up"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "N94UyOdA0mCO"
+ },
+ "outputs": [],
+ "source": [
+ "!pip install tensorflow-gpu --upgrade\n",
+ "!pip install git+https://github.com/google-research/bigbird.git"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "0irPwcbBYvDV"
+ },
+ "outputs": [],
+ "source": [
+ "from bigbird.core import flags\n",
+ "from bigbird.core import modeling\n",
+ "from bigbird.core import utils\n",
+ "from bigbird.classifier import run_classifier\n",
+ "import tensorflow.compat.v2 as tf\n",
+ "import tensorflow_datasets as tfds\n",
+ "from tqdm import tqdm\n",
+ "import sys\n",
+ "\n",
+ "FLAGS = flags.FLAGS\n",
+ "if not hasattr(FLAGS, \"f\"): flags.DEFINE_string(\"f\", \"\", \"\")\n",
+ "FLAGS(sys.argv)\n",
+ "\n",
+ "tf.enable_v2_behavior()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "AJexg2zsxfHo"
+ },
+ "source": [
+ "## Set options"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "rph2sJ75kBNA"
+ },
+ "outputs": [],
+ "source": [
+ "FLAGS.data_dir = \"tfds://imdb_reviews/plain_text\"\n",
+ "FLAGS.attention_type = \"block_sparse\"\n",
+ "FLAGS.max_encoder_length = 3072 # 4096 on 16GB GPUs like V100, on free colab only lower memory GPU like T4 is available\n",
+ "FLAGS.learning_rate = 1e-5\n",
+ "FLAGS.attention_probs_dropout_prob = 0.0\n",
+ "FLAGS.hidden_dropout_prob = 0.0\n",
+ "FLAGS.vocab_model_file = \"gpt2\""
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "zuxI3V_3j57Y"
+ },
+ "outputs": [],
+ "source": [
+ "bert_config = flags.as_dictionary()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "kRF4TUEQxjXJ"
+ },
+ "source": [
+ "## Define classification model"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "J3yNdo5toQwq"
+ },
+ "outputs": [],
+ "source": [
+ "model = modeling.BertModel(bert_config, train=True)\n",
+ "headl = run_classifier.ClassifierLossLayer(\n",
+ " bert_config[\"num_labels\"], bert_config[\"hidden_dropout_prob\"],\n",
+ " utils.create_initializer(bert_config[\"initializer_range\"]),\n",
+ " name=bert_config[\"scope\"]+\"/classifier\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "DXOY78vbqHX9"
+ },
+ "outputs": [],
+ "source": [
+ "@tf.function(jit_compile=True)\n",
+ "def fwd_bwd(features, labels):\n",
+ " with tf.GradientTape() as g:\n",
+ " _, pooled_output = model(features, training=True)\n",
+ " loss, log_probs = headl(pooled_output, labels, True)\n",
+ " grads = g.gradient(loss, model.trainable_weights+headl.trainable_weights)\n",
+ " return loss, log_probs, grads"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "DzoTMyQlxsRo"
+ },
+ "source": [
+ "## Dataset pipeline"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "executionInfo": {
+ "elapsed": 5770,
+ "status": "ok",
+ "timestamp": 1607595313569,
+ "user": {
+ "displayName": "Manzil Zaheer",
+ "photoUrl": "",
+ "userId": "06259716656099187509"
+ },
+ "user_tz": 480
+ },
+ "id": "Oo-NQTDZZs51",
+ "outputId": "ed2a0713-e06a-442f-a188-191d1fdc494d"
+ },
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n"
+ ]
+ }
+ ],
+ "source": [
+ "train_input_fn = run_classifier.input_fn_builder(\n",
+ " data_dir=FLAGS.data_dir,\n",
+ " vocab_model_file=FLAGS.vocab_model_file,\n",
+ " max_encoder_length=FLAGS.max_encoder_length,\n",
+ " substitute_newline=FLAGS.substitute_newline,\n",
+ " is_training=True)\n",
+ "dataset = train_input_fn({'batch_size': 2})"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "executionInfo": {
+ "elapsed": 1527,
+ "status": "ok",
+ "timestamp": 1607595315103,
+ "user": {
+ "displayName": "Manzil Zaheer",
+ "photoUrl": "",
+ "userId": "06259716656099187509"
+ },
+ "user_tz": 480
+ },
+ "id": "hRvmfaNUi-V5",
+ "outputId": "18578022-0344-4d01-cb2d-048f0c4f0d78"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "(\u003ctf.Tensor: shape=(2, 4096), dtype=int32, numpy=\n",
+ "array([[ 65, 733, 474, ..., 0, 0, 0],\n",
+ " [ 65, 415, 26500, ..., 0, 0, 0]], dtype=int32)\u003e, \u003ctf.Tensor: shape=(2,), dtype=int32, numpy=array([0, 1], dtype=int32)\u003e)\n",
+ "(\u003ctf.Tensor: shape=(2, 4096), dtype=int32, numpy=\n",
+ "array([[ 65, 484, 20677, ..., 0, 0, 0],\n",
+ " [ 65, 871, 3908, ..., 0, 0, 0]], dtype=int32)\u003e, \u003ctf.Tensor: shape=(2,), dtype=int32, numpy=array([1, 1], dtype=int32)\u003e)\n",
+ "(\u003ctf.Tensor: shape=(2, 4096), dtype=int32, numpy=\n",
+ "array([[ 65, 415, 6506, ..., 0, 0, 0],\n",
+ " [ 65, 418, 1150, ..., 0, 0, 0]], dtype=int32)\u003e, \u003ctf.Tensor: shape=(2,), dtype=int32, numpy=array([1, 0], dtype=int32)\u003e)\n"
+ ]
+ }
+ ],
+ "source": [
+ "# inspect at a few examples\n",
+ "for ex in dataset.take(3):\n",
+ " print(ex)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "lYCyGH56zOOU"
+ },
+ "source": [
+ "## Check outputs"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "executionInfo": {
+ "elapsed": 458,
+ "status": "ok",
+ "timestamp": 1607595411541,
+ "user": {
+ "displayName": "Manzil Zaheer",
+ "photoUrl": "",
+ "userId": "06259716656099187509"
+ },
+ "user_tz": 480
+ },
+ "id": "5uQOwyGQzRKt",
+ "outputId": "6db22a02-3689-4b86-e6ed-b67eabbfc743"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "Loss: 0.6977416\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n"
+ ]
+ }
+ ],
+ "source": [
+ "loss, log_probs, grads = fwd_bwd(ex[0], ex[1])\n",
+ "print('Loss: ', loss.numpy())"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "Qz_LdCCdyDCR"
+ },
+ "source": [
+ "## (Optionally) Load pretrained model"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "executionInfo": {
+ "elapsed": 36637,
+ "status": "ok",
+ "timestamp": 1607595448644,
+ "user": {
+ "displayName": "Manzil Zaheer",
+ "photoUrl": "",
+ "userId": "06259716656099187509"
+ },
+ "user_tz": 480
+ },
+ "id": "rRa2dD1RzLN4",
+ "outputId": "225e476b-2314-428a-b4ee-d267fb934a70"
+ },
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "100%|██████████| 199/199 [00:34\u003c00:00, 4.94it/s]\n"
+ ]
+ }
+ ],
+ "source": [
+ "ckpt_path = 'gs://bigbird-transformer/pretrain/bigbr_base/model.ckpt-0'\n",
+ "ckpt_reader = tf.compat.v1.train.NewCheckpointReader(ckpt_path)\n",
+ "model.set_weights([ckpt_reader.get_tensor(v.name[:-2]) for v in tqdm(model.trainable_weights, position=0)])"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "r6-BziYxzL3U"
+ },
+ "source": [
+ "## Train"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "executionInfo": {
+ "elapsed": 5080359,
+ "status": "ok",
+ "timestamp": 1607600529015,
+ "user": {
+ "displayName": "Manzil Zaheer",
+ "photoUrl": "",
+ "userId": "06259716656099187509"
+ },
+ "user_tz": 480
+ },
+ "id": "IWjkDvu9k7ie",
+ "outputId": "67dcf3e1-c126-4291-90bc-da71b8c07c52"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Loss = 0.7094929218292236 Accuracy = 0.5"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ " 0%| | 0/10000 [00:06\u003c1:32:59, 1.79it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "Loss = 0.4131925702095032 Accuracy = 0.8123108148574829"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ " 10%|█ | 1000/10000 [08:26\u003c1:16:08, 1.97it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "Loss = 0.32566359639167786 Accuracy = 0.8608739376068115"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ " 20%|██ | 2000/10000 [16:52\u003c1:08:17, 1.95it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "Loss = 0.28784531354904175 Accuracy = 0.882480800151825"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ " 30%|███ | 3000/10000 [25:18\u003c58:58, 1.98it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "Loss = 0.2657429575920105 Accuracy = 0.8936356902122498"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ " 40%|████ | 4000/10000 [33:44\u003c50:41, 1.97it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "Loss = 0.24971100687980652 Accuracy = 0.9020236134529114"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ " 50%|█████ | 5000/10000 [42:10\u003c42:03, 1.98it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "Loss = 0.23958759009838104 Accuracy = 0.9069437384605408"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ " 60%|██████ | 6000/10000 [50:36\u003c33:43, 1.98it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "Loss = 0.2304597944021225 Accuracy = 0.9108854532241821"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ " 70%|███████ | 7000/10000 [59:02\u003c25:20, 1.97it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "Loss = 0.2243848443031311 Accuracy = 0.9135903120040894"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ " 80%|████████ | 8000/10000 [1:07:30\u003c17:23, 1.92it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "Loss = 0.21911397576332092 Accuracy = 0.9155822396278381"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ " 90%|█████████ | 9000/10000 [1:16:05\u003c08:34, 1.94it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "Loss = 0.21378542482852936 Accuracy = 0.9180262088775635"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "100%|██████████| 10000/10000 [1:24:39\u003c00:00, 1.94it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n"
+ ]
+ }
+ ],
+ "source": [
+ "opt = tf.keras.optimizers.Adam(FLAGS.learning_rate)\n",
+ "train_loss = tf.keras.metrics.Mean(name='train_loss')\n",
+ "train_accuracy = tf.keras.metrics.CategoricalAccuracy(name='train_accuracy')\n",
+ "\n",
+ "for i, ex in enumerate(tqdm(dataset.take(FLAGS.num_train_steps), position=0)):\n",
+ " loss, log_probs, grads = fwd_bwd(ex[0], ex[1])\n",
+ " opt.apply_gradients(zip(grads, model.trainable_weights+headl.trainable_weights))\n",
+ " train_loss(loss)\n",
+ " train_accuracy(tf.one_hot(ex[1], 2), log_probs)\n",
+ " if i% 1000 == 0:\n",
+ " print('Loss = {} Accuracy = {}'.format(train_loss.result().numpy(), train_accuracy.result().numpy()))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "-Mq_xhMzef42"
+ },
+ "outputs": [],
+ "source": [
+ "eval_input_fn = run_classifier.input_fn_builder(\n",
+ " data_dir=FLAGS.data_dir,\n",
+ " vocab_model_file=FLAGS.vocab_model_file,\n",
+ " max_encoder_length=FLAGS.max_encoder_length,\n",
+ " substitute_newline=FLAGS.substitute_newline,\n",
+ " is_training=False)\n",
+ "eval_dataset = eval_input_fn({'batch_size': 2})"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "executionInfo": {
+ "elapsed": 3263,
+ "status": "ok",
+ "timestamp": 1607617729500,
+ "user": {
+ "displayName": "Manzil Zaheer",
+ "photoUrl": "",
+ "userId": "06259716656099187509"
+ },
+ "user_tz": 480
+ },
+ "id": "rqPN4R8kerUG",
+ "outputId": "194f8765-f13d-46f9-f7fc-0b4b54c9e9d5"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "Loss = 0.16173037886619568 Accuracy = 0.9459513425827026100"
+ ]
+ }
+ ],
+ "source": [
+ "eval_loss = tf.keras.metrics.Mean(name='eval_loss')\n",
+ "eval_accuracy = tf.keras.metrics.CategoricalAccuracy(name='eval_accuracy')\n",
+ "\n",
+ "for ex in eval_dataset.take(FLAGS.num_train_steps):\n",
+ " loss, log_probs, grads = fwd_bwd(ex[0], ex[1])\n",
+ " eval_loss(loss)\n",
+ " eval_accuracy(tf.one_hot(ex[1], 2), log_probs)\n",
+ "print('Loss = {} Accuracy = {}'.format(eval_loss.result().numpy(), eval_accuracy.result().numpy()))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "BvEFgoXJxQOa"
+ },
+ "outputs": [],
+ "source": [
+ ""
+ ]
+ }
+ ],
+ "metadata": {
+ "colab": {
+ "collapsed_sections": [],
+ "last_runtime": {},
+ "name": "BigBirdGPU.ipynb",
+ "provenance": []
+ },
+ "kernelspec": {
+ "display_name": "Python 3",
+ "name": "python3"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 0
+}
diff --git a/bigbird/classifier/large_size.sh b/bigbird/classifier/large_size.sh
new file mode 100644
index 0000000..f232ab8
--- /dev/null
+++ b/bigbird/classifier/large_size.sh
@@ -0,0 +1,23 @@
+#!/bin/bash
+
+# TF_XLA_FLAGS=--tf_xla_auto_jit=2
+bigbird/classifier/run_classifier.py \
+ --data_dir="tfds://imdb_reviews/plain_text" \
+ --output_dir="$GCP_EXP_BUCKET"classifier/imdb \
+ --attention_type=block_sparse \
+ --max_encoder_length=4096 \
+ --num_attention_heads=16 \
+ --num_hidden_layers=24 \
+ --hidden_size=1024 \
+ --intermediate_size=4096 \
+ --block_size=64 \
+ --train_batch_size=1 \
+ --eval_batch_size=1 \
+ --do_train=True \
+ --do_eval=True \
+ --use_tpu=True \
+ --tpu_name=bigbird \
+ --tpu_zone=europe-west4-a \
+ --gcp_project="$GCP_PROJECT_NAME" \
+ --num_tpu_cores=32 \
+ --init_checkpoint=gs://bigbird-transformer/pretrain/bigbr_base/model.ckpt-0
diff --git a/bigbird/classifier/run_classifier.py b/bigbird/classifier/run_classifier.py
new file mode 100644
index 0000000..19641ca
--- /dev/null
+++ b/bigbird/classifier/run_classifier.py
@@ -0,0 +1,454 @@
+# Copyright 2020 The BigBird Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Run classification fine-tuning for BigBird."""
+
+import os
+
+from absl import app
+from absl import logging
+from bigbird.core import flags
+from bigbird.core import modeling
+from bigbird.core import optimization
+from bigbird.core import utils
+from natsort import natsorted
+import tensorflow.compat.v2 as tf
+import tensorflow_datasets as tfds
+import tensorflow_text as tft
+
+
+FLAGS = flags.FLAGS
+
+## Required parameters
+
+flags.DEFINE_string(
+ "data_dir", "tfds://imdb_reviews/plain_text",
+ "The input data dir. Should contain the TFRecord files. "
+ "Can be TF Dataset with prefix tfds://")
+
+flags.DEFINE_string(
+ "output_dir", "/tmp/bigb",
+ "The output directory where the model checkpoints will be written.")
+
+## Other parameters
+
+flags.DEFINE_string(
+ "init_checkpoint", None,
+ "Initial checkpoint (usually from a pre-trained BigBird model).")
+
+flags.DEFINE_integer(
+ "max_encoder_length", 512,
+ "The maximum total input sequence length after SentencePiece tokenization. "
+ "Sequences longer than this will be truncated, and sequences shorter "
+ "than this will be padded.")
+
+flags.DEFINE_string(
+ "substitute_newline", None,
+ "Replace newline charachter from text with supplied string.")
+
+flags.DEFINE_bool(
+ "do_train", True,
+ "Whether to run training.")
+
+flags.DEFINE_bool(
+ "do_eval", False,
+ "Whether to run eval on the dev set.")
+
+flags.DEFINE_bool(
+ "do_export", False,
+ "Whether to export the model as TF SavedModel.")
+
+flags.DEFINE_integer(
+ "train_batch_size", 8,
+ "Local batch size for training. "
+ "Total batch size will be multiplied by number gpu/tpu cores available.")
+
+flags.DEFINE_integer(
+ "eval_batch_size", 8,
+ "Local batch size for eval. "
+ "Total batch size will be multiplied by number gpu/tpu cores available.")
+
+flags.DEFINE_string(
+ "optimizer", "AdamWeightDecay",
+ "Optimizer to use. Can be Adafactor, Adam, and AdamWeightDecay.")
+
+flags.DEFINE_float(
+ "learning_rate", 1e-5,
+ "The initial learning rate for Adam.")
+
+flags.DEFINE_integer(
+ "num_train_steps", 16000,
+ "Total number of training steps to perform.")
+
+flags.DEFINE_integer(
+ "num_warmup_steps", 1000,
+ "Number of steps to perform linear warmup.")
+
+flags.DEFINE_integer(
+ "save_checkpoints_steps", 1000,
+ "How often to save the model checkpoint.")
+
+flags.DEFINE_integer(
+ "num_labels", 2,
+ "Number of ways to classify.")
+
+
+def input_fn_builder(data_dir, vocab_model_file, max_encoder_length,
+ substitute_newline, is_training, tmp_dir=None):
+ """Creates an `input_fn` closure to be passed to TPUEstimator."""
+
+ def _decode_record(record):
+ """Decodes a record to a TensorFlow example."""
+ name_to_features = {
+ "text": tf.io.FixedLenFeature([], tf.string),
+ "label": tf.io.FixedLenFeature([], tf.int64),
+ }
+ example = tf.io.parse_single_example(record, name_to_features)
+ return example
+
+ def _tokenize_example(example):
+ text, label = example["text"], example["label"]
+ tokenizer = tft.SentencepieceTokenizer(
+ model=tf.io.gfile.GFile(vocab_model_file, "rb").read())
+ if substitute_newline:
+ text = tf.strings.regex_replace(text, "\n", substitute_newline)
+ ids = tokenizer.tokenize(text)
+ ids = ids[:max_encoder_length - 2]
+ # Add [CLS] (65) and [SEP] (66) special tokens.
+ prefix = tf.constant([65])
+ suffix = tf.constant([66])
+ ids = tf.concat([prefix, ids, suffix], axis=0)
+ if isinstance(ids, tf.RaggedTensor):
+ ids = ids.to_tensor(0)
+
+ # tf.Example only supports tf.int64, but the TPU is better with tf.int32.
+ label = tf.cast(label, tf.int32)
+
+ return ids, label
+
+ def input_fn(params):
+ """The actual input function."""
+ batch_size = params["batch_size"]
+ tpu_context = params.get("context", None)
+ seed = 0
+
+ # Load dataset and handle tfds separately
+ split = "train" if is_training else "test"
+ if "tfds://" == data_dir[:7]:
+ d = tfds.load(data_dir[7:], split=split,
+ shuffle_files=is_training,
+ data_dir=tmp_dir)
+ else:
+ input_files = tf.io.gfile.glob(
+ os.path.join(data_dir, "{}.tfrecord*".format(split)))
+
+ # Classification datasets are small so parallel interleaved reading
+ # won't buy us much.
+ d = tf.data.TFRecordDataset(input_files)
+ d = d.map(_decode_record,
+ num_parallel_calls=tf.data.experimental.AUTOTUNE,
+ deterministic=is_training)
+
+ d = d.map(_tokenize_example,
+ num_parallel_calls=tf.data.experimental.AUTOTUNE,
+ deterministic=is_training)
+
+ # Tokenize and batch dataset by sentencepiece
+ if is_training:
+ # Classification datasets are usually small
+ # and interleaving files may not be effective.
+ # So to ensure different data in a multi-host setup
+ # we explicitly shard the dataset by host id.
+ if tpu_context: # ensuring different data in multi-host setup
+ d = d.shard(tpu_context.num_hosts, tpu_context.current_host)
+ seed = tpu_context.current_host
+ d = d.shuffle(buffer_size=10000, seed=seed,
+ reshuffle_each_iteration=True)
+ d = d.repeat()
+ d = d.padded_batch(batch_size, ([max_encoder_length], []),
+ drop_remainder=True) # For static shape
+ return d
+
+ return input_fn
+
+
+def serving_input_fn_builder(batch_size, max_encoder_length,
+ vocab_model_file, substitute_newline):
+ """Creates an `input_fn` closure for exported SavedModel."""
+ def dynamic_padding(inp, min_size):
+ pad_size = tf.maximum(min_size - tf.shape(inp)[1], 0)
+ paddings = [[0, 0], [0, pad_size]]
+ return tf.pad(inp, paddings)
+
+ def input_fn():
+ # text input
+ text = tf.compat.v1.placeholder(tf.string, [batch_size], name="input_text")
+
+ # text tokenize
+ tokenizer = tft.SentencepieceTokenizer(
+ model=tf.io.gfile.GFile(vocab_model_file, "rb").read())
+ if substitute_newline:
+ text = tf.strings.regex_replace(text, "\n", substitute_newline)
+ ids = tokenizer.tokenize(text)
+ ids = ids[:, :max_encoder_length - 2]
+
+ # Add [CLS] and [SEP] special tokens.
+ prefix = tf.repeat(tf.constant([[65]]), batch_size, axis=0)
+ suffix = tf.repeat(tf.constant([[66]]), batch_size, axis=0)
+ ids = tf.concat([prefix, ids, suffix], axis=1)
+ if isinstance(ids, tf.RaggedTensor):
+ ids = ids.to_tensor(0)
+
+ # text padding: Pad only if necessary and reshape properly
+ padded_ids = dynamic_padding(ids, max_encoder_length)
+ ids = tf.slice(padded_ids, [0, 0], [batch_size, max_encoder_length])
+
+ receiver_tensors = {"input": text}
+ features = {"input_ids": tf.cast(ids, tf.int32, name="input_ids")}
+
+ return tf.estimator.export.ServingInputReceiver(
+ features=features, receiver_tensors=receiver_tensors)
+
+ return input_fn
+
+
+def model_fn_builder(bert_config):
+ """Returns `model_fn` closure for TPUEstimator."""
+
+ def model_fn(features, labels, mode, params): # pylint: disable=unused-argument
+ """The `model_fn` for TPUEstimator."""
+
+ if isinstance(features, dict):
+ if not labels and "labels" in features:
+ labels = features["labels"]
+ features = features["input_ids"]
+
+ is_training = (mode == tf.estimator.ModeKeys.TRAIN)
+
+ model = modeling.BertModel(bert_config)
+ headl = ClassifierLossLayer(
+ bert_config["num_labels"], bert_config["hidden_dropout_prob"],
+ utils.create_initializer(bert_config["initializer_range"]),
+ name=bert_config["scope"]+"/classifier")
+
+ _, pooled_output = model(features, training=is_training)
+ total_loss, log_probs = headl(pooled_output, labels, is_training)
+
+ tvars = tf.compat.v1.trainable_variables()
+ utils.log_variables(tvars, bert_config["ckpt_var_list"])
+
+ output_spec = None
+ if mode == tf.estimator.ModeKeys.TRAIN:
+
+ learning_rate = optimization.get_linear_warmup_linear_decay_lr(
+ init_lr=bert_config["learning_rate"],
+ num_train_steps=bert_config["num_train_steps"],
+ num_warmup_steps=bert_config["num_warmup_steps"])
+
+ optimizer = optimization.get_optimizer(bert_config, learning_rate)
+
+ global_step = tf.compat.v1.train.get_or_create_global_step()
+
+ gradients = optimizer.compute_gradients(total_loss, tvars)
+ train_op = optimizer.apply_gradients(gradients, global_step=global_step)
+
+ output_spec = tf.compat.v1.estimator.tpu.TPUEstimatorSpec(
+ mode=mode,
+ loss=total_loss,
+ train_op=train_op,
+ host_call=utils.add_scalars_to_summary(
+ bert_config["output_dir"], {"learning_rate": learning_rate}))
+
+ elif mode == tf.estimator.ModeKeys.EVAL:
+
+ def metric_fn(loss_value, label_ids, log_probs):
+ loss = tf.compat.v1.metrics.mean(values=loss_value)
+
+ predictions = tf.argmax(log_probs, axis=-1, output_type=tf.int32)
+ accuracy = tf.compat.v1.metrics.accuracy(
+ labels=label_ids, predictions=predictions)
+ p1, p1_op = tf.compat.v1.metrics.precision_at_k(
+ labels=tf.cast(label_ids, tf.int64), predictions=log_probs, k=1)
+ r1, r1_op = tf.compat.v1.metrics.recall_at_k(
+ labels=tf.cast(label_ids, tf.int64), predictions=log_probs, k=1)
+ f11 = tf.math.divide_no_nan(2*p1*r1, p1+r1)
+
+ metric_dict = {
+ "P@1": (p1, p1_op),
+ "R@1": (r1, r1_op),
+ "f1@1": (f11, tf.no_op()),
+ "classification_accuracy": accuracy,
+ "classification_loss": loss,
+ }
+
+ return metric_dict
+
+ eval_metrics = (metric_fn,
+ [tf.expand_dims(total_loss, 0), labels, log_probs])
+ output_spec = tf.compat.v1.estimator.tpu.TPUEstimatorSpec(
+ mode=mode,
+ loss=total_loss,
+ eval_metrics=eval_metrics)
+ else:
+ output_spec = tf.compat.v1.estimator.tpu.TPUEstimatorSpec(
+ mode=mode,
+ predictions={"log-probabilities": log_probs})
+
+ return output_spec
+
+ return model_fn
+
+
+class ClassifierLossLayer(tf.compat.v1.layers.Layer):
+ """Final classifier layer with loss."""
+
+ def __init__(self,
+ num_labels,
+ dropout_prob=0.0,
+ initializer=None,
+ use_bias=True,
+ name="classifier"):
+ super(ClassifierLossLayer, self).__init__(name=name)
+ self.num_labels = num_labels
+ self.initializer = initializer
+ self.dropout_prob = dropout_prob
+ self.use_bias = use_bias
+
+ self.w = None
+ self.b = None
+
+ def call(self, input_tensor, labels=None, training=None):
+ last_dim = utils.get_shape_list(input_tensor)[-1]
+ input_tensor = utils.dropout(input_tensor, self.dropout_prob, training)
+
+ if self.w is None:
+ self.w = tf.compat.v1.get_variable(
+ name="kernel",
+ shape=[last_dim, self.num_labels],
+ initializer=self.initializer)
+ self.initializer = None
+ self._trainable_weights.append(self.w)
+ logits = tf.matmul(input_tensor, self.w)
+
+ if self.use_bias:
+ if self.b is None:
+ self.b = tf.compat.v1.get_variable(
+ name="bias",
+ shape=[self.num_labels],
+ initializer=tf.zeros_initializer)
+ self._trainable_weights.append(self.b)
+ logits = tf.nn.bias_add(logits, self.b)
+
+ log_probs = tf.nn.log_softmax(logits, axis=-1)
+ if labels is not None:
+ one_hot_labels = tf.one_hot(labels, depth=self.num_labels,
+ dtype=tf.float32)
+ per_example_loss = -tf.reduce_sum(one_hot_labels * log_probs, axis=-1)
+ loss = tf.reduce_mean(per_example_loss)
+ else:
+ loss = tf.constant(0.0)
+
+ return loss, log_probs
+
+
+def main(_):
+
+ if not FLAGS.do_train and not FLAGS.do_eval and not FLAGS.do_export:
+ raise ValueError(
+ "At least one of `do_train`, `do_eval` must be True.")
+
+ bert_config = flags.as_dictionary()
+
+ if FLAGS.max_encoder_length > bert_config["max_position_embeddings"]:
+ raise ValueError(
+ "Cannot use sequence length %d because the BERT model "
+ "was only trained up to sequence length %d" %
+ (FLAGS.max_encoder_length, bert_config["max_position_embeddings"]))
+
+ tf.io.gfile.makedirs(FLAGS.output_dir)
+ if FLAGS.do_train:
+ flags.save(os.path.join(FLAGS.output_dir, "classifier.config"))
+
+ model_fn = model_fn_builder(bert_config)
+ estimator = utils.get_estimator(bert_config, model_fn)
+
+ if FLAGS.do_train:
+ logging.info("***** Running training *****")
+ logging.info(" Batch size = %d", estimator.train_batch_size)
+ logging.info(" Num steps = %d", FLAGS.num_train_steps)
+ train_input_fn = input_fn_builder(
+ data_dir=FLAGS.data_dir,
+ vocab_model_file=FLAGS.vocab_model_file,
+ max_encoder_length=FLAGS.max_encoder_length,
+ substitute_newline=FLAGS.substitute_newline,
+ tmp_dir=os.path.join(FLAGS.output_dir, "tfds"),
+ is_training=True)
+ estimator.train(input_fn=train_input_fn, max_steps=FLAGS.num_train_steps)
+
+ if FLAGS.do_eval:
+ logging.info("***** Running evaluation *****")
+ logging.info(" Batch size = %d", estimator.eval_batch_size)
+
+ eval_input_fn = input_fn_builder(
+ data_dir=FLAGS.data_dir,
+ vocab_model_file=FLAGS.vocab_model_file,
+ max_encoder_length=FLAGS.max_encoder_length,
+ substitute_newline=FLAGS.substitute_newline,
+ tmp_dir=os.path.join(FLAGS.output_dir, "tfds"),
+ is_training=False)
+
+ if FLAGS.use_tpu:
+ with tf.compat.v1.Session() as sess:
+ eval_steps = eval_input_fn({
+ "batch_size": estimator.eval_batch_size
+ }).cardinality().eval(session=sess)
+ else:
+ eval_steps = None
+
+ # Run evaluation for each new checkpoint.
+ all_ckpts = [
+ v.split(".meta")[0] for v in tf.io.gfile.glob(
+ os.path.join(FLAGS.output_dir, "model.ckpt*.meta"))
+ ]
+ all_ckpts = natsorted(all_ckpts)
+ for ckpt in all_ckpts:
+ current_step = int(os.path.basename(ckpt).split("-")[1])
+ output_eval_file = os.path.join(
+ FLAGS.output_dir, "eval_results_{}.txt".format(current_step))
+ result = estimator.evaluate(input_fn=eval_input_fn,
+ checkpoint_path=ckpt,
+ steps=eval_steps)
+
+ with tf.io.gfile.GFile(output_eval_file, "w") as writer:
+ logging.info("***** Eval results *****")
+ for key in sorted(result.keys()):
+ logging.info(" %s = %s", key, str(result[key]))
+ writer.write("%s = %s\n" % (key, str(result[key])))
+
+ if FLAGS.do_export:
+ logging.info("***** Running export *****")
+
+ serving_input_fn = serving_input_fn_builder(
+ batch_size=FLAGS.eval_batch_size,
+ vocab_model_file=FLAGS.vocab_model_file,
+ max_encoder_length=FLAGS.max_encoder_length,
+ substitute_newline=FLAGS.substitute_newline)
+
+ estimator.export_saved_model(
+ os.path.join(FLAGS.output_dir, "export"), serving_input_fn)
+
+
+if __name__ == "__main__":
+ tf.compat.v1.disable_v2_behavior()
+ app.run(main)
diff --git a/bigbird/core/__init__.py b/bigbird/core/__init__.py
new file mode 100644
index 0000000..f6cd7c8
--- /dev/null
+++ b/bigbird/core/__init__.py
@@ -0,0 +1,14 @@
+# Copyright 2020 The BigBird Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
diff --git a/bigbird/core/attention.py b/bigbird/core/attention.py
new file mode 100644
index 0000000..59c9eac
--- /dev/null
+++ b/bigbird/core/attention.py
@@ -0,0 +1,1033 @@
+# Copyright 2020 The BigBird Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""BigBird Attention Layers."""
+
+from absl import logging
+from bigbird.core import utils
+import numpy as np
+import tensorflow.compat.v2 as tf
+
+
+MAX_SEQ_LEN = 4096
+
+
+def get_single_block_row_attention(block_id,
+ to_start_block_id,
+ to_end_block_id,
+ num_rand_blocks,
+ window_block_left=1,
+ window_block_right=1,
+ global_block_left=1,
+ global_block_right=1):
+ """For a single row block get random row attention.
+
+ Args:
+ block_id: int. block id of row.
+ to_start_block_id: int. random attention coloum start id.
+ to_end_block_id: int. random attention coloum end id.
+ num_rand_blocks: int. number of random blocks to be selected.
+ window_block_left: int. number of blocks of window to left of a block.
+ window_block_right: int. number of blocks of window to right of a block.
+ global_block_left: int. Number of blocks globally used to the left.
+ global_block_right: int. Number of blocks globally used to the right.
+
+ Returns:
+ row containing the random attention vector of size num_rand_blocks.
+ """
+
+ # list of to_blocks from which to choose random attention
+ to_block_list = np.arange(to_start_block_id, to_end_block_id,
+ dtype=np.int32)
+ # permute the blocks
+ perm_block = np.random.permutation(to_block_list)
+ # print(perm_block)
+
+ # illegal blocks for the current block id, using window
+ illegal_blocks = list(
+ range(block_id - window_block_left, block_id + window_block_right + 1))
+
+ # Add blocks at the start and at the end
+ illegal_blocks.extend(list(range(global_block_left)))
+ illegal_blocks.extend(
+ list(range(to_end_block_id - global_block_right, to_end_block_id)))
+
+ # The second from_block cannot choose random attention on second last to_block
+ if block_id == 1:
+ illegal_blocks.append(to_end_block_id-2)
+
+ # The second last from_block cannot choose random attention on second to_block
+ if block_id == to_end_block_id - 2:
+ illegal_blocks.append(1)
+
+ selected_random_blokcs = []
+
+ for i in range(to_end_block_id - to_start_block_id):
+ if perm_block[i] not in illegal_blocks:
+ selected_random_blokcs.append(perm_block[i])
+ if len(selected_random_blokcs) == num_rand_blocks:
+ break
+ return np.array(selected_random_blokcs, dtype=np.int32)
+
+
+def bigbird_block_rand_mask_with_head(from_seq_length,
+ to_seq_length,
+ from_block_size,
+ to_block_size,
+ num_heads,
+ plan_from_length,
+ plan_num_rand_blocks,
+ window_block_left=1,
+ window_block_right=1,
+ global_block_top=1,
+ global_block_bottom=1,
+ global_block_left=1,
+ global_block_right=1):
+ """Create adjacency list of random attention.
+
+ Args:
+ from_seq_length: int. length of from sequence.
+ to_seq_length: int. length of to sequence.
+ from_block_size: int. size of block in from sequence.
+ to_block_size: int. size of block in to sequence.
+ num_heads: int. total number of heads.
+ plan_from_length: list. plan from lenght where num_rand are choosen from.
+ plan_num_rand_blocks: list. number of rand blocks within the plan.
+ window_block_left: int. number of blocks of window to left of a block.
+ window_block_right: int. number of blocks of window to right of a block.
+ global_block_top: int. number of blocks at the top.
+ global_block_bottom: int. number of blocks at the bottom.
+ global_block_left: int. Number of blocks globally used to the left.
+ global_block_right: int. Number of blocks globally used to the right.
+
+ Returns:
+ adjacency list of size num_head where each element is of size
+ from_seq_length//from_block_size-2 by num_rand_blocks
+ """
+ assert from_seq_length//from_block_size == to_seq_length//to_block_size, \
+ "Error the number of blocks needs to be same!"
+
+ assert from_seq_length in plan_from_length, \
+ "Error from sequence length not in plan!"
+
+ # Total number of blocks in the mmask
+ num_blocks = from_seq_length//from_block_size
+ # Number of blocks per plan
+ plan_block_length = np.array(plan_from_length) // from_block_size
+ # till when to follow plan
+ max_plan_idx = plan_from_length.index(from_seq_length)
+ # Random Attention adjajency list
+ rand_attn = [np.zeros((num_blocks,
+ np.sum(plan_num_rand_blocks[:max_plan_idx+1])),
+ dtype=np.int32) for i in range(num_heads)]
+
+ # We will go iteratively over the plan blocks and pick random number of
+ # Attention blocks from the legally allowed blocks
+ for plan_idx in range(max_plan_idx+1):
+ rnd_r_cnt = 0
+ if plan_idx > 0:
+ # set the row for all from_blocks starting from 0 to
+ # plan_block_length[plan_idx-1]
+ # column indx start fromm plan_block_length[plan_idx-1] and ends at
+ # plan_block_length[plan_idx]
+ if plan_num_rand_blocks[plan_idx] > 0:
+ rnd_r_cnt = int(np.sum(plan_num_rand_blocks[:plan_idx]))
+ curr_r_cnt = int(np.sum(plan_num_rand_blocks[:plan_idx+1]))
+ for blk_rw_idx in range(global_block_top,
+ plan_block_length[plan_idx-1]):
+ for h in range(num_heads):
+ # print("head", h, "blk_rw_idx", blk_rw_idx)
+ rand_attn[h][blk_rw_idx,
+ rnd_r_cnt:curr_r_cnt] = get_single_block_row_attention(
+ block_id=blk_rw_idx,
+ to_start_block_id=plan_block_length[plan_idx - 1],
+ to_end_block_id=plan_block_length[plan_idx],
+ num_rand_blocks=plan_num_rand_blocks[plan_idx],
+ window_block_left=window_block_left,
+ window_block_right=window_block_right,
+ global_block_left=global_block_left,
+ global_block_right=global_block_right)
+
+ for pl_id in range(plan_idx):
+ if plan_num_rand_blocks[pl_id] == 0:
+ continue
+ for blk_rw_idx in range(plan_block_length[plan_idx-1],
+ plan_block_length[plan_idx]):
+ rnd_r_cnt = 0
+ to_start_block_id = 0
+ if pl_id > 0:
+ rnd_r_cnt = int(np.sum(plan_num_rand_blocks[:pl_id]))
+ to_start_block_id = plan_block_length[pl_id-1]
+ curr_r_cnt = int(np.sum(plan_num_rand_blocks[:pl_id+1]))
+ for h in range(num_heads):
+ # print("head", h, "blk_rw_idx", blk_rw_idx)
+ rand_attn[h][blk_rw_idx,
+ rnd_r_cnt:curr_r_cnt] = get_single_block_row_attention(
+ block_id=blk_rw_idx,
+ to_start_block_id=to_start_block_id,
+ to_end_block_id=plan_block_length[pl_id],
+ num_rand_blocks=plan_num_rand_blocks[pl_id],
+ window_block_left=window_block_left,
+ window_block_right=window_block_right,
+ global_block_left=global_block_left,
+ global_block_right=global_block_right)
+
+ if plan_num_rand_blocks[plan_idx] == 0:
+ continue
+ # print("Start from here")
+ curr_r_cnt = int(np.sum(plan_num_rand_blocks[:plan_idx+1]))
+ from_start_block_id = global_block_top
+ to_start_block_id = 0
+ if plan_idx > 0:
+ rnd_r_cnt = int(np.sum(plan_num_rand_blocks[:plan_idx]))
+ from_start_block_id = plan_block_length[plan_idx-1]
+ to_start_block_id = plan_block_length[plan_idx-1]
+
+ for blk_rw_idx in range(from_start_block_id, plan_block_length[plan_idx]):
+ for h in range(num_heads):
+ # print("head", h, "blk_rw_idx", blk_rw_idx)
+ rand_attn[h][blk_rw_idx,
+ rnd_r_cnt:curr_r_cnt] = get_single_block_row_attention(
+ block_id=blk_rw_idx,
+ to_start_block_id=to_start_block_id,
+ to_end_block_id=plan_block_length[plan_idx],
+ num_rand_blocks=plan_num_rand_blocks[plan_idx],
+ window_block_left=window_block_left,
+ window_block_right=window_block_right,
+ global_block_left=global_block_left,
+ global_block_right=global_block_right)
+
+ for nh in range(num_heads):
+ rand_attn[nh] = rand_attn[nh][global_block_top:num_blocks -
+ global_block_bottom, :]
+ return rand_attn
+
+
+def get_rand_attn_plan(from_seq_length, from_block_size, num_rand_blocks):
+ """Gives the plan of where to put random attention.
+
+ Args:
+ from_seq_length: int. length of from sequence.
+ from_block_size: int. size of block in from sequence.
+ num_rand_blocks: int. Number of random chunks per row.
+
+ Returns:
+ plan_from_length: ending location of from block
+ plan_num_rand_blocks: number of random ending location for each block
+ """
+ # general plan
+ plan_from_length = []
+ plan_num_rand_blocks = []
+ if (2*num_rand_blocks + 5) < (from_seq_length // from_block_size):
+ plan_from_length.append(int((2*num_rand_blocks + 5)*from_block_size))
+ plan_num_rand_blocks.append(num_rand_blocks)
+ plan_from_length.append(from_seq_length)
+ plan_num_rand_blocks.append(0)
+ elif (num_rand_blocks + 5) < (from_seq_length // from_block_size):
+ plan_from_length.append(int((num_rand_blocks + 5)*from_block_size))
+ plan_num_rand_blocks.append(num_rand_blocks//2)
+ plan_from_length.append(from_seq_length)
+ plan_num_rand_blocks.append(num_rand_blocks - (num_rand_blocks//2))
+ else:
+ plan_from_length.append(from_seq_length)
+ plan_num_rand_blocks.append(num_rand_blocks)
+
+ return plan_from_length, plan_num_rand_blocks
+
+
+def bigbird_block_rand_mask(from_seq_length,
+ to_seq_length,
+ from_block_size,
+ to_block_size,
+ num_rand_blocks,
+ last_idx=-1):
+ """Create adjacency list of random attention.
+
+ Args:
+ from_seq_length: int. length of from sequence.
+ to_seq_length: int. length of to sequence.
+ from_block_size: int. size of block in from sequence.
+ to_block_size: int. size of block in to sequence.
+ num_rand_blocks: int. Number of random chunks per row.
+ last_idx: if -1 then num_rand_blocks blocks chosen anywhere in to sequence,
+ if positive then num_rand_blocks blocks choosen only upto last_idx.
+
+ Returns:
+ adjacency list of size from_seq_length//from_block_size-2 by num_rand_blocks
+ """
+ assert from_seq_length//from_block_size == to_seq_length//to_block_size, \
+ "Error the number of blocks needs to be same!"
+
+ rand_attn = np.zeros(
+ (from_seq_length // from_block_size - 2, num_rand_blocks), dtype=np.int32)
+ middle_seq = np.arange(1, to_seq_length // to_block_size - 1, dtype=np.int32)
+ last = to_seq_length // to_block_size - 1
+ if last_idx > (2 * to_block_size):
+ last = (last_idx // to_block_size) - 1
+
+ r = num_rand_blocks # shorthand
+ for i in range(1, from_seq_length // from_block_size-1):
+ start = i-2
+ end = i
+ if i == 1:
+ rand_attn[i-1, :] = np.random.permutation(middle_seq[2:last])[:r]
+ elif i == 2:
+ rand_attn[i-1, :] = np.random.permutation(middle_seq[3:last])[:r]
+ elif i == from_seq_length // from_block_size - 3:
+ rand_attn[i-1, :] = np.random.permutation(middle_seq[:last])[:r]
+ # Missing -3: should have been sliced till last-3
+ elif i == from_seq_length // from_block_size - 2:
+ rand_attn[i-1, :] = np.random.permutation(middle_seq[:last])[:r]
+ # Missing -4: should have been sliced till last-4
+ else:
+ if start > last:
+ start = last
+ rand_attn[i-1, :] = np.random.permutation(middle_seq[:start])[:r]
+ elif (end+1) == last:
+ rand_attn[i-1, :] = np.random.permutation(middle_seq[:start])[:r]
+ else:
+ rand_attn[i-1, :] = np.random.permutation(
+ np.concatenate((middle_seq[:start], middle_seq[end+1:last])))[:r]
+ return rand_attn
+
+
+def full_bigbird_mask(from_seq_length,
+ to_seq_length,
+ from_block_size,
+ to_block_size,
+ num_rand_blocks,
+ rand_attn=None,
+ focus=1024):
+ """Calculate BigBird attention pattern as a full dense matrix.
+
+ Args:
+ from_seq_length: int. length of from sequence.
+ to_seq_length: int. length of to sequence.
+ from_block_size: int. size of block in from sequence.
+ to_block_size: int. size of block in to sequence.
+ num_rand_blocks: int. Number of random chunks per row.
+ rand_attn: adjajency matrix for random attention.
+ focus: pick random mask within focus
+
+ Returns:
+ attention mask matrix of shape [from_seq_length, to_seq_length]
+ """
+ if rand_attn is None:
+ rand_attn = bigbird_block_rand_mask(MAX_SEQ_LEN, MAX_SEQ_LEN,
+ from_block_size, to_block_size,
+ num_rand_blocks, focus)
+
+ attn_mask = np.zeros((MAX_SEQ_LEN, MAX_SEQ_LEN), dtype=np.int32)
+ for i in range(1, (MAX_SEQ_LEN // from_block_size) - 1):
+ attn_mask[(i) * from_block_size:(i + 1) * from_block_size,
+ (i - 1) * to_block_size:(i + 2) * to_block_size] = 1
+ for j in rand_attn[i - 1, :]:
+ attn_mask[i * from_block_size:(i + 1) * from_block_size,
+ j * to_block_size:(j + 1) * to_block_size] = 1
+
+ attn_mask[:from_block_size, :] = 1
+ attn_mask[:, :to_block_size] = 1
+ attn_mask[:, -to_block_size:] = 1
+ attn_mask[-from_block_size:, :] = 1
+ clipped_attn_mask = attn_mask[:from_seq_length, :to_seq_length]
+ return np.array(clipped_attn_mask, dtype=bool)
+
+
+def create_rand_mask_from_inputs(from_blocked_mask,
+ to_blocked_mask,
+ rand_attn,
+ num_attention_heads,
+ num_rand_blocks,
+ batch_size,
+ from_seq_length,
+ from_block_size):
+ """Create 3D attention mask from a 2D tensor mask.
+
+ Args:
+ from_blocked_mask: 2D Tensor of shape [batch_size,
+ from_seq_length//from_block_size, from_block_size].
+ to_blocked_mask: int32 Tensor of shape [batch_size,
+ to_seq_length//to_block_size, to_block_size].
+ rand_attn: [batch_size, num_attention_heads,
+ from_seq_length//from_block_size-2, num_rand_blocks]
+ num_attention_heads: int. Number of attention heads.
+ num_rand_blocks: int. Number of random chunks per row.
+ batch_size: int. Batch size for computation.
+ from_seq_length: int. length of from sequence.
+ from_block_size: int. size of block in from sequence.
+
+ Returns:
+ float Tensor of shape [batch_size, num_attention_heads,
+ from_seq_length//from_block_size-2,
+ from_block_size, num_rand_blocks*to_block_size].
+ """
+ num_windows = from_seq_length // from_block_size - 2
+ rand_mask = tf.reshape(
+ tf.gather(to_blocked_mask, rand_attn, batch_dims=1), [
+ batch_size, num_attention_heads, num_windows,
+ num_rand_blocks * from_block_size
+ ])
+ rand_mask = tf.einsum("BLQ,BHLK->BHLQK", from_blocked_mask[:, 1:-1],
+ rand_mask)
+ return rand_mask
+
+
+def create_band_mask_from_inputs(from_blocked_mask, to_blocked_mask):
+ """Create 3D attention mask from a 2D tensor mask.
+
+ Args:
+ from_blocked_mask: 2D Tensor of shape [batch_size,
+ from_seq_length//from_block_size, from_block_size].
+ to_blocked_mask: int32 Tensor of shape [batch_size,
+ to_seq_length//to_block_size, to_block_size].
+
+ Returns:
+ float Tensor of shape [batch_size, 1, from_seq_length//from_block_size-4,
+ from_block_size, 3*to_block_size].
+ """
+ exp_blocked_to_pad = tf.concat(
+ [to_blocked_mask[:, 1:-3], to_blocked_mask[:, 2:-2],
+ to_blocked_mask[:, 3:-1]], 2)
+ band_mask = tf.einsum("BLQ,BLK->BLQK",
+ tf.cast(from_blocked_mask[:, 2:-2], tf.float32),
+ tf.cast(exp_blocked_to_pad, tf.float32))
+ band_mask = tf.expand_dims(band_mask, 1)
+ return band_mask
+
+
+def create_attention_mask_from_input_mask(from_mask, to_mask):
+ """Create attention mask from a 2D tensor mask.
+
+ Args:
+ from_mask: int32 Tensor of shape [batch_size, from_seq_length].
+ to_mask: int32 Tensor of shape [batch_size, to_seq_length].
+
+ Returns:
+ int32 Tensor of shape [batch_size, 1, from_seq_length, to_seq_length].
+ """
+ mask = tf.einsum("BF,BT->BFT", from_mask, to_mask)
+
+ # expand to create a slot for heads.
+ mask = tf.expand_dims(mask, 1)
+
+ return mask
+
+
+def original_full_attention(query_layer,
+ key_layer,
+ value_layer,
+ attention_mask,
+ size_per_head,
+ attention_probs_dropout_prob):
+ """Full quadratic attention calculation.
+
+ Args:
+ query_layer: float Tensor of shape [batch_size, num_attention_heads,
+ from_seq_length, size_per_head]
+ key_layer: float Tensor of shape [batch_size, num_attention_heads,
+ to_seq_length, size_per_head]
+ value_layer: float Tensor of shape [batch_size, num_attention_heads,
+ to_seq_length, size_per_head]
+ attention_mask: (optional) int32 Tensor of shape [batch_size,
+ from_seq_length, to_seq_length]. The values should be 1 or 0. The
+ attention scores will effectively be set to -infinity for any positions in
+ the mask that are 0, and will be unchanged for positions that are 1.
+ size_per_head: (optional) int. Size of each attention head.
+ attention_probs_dropout_prob: (optional) float. Dropout probability of the
+ attention probabilities.
+
+ Returns:
+ float Tensor of shape [batch_size, from_seq_length, num_attention_heads,
+ size_per_head].
+ """
+
+ # Directly take n^2 dot product between "query" and "key".
+ attention_scores = tf.einsum("BNFH,BNTH->BNFT", query_layer, key_layer)
+ attention_scores = tf.multiply(attention_scores,
+ 1.0 / np.sqrt(float(size_per_head)))
+
+ if attention_mask is not None:
+ # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
+ # masked positions, this operation will create a tensor which is 0.0 for
+ # positions we want to attend and -10000.0 for masked positions.
+ adder = (1.0 - tf.cast(attention_mask, tf.float32)) * -10000.0
+
+ # Since we are adding it to the raw scores before the softmax, this is
+ # effectively the same as removing these entirely.
+ attention_scores += adder
+
+ # Normalize the attention scores to probabilities.
+ # `attention_probs` = [B, N, F, T]
+ attention_probs = tf.nn.softmax(attention_scores)
+
+ # This is actually dropping out entire tokens to attend to, which might
+ # seem a bit unusual, but is taken from the original Transformer paper.
+ attention_probs = utils.dropout(attention_probs, attention_probs_dropout_prob)
+
+ # `context_layer` = [B, F, N, H]
+ context_layer = tf.einsum("BNFT,BNTH->BFNH", attention_probs, value_layer)
+ return context_layer
+
+
+def bigbird_simulated_attention(query_layer,
+ key_layer,
+ value_layer,
+ attention_mask,
+ num_attention_heads,
+ num_rand_blocks,
+ size_per_head,
+ from_seq_length,
+ to_seq_length,
+ from_block_size,
+ to_block_size,
+ seed=None):
+ """BigBird attention calculation using masks in quadratic time.
+
+ Args:
+ query_layer: float Tensor of shape [batch_size, num_attention_heads,
+ from_seq_length, size_per_head]
+ key_layer: float Tensor of shape [batch_size, num_attention_heads,
+ to_seq_length, size_per_head]
+ value_layer: float Tensor of shape [batch_size, num_attention_heads,
+ to_seq_length, size_per_head]
+ attention_mask: int32 Tensor of shape [batch_size,
+ from_seq_length, to_seq_length]. The values should be 1 or 0. The
+ attention scores will effectively be set to -infinity for any positions in
+ the mask that are 0, and will be unchanged for positions that are 1.
+ num_attention_heads: int. Number of attention heads.
+ num_rand_blocks: int. Number of random chunks per row.
+ size_per_head: int. Size of each attention head.
+ from_seq_length: int. length of from sequence.
+ to_seq_length: int. length of to sequence.
+ from_block_size: int. size of block in from sequence.
+ to_block_size: int. size of block in to sequence.
+ seed: (Optional) int. Reandom seed for generating random mask.
+
+ Returns:
+ float Tensor of shape [batch_size, from_seq_length, num_attention_heads,
+ size_per_head].
+ """
+
+ if seed:
+ np.random.seed(seed)
+
+ plan_from_length, plan_num_rand_blocks = get_rand_attn_plan(
+ from_seq_length, from_block_size, num_rand_blocks)
+
+ rand_attn = bigbird_block_rand_mask_with_head(
+ from_seq_length=from_seq_length,
+ to_seq_length=to_seq_length,
+ from_block_size=from_block_size,
+ to_block_size=to_block_size,
+ num_heads=num_attention_heads,
+ plan_from_length=plan_from_length,
+ plan_num_rand_blocks=plan_num_rand_blocks)
+ temp_mask = [
+ full_bigbird_mask( # pylint: disable=g-complex-comprehension
+ from_seq_length, to_seq_length, from_block_size, to_block_size,
+ num_rand_blocks, rand_attn=rand_attn[i], focus=1024)
+ for i in range(num_attention_heads)
+ ]
+ temp_mask = np.stack(temp_mask, axis=0)
+ temp_mask = np.array(temp_mask, dtype=bool)
+
+ rand_block_mask = tf.constant(temp_mask, dtype=tf.bool) # [N, F, T]
+ rand_block_mask = tf.cast(rand_block_mask, tf.int32)
+ rand_block_mask = tf.expand_dims(rand_block_mask, 0) # [1, N, F, T]
+ if attention_mask is not None:
+ attention_mask = tf.minimum(attention_mask, rand_block_mask)
+ else:
+ attention_mask = rand_block_mask
+ return original_full_attention(query_layer,
+ key_layer,
+ value_layer,
+ attention_mask,
+ size_per_head,
+ attention_probs_dropout_prob=0.0)
+
+
+def bigbird_block_sparse_attention(query_layer,
+ key_layer,
+ value_layer,
+ band_mask,
+ from_mask,
+ to_mask,
+ from_blocked_mask,
+ to_blocked_mask,
+ num_attention_heads,
+ num_rand_blocks,
+ size_per_head,
+ batch_size,
+ from_seq_length,
+ to_seq_length,
+ from_block_size,
+ to_block_size,
+ seed=None,
+ plan_from_length=None,
+ plan_num_rand_blocks=None):
+ """BigBird attention sparse calculation using blocks in linear time.
+
+ Assumes from_seq_length//from_block_size == to_seq_length//to_block_size.
+
+
+ Args:
+ query_layer: float Tensor of shape [batch_size, num_attention_heads,
+ from_seq_length, size_per_head]
+ key_layer: float Tensor of shape [batch_size, num_attention_heads,
+ to_seq_length, size_per_head]
+ value_layer: float Tensor of shape [batch_size, num_attention_heads,
+ to_seq_length, size_per_head]
+ band_mask: (optional) int32 Tensor of shape [batch_size, 1,
+ from_seq_length//from_block_size-4, from_block_size, 3*to_block_size].
+ The values should be 1 or 0. The attention scores will effectively be
+ set to -infinity for any positions in the mask that are 0, and will be
+ unchanged for positions that are 1.
+ from_mask: (optional) int32 Tensor of shape [batch_size, 1,
+ from_seq_length, 1]. The values should be 1 or 0. The
+ attention scores will effectively be set to -infinity for any positions in
+ the mask that are 0, and will be unchanged for positions that are 1.
+ to_mask: (optional) int32 Tensor of shape [batch_size, 1, 1,
+ to_seq_length]. The values should be 1 or 0. The
+ attention scores will effectively be set to -infinity for any positions in
+ the mask that are 0, and will be unchanged for positions that are 1.
+ from_blocked_mask: (optional) int32 Tensor of shape [batch_size,
+ from_seq_length//from_block_size, from_block_size].
+ Same as from_mask, just reshaped.
+ to_blocked_mask: (optional) int32 Tensor of shape [batch_size,
+ to_seq_length//to_block_size, to_block_size].
+ Same as to_mask, just reshaped.
+ num_attention_heads: int. Number of attention heads.
+ num_rand_blocks: int. Number of random chunks per row.
+ size_per_head: int. Size of each attention head.
+ batch_size: int. Batch size for computation.
+ from_seq_length: int. length of from sequence.
+ to_seq_length: int. length of to sequence.
+ from_block_size: int. size of block in from sequence.
+ to_block_size: int. size of block in to sequence.
+ seed: (Optional) int. Reandom seed for generating random mask.
+ plan_from_length: (Optional) list. Plan of where to put random attn. It
+ divides the block matrix into chuncks, where each chunck will have
+ some randomm attn.
+ plan_num_rand_blocks: (Optional) list. Number of random per block given by
+ plan_from_length.
+
+ Returns:
+ float Tensor of shape [batch_size, from_seq_length, num_attention_heads,
+ size_per_head].
+ """
+ assert from_seq_length//from_block_size == to_seq_length//to_block_size
+
+ # cast masks to float
+ from_mask = tf.cast(from_mask, tf.float32)
+ to_mask = tf.cast(to_mask, tf.float32)
+ band_mask = tf.cast(band_mask, tf.float32)
+ from_blocked_mask = tf.cast(from_blocked_mask, tf.float32)
+ to_blocked_mask = tf.cast(to_blocked_mask, tf.float32)
+
+ # generate random attention and corresponding masks
+ np.random.seed(seed)
+ if from_seq_length in [1024, 3072, 4096]: # old plans used in paper
+ rand_attn = [
+ bigbird_block_rand_mask( # pylint: disable=g-complex-comprehension
+ MAX_SEQ_LEN, MAX_SEQ_LEN,
+ from_block_size, to_block_size, num_rand_blocks,
+ last_idx=1024)[:(from_seq_length // from_block_size - 2)]
+ for _ in range(num_attention_heads)
+ ]
+ else:
+ if plan_from_length is None:
+ plan_from_length, plan_num_rand_blocks = get_rand_attn_plan(
+ from_seq_length, from_block_size, num_rand_blocks)
+
+ rand_attn = bigbird_block_rand_mask_with_head(
+ from_seq_length=from_seq_length,
+ to_seq_length=to_seq_length,
+ from_block_size=from_block_size,
+ to_block_size=to_block_size,
+ num_heads=num_attention_heads,
+ plan_from_length=plan_from_length,
+ plan_num_rand_blocks=plan_num_rand_blocks)
+ rand_attn = np.stack(rand_attn, axis=0)
+ rand_attn = tf.constant(rand_attn, dtype=tf.int32)
+ rand_attn = tf.expand_dims(rand_attn, 0)
+ rand_attn = tf.repeat(rand_attn, batch_size, 0)
+
+ rand_mask = create_rand_mask_from_inputs(
+ from_blocked_mask, to_blocked_mask, rand_attn,
+ num_attention_heads, num_rand_blocks,
+ batch_size, from_seq_length, from_block_size,)
+
+ # Define shorthands
+ h = num_attention_heads
+ r = num_rand_blocks
+ d = size_per_head
+ b = batch_size
+ m = from_seq_length
+ n = to_seq_length
+ wm = from_block_size
+ wn = to_block_size
+
+ blocked_query_matrix = tf.reshape(query_layer, (b, h, m // wm, wm, -1))
+ blocked_key_matrix = tf.reshape(key_layer, (b, h, n // wn, wn, -1))
+ blocked_value_matrix = tf.reshape(value_layer, (b, h, n // wn, wn, -1))
+ gathered_key = tf.reshape(
+ tf.gather(blocked_key_matrix, rand_attn, batch_dims=2, name="gather_key"),
+ (b, h, m // wm - 2, r * wn, -1)) # [b, h, n//wn-2, r, wn, -1]
+ gathered_value = tf.reshape(
+ tf.gather(
+ blocked_value_matrix, rand_attn, batch_dims=2, name="gather_value"),
+ (b, h, m // wm - 2, r * wn, -1)) # [b, h, n//wn-2, r, wn, -1]
+
+ first_product = tf.einsum(
+ "BHQD,BHKD->BHQK", blocked_query_matrix[:, :, 0],
+ key_layer) # [b, h, wm, -1] x [b, h, n, -1] ==> [b, h, wm, n]
+ first_product = tf.multiply(first_product, 1.0 / np.sqrt(d))
+ first_product += (1.0 - to_mask) * -10000.0
+ first_attn_weights = tf.nn.softmax(first_product) # [b, h, wm, n]
+ first_context_layer = tf.einsum(
+ "BHQK,BHKD->BHQD", first_attn_weights,
+ value_layer) # [b, h, wm, n] x [b, h, n, -1] ==> [b, h, wm, -1]
+ first_context_layer = tf.expand_dims(first_context_layer, 2)
+
+ second_key_mat = tf.concat([
+ blocked_key_matrix[:, :, 0], blocked_key_matrix[:, :, 1],
+ blocked_key_matrix[:, :, 2], blocked_key_matrix[:, :, -1],
+ gathered_key[:, :, 0]], 2) # [b, h, (4+r)*wn, -1]
+ second_value_mat = tf.concat([
+ blocked_value_matrix[:, :, 0], blocked_value_matrix[:, :, 1],
+ blocked_value_matrix[:, :, 2], blocked_value_matrix[:, :, -1],
+ gathered_value[:, :, 0]], 2) # [b, h, (4+r)*wn, -1]
+ second_product = tf.einsum(
+ "BHQD,BHKD->BHQK", blocked_query_matrix[:, :, 1], second_key_mat
+ ) # [b, h, wm, -1] x [b, h, (4+r)*wn, -1] ==> [b, h, wm, (4+r)*wn]
+ second_seq_pad = tf.concat([
+ to_mask[:, :, :, :3 * wn], to_mask[:, :, :, -wn:],
+ tf.ones([b, 1, 1, r * wn], dtype=tf.float32)], 3)
+ second_rand_pad = tf.concat(
+ [tf.ones([b, h, wm, 4 * wn], dtype=tf.float32), rand_mask[:, :, 0]], 3)
+ second_product = tf.multiply(second_product, 1.0 / np.sqrt(d))
+ second_product += (1.0 -
+ tf.minimum(second_seq_pad, second_rand_pad)) * -10000.0
+ second_attn_weights = tf.nn.softmax(second_product) # [b , h, wm, (4+r)*wn]
+ second_context_layer = tf.einsum(
+ "BHQK,BHKD->BHQD", second_attn_weights, second_value_mat
+ ) # [b, h, wm, (4+r)*wn] x [b, h, (4+r)*wn, -1] ==> [b, h, wm, -1]
+ second_context_layer = tf.expand_dims(second_context_layer, 2)
+
+ exp_blocked_key_matrix = tf.concat([
+ blocked_key_matrix[:, :, 1:-3], blocked_key_matrix[:, :, 2:-2],
+ blocked_key_matrix[:, :, 3:-1]], 3) # [b, h, m//wm-4, 3*wn, -1]
+ exp_blocked_value_matrix = tf.concat([
+ blocked_value_matrix[:, :, 1:-3], blocked_value_matrix[:, :, 2:-2],
+ blocked_value_matrix[:, :, 3:-1]], 3) # [b, h, m//wm-4, 3*wn, -1]
+ middle_query_matrix = blocked_query_matrix[:, :, 2:-2]
+ inner_band_product = tf.einsum(
+ "BHLQD,BHLKD->BHLQK", middle_query_matrix, exp_blocked_key_matrix
+ ) # [b, h, m//wm-4, wm, -1] x [b, h, m//wm-4, 3*wn, -1]
+ # ==> [b, h, m//wm-4, wm, 3*wn]
+ inner_band_product = tf.multiply(inner_band_product, 1.0 / np.sqrt(d))
+ rand_band_product = tf.einsum(
+ "BHLQD,BHLKD->BHLQK", middle_query_matrix, gathered_key[:, :, 1:-1]
+ ) # [b, h, m//wm-4, wm, -1] x [b, h, m//wm-4, r*wn, -1]
+ # ==> [b, h, m//wm-4, wm, r*wn]
+ rand_band_product = tf.multiply(rand_band_product, 1.0 / np.sqrt(d))
+ first_band_product = tf.einsum(
+ "BHLQD,BHKD->BHLQK", middle_query_matrix, blocked_key_matrix[:, :, 0]
+ ) # [b, h, m//wm-4, wm, -1] x [b, h, wn, -1] ==> [b, h, m//wm-4, wm, wn]
+ first_band_product = tf.multiply(first_band_product, 1.0 / np.sqrt(d))
+ last_band_product = tf.einsum(
+ "BHLQD,BHKD->BHLQK", middle_query_matrix, blocked_key_matrix[:, :, -1]
+ ) # [b, h, m//wm-4, wm, -1] x [b, h, wn, -1] ==> [b, h, m//wm-4, wm, wn]
+ last_band_product = tf.multiply(last_band_product, 1.0 / np.sqrt(d))
+ inner_band_product += (1.0 - band_mask) * -10000.0
+ first_band_product += (
+ 1.0 - tf.expand_dims(to_mask[:, :, :, :wn], 3)) * -10000.0
+ last_band_product += (
+ 1.0 - tf.expand_dims(to_mask[:, :, :, -wn:], 3)) * -10000.0
+ rand_band_product += (1.0 - rand_mask[:, :, 1:-1]) * -10000.0
+ band_product = tf.concat([
+ first_band_product, inner_band_product, rand_band_product,
+ last_band_product], -1) # [b, h, m//wm-4, wm, (5+r)*wn]
+ attn_weights = tf.nn.softmax(band_product) # [b, h, m//wm-4, wm, (5+r)*wn]
+ context_layer = tf.einsum(
+ "BHLQK,BHLKD->BHLQD", attn_weights[:, :, :, :, wn:4 * wn],
+ exp_blocked_value_matrix
+ ) # [b, h, m//wm-4, wm, 3*wn] x [b, h, m//wm-4, 3*wn, -1]
+ # ==> [b, h, m//wm-4, wm, -1]
+ context_layer += tf.einsum(
+ "BHLQK,BHLKD->BHLQD", attn_weights[:, :, :, :, 4 * wn:-wn],
+ gathered_value[:, :, 1:-1]
+ ) # [b, h, m//wm-4, wm, r*wn] x [b, h, m//wm-4, r*wn, -1]
+ # ==> [b, h, m//wm-4, wm, -1]
+ context_layer += tf.einsum(
+ "BHLQK,BHKD->BHLQD", attn_weights[:, :, :, :, :wn],
+ blocked_value_matrix[:, :, 0]
+ ) # [b, h, m//wm-4, wm, wn] x [b, h, wn, -1] ==> [b, h, m//wm-4, wm, -1]
+ context_layer += tf.einsum(
+ "BHLQK,BHKD->BHLQD", attn_weights[:, :, :, :, -wn:],
+ blocked_value_matrix[:, :, -1]
+ ) # [b, h, m//wm-4, wm, wn] x [b, h, wn, -1] ==> [b, h, m//wm-4, wm, -1]
+
+ second_last_key_mat = tf.concat([
+ blocked_key_matrix[:, :, 0], blocked_key_matrix[:, :, -3],
+ blocked_key_matrix[:, :, -2], blocked_key_matrix[:, :, -1],
+ gathered_key[:, :, -1]], 2) # [b, h, (4+r)*wn, -1]
+ second_last_value_mat = tf.concat([
+ blocked_value_matrix[:, :, 0], blocked_value_matrix[:, :, -3],
+ blocked_value_matrix[:, :, -2], blocked_value_matrix[:, :, -1],
+ gathered_value[:, :, -1]], 2) # [b, h, (4+r)*wn, -1]
+ second_last_product = tf.einsum(
+ "BHQD,BHKD->BHQK", blocked_query_matrix[:, :, -2], second_last_key_mat
+ ) # [b, h, wm, -1] x [b, h, (4+r)*wn, -1] ==> [b, h, wm, (4+r)*wn]
+ second_last_seq_pad = tf.concat([
+ to_mask[:, :, :, :wn], to_mask[:, :, :, -3 * wn:],
+ tf.ones([b, 1, 1, r * wn], dtype=tf.float32)], 3)
+ second_last_rand_pad = tf.concat(
+ [tf.ones([b, h, wm, 4 * wn], dtype=tf.float32), rand_mask[:, :, -1]], 3)
+ second_last_product = tf.multiply(second_last_product, 1.0 / np.sqrt(d))
+ second_last_product += (
+ 1.0 - tf.minimum(second_last_seq_pad, second_last_rand_pad)) * -10000.0
+ second_last_attn_weights = tf.nn.softmax(
+ second_last_product) # [b, h, wm, (4+r)*wn]
+ second_last_context_layer = tf.einsum(
+ "BHQK,BHKD->BHQD", second_last_attn_weights, second_last_value_mat
+ ) # [b, h, wm, (4+r)*wn] x [b, h, (4+r)*wn, -1] ==> [b, h, wm, -1]
+ second_last_context_layer = tf.expand_dims(second_last_context_layer, 2)
+
+ last_product = tf.einsum(
+ "BHQD,BHKD->BHQK", blocked_query_matrix[:, :, -1],
+ key_layer) # [b, h, wm, -1] x [b, h, n, -1] ==> [b, h, wm, n]
+ last_product = tf.multiply(last_product, 1.0 / np.sqrt(d))
+ last_product += (1.0 - to_mask) * -10000.0
+ last_attn_weights = tf.nn.softmax(last_product) # [b, h, wm, n]
+ last_context_layer = tf.einsum(
+ "BHQK,BHKD->BHQD", last_attn_weights,
+ value_layer) # [b, h, wm, n] x [b, h, n, -1] ==> [b, h, wm, -1]
+ last_context_layer = tf.expand_dims(last_context_layer, 2)
+
+ context_layer = tf.concat([
+ first_context_layer, second_context_layer, context_layer,
+ second_last_context_layer, last_context_layer
+ ], 2)
+ context_layer = tf.reshape(context_layer, (b, h, m, -1)) * from_mask
+ context_layer = tf.transpose(context_layer, (0, 2, 1, 3))
+ return context_layer
+
+
+class MultiHeadedAttentionLayer(tf.compat.v1.layers.Layer):
+ """A multi-headed attention layer.
+
+ It implements following types of multi-headed attention:
+ - original_full attention from "Attention is all you Need".
+ - simulated_sparse attention from BigBird with full quadratic implemention.
+ - block_sparse attention from BigBird with memory efficient linear impl.
+ """
+
+ def __init__(self,
+ attention_type,
+ num_attention_heads=1,
+ num_rand_blocks=3,
+ size_per_head=512,
+ initializer_range=0.02,
+ from_block_size=64,
+ to_block_size=64,
+ attention_probs_dropout_prob=0.0,
+ use_bias=True,
+ seed=None,
+ query_act=None,
+ key_act=None,
+ value_act=None,
+ name=None,
+ **kwargs):
+ """Constructor for a multi-headed attention layer.
+
+ Args:
+ attention_type: Type of attention, needs to be one of ['original_full',
+ 'simulated_sparse', 'block_sparse'].
+ num_attention_heads: (optional) int. Number of attention heads.
+ num_rand_blocks: (optional) int. Number of random chunks per row.
+ size_per_head: (optional) int. Size of each attention head.
+ initializer_range: (optional) float. Range of the weight initializer.
+ from_block_size: (optional) int. size of block in from sequence.
+ to_block_size: (optional) int. size of block in to sequence.
+ attention_probs_dropout_prob: (optional) float. Dropout probability of the
+ attention probabilities.
+ use_bias: Whether the layer uses a bias vector.
+ seed: (Optional) int. Reandom seed for generating random mask.
+ query_act: (optional) Activation function for the query transform.
+ key_act: (optional) Activation function for the key transform.
+ value_act: (optional) Activation function for the value transform.
+ name: The name scope of this layer.
+ **kwargs: others
+ """
+ super(MultiHeadedAttentionLayer, self).__init__(name=name, **kwargs)
+ self.query_layer = utils.Dense3dLayer(
+ num_attention_heads, size_per_head,
+ utils.create_initializer(initializer_range), query_act,
+ "query", head_first=True, use_bias=use_bias)
+
+ self.key_layer = utils.Dense3dLayer(
+ num_attention_heads, size_per_head,
+ utils.create_initializer(initializer_range), key_act,
+ "key", head_first=True, use_bias=use_bias)
+
+ self.value_layer = utils.Dense3dLayer(
+ num_attention_heads, size_per_head,
+ utils.create_initializer(initializer_range), value_act,
+ "value", head_first=True, use_bias=use_bias)
+
+ def attn_impl(
+ query, key, value, attention_mask,
+ band_mask, from_mask, to_mask, from_blocked_mask, to_blocked_mask,
+ batch_size, from_seq_length, to_seq_length, training):
+ if attention_type == "original_full":
+ logging.info("**** Using original full attention ****")
+ attn_fn = original_full_attention(
+ query, key, value,
+ attention_mask, size_per_head,
+ attention_probs_dropout_prob if training else 0.0)
+ elif attention_type == "simulated_sparse":
+ logging.info("**** Using simulated sparse attention ****")
+ attn_fn = bigbird_simulated_attention(
+ query, key, value,
+ attention_mask, num_attention_heads, num_rand_blocks, size_per_head,
+ from_seq_length, to_seq_length, from_block_size, to_block_size,
+ seed)
+ elif attention_type == "block_sparse":
+ logging.info("**** Using block sparse attention ****")
+ attn_fn = bigbird_block_sparse_attention(
+ query, key, value,
+ band_mask, from_mask, to_mask, from_blocked_mask, to_blocked_mask,
+ num_attention_heads, num_rand_blocks, size_per_head, batch_size,
+ from_seq_length, to_seq_length, from_block_size, to_block_size,
+ seed)
+ else:
+ raise NotImplementedError(
+ "Attention type {} is not implemented".format(attention_type))
+ return attn_fn
+
+ self.attn_impl = attn_impl
+
+ @property
+ def trainable_weights(self):
+ self._trainable_weights = (self.query_layer.trainable_weights +
+ self.key_layer.trainable_weights +
+ self.value_layer.trainable_weights)
+ return self._trainable_weights
+
+ def call(self,
+ from_tensor,
+ to_tensor,
+ attention_mask=None,
+ band_mask=None,
+ from_mask=None,
+ to_mask=None,
+ from_blocked_mask=None,
+ to_blocked_mask=None,
+ cache=None,
+ decode_i=None,
+ training=None):
+ """Implements a multi-headed attention layer from from_tensor to to_tensor.
+
+ Args:
+ from_tensor: float Tensor of shape [batch_size, from_seq_length,
+ from_width]
+ to_tensor: float Tensor of shape [batch_size, to_seq_length, to_width].
+ attention_mask: (optional) int32 Tensor of shape [batch_size,
+ from_seq_length, to_seq_length]. The values should be 1 or 0. The
+ attention scores will effectively be set to -infinity for any positions
+ in the mask that are 0, and will be unchanged for positions that are 1.
+ band_mask: (optional) int32 Tensor of shape [batch_size, 1,
+ from_seq_length//from_block_size-4, from_block_size, 3*to_block_size].
+ The values should be 1 or 0. The attention scores will effectively be
+ set to -infinity for any positions in the mask that are 0, and will be
+ unchanged for positions that are 1.
+ from_mask: (optional) int32 Tensor of shape [batch_size, 1,
+ from_seq_length, 1]. The values should be 1 or 0. The
+ attention scores will effectively be set to -infinity for any positions
+ in the mask that are 0, and will be unchanged for positions that are 1.
+ to_mask: (optional) int32 Tensor of shape [batch_size, 1, 1,
+ to_seq_length]. The values should be 1 or 0. The
+ attention scores will effectively be set to -infinity for any positions
+ in the mask that are 0, and will be unchanged for positions that are 1.
+ from_blocked_mask: (optional) int32 Tensor of shape [batch_size,
+ from_seq_length//from_block_size, from_block_size].
+ Same as from_mask, just reshaped.
+ to_blocked_mask: (optional) int32 Tensor of shape [batch_size,
+ to_seq_length//to_block_size, to_block_size].
+ Same as to_mask, just reshaped.
+ cache: (Used during prediction) A dictionary with tensors containing
+ results of previous attentions. The dictionary must have the items:
+ {"k": tensor with shape
+ [batch_size, max_len, num_attention_heads, size_per_head],
+ "v": tensor with shape
+ [batch_size, max_len, num_attention_heads, size_per_head]}
+ decode_i: (Used during prediction) current location of decoding
+ training: Boolean indicating whether the call is training or inference.
+
+ Returns:
+ float Tensor of shape [batch_size, from_seq_length, num_attention_heads,
+ size_per_head].
+
+ Raises:
+ ValueError: Any of the arguments or tensor shapes are invalid.
+ NotImplementedError: For unknown attention type.
+ """
+ from_shape = utils.get_shape_list(from_tensor, expected_rank=3)
+ to_shape = utils.get_shape_list(to_tensor, expected_rank=3)
+
+ if len(from_shape) != len(to_shape):
+ raise ValueError(
+ "The rank of `from_tensor` must match the rank of `to_tensor`.")
+
+ if len(from_shape) == 3:
+ batch_size = from_shape[0]
+ from_seq_length = from_shape[1]
+ to_seq_length = to_shape[1]
+ else:
+ raise ValueError(
+ "Need rank 3 tensors to attention_layer.")
+
+ # Scalar dimensions referenced here:
+ # b = batch size (number of sequences)
+ # m = `from_tensor` sequence length
+ # n = `to_tensor` sequence length
+ # h = `num_attention_heads`
+ # d = `size_per_head`
+
+ # `query` = [b, h, m, d]
+ query = self.query_layer(from_tensor)
+
+ # `key` = [b, h, n, d]
+ key = self.key_layer(to_tensor)
+
+ # `value_layer` = [b, h, n, d]
+ value = self.value_layer(to_tensor)
+
+ if cache is not None and decode_i is not None:
+ max_len = utils.get_shape_list(cache["k"])[2]
+ indices_select = tf.reshape(
+ tf.one_hot(decode_i, max_len, dtype=to_tensor.dtype),
+ [1, 1, max_len, 1])
+ key = cache["k"] + key * indices_select
+ value = cache["v"] + value * indices_select
+ cache["k"] = key
+ cache["v"] = value
+
+ contextual_output = self.attn_impl(
+ query, key, value, attention_mask,
+ band_mask, from_mask, to_mask, from_blocked_mask, to_blocked_mask,
+ batch_size, from_seq_length, to_seq_length, training)
+
+ return contextual_output
diff --git a/bigbird/core/beam_search.py b/bigbird/core/beam_search.py
new file mode 100644
index 0000000..e21c937
--- /dev/null
+++ b/bigbird/core/beam_search.py
@@ -0,0 +1,224 @@
+# Copyright 2020 The BigBird Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Beam search branched from Pegasus.
+
+Original source:
+https://github.com/google-research/pegasus/blob/master/pegasus/layers/beam_search.py
+
+This beam search implementation is designed for TPU usage only and prefers
+flexibility over efficiency. Transformer attention caching is not enabled yet.
+
+Mostly follows implementation in T2T. Several difference to pure beamsearch:
+1. has finished and alive seqs, use 2 * beam_size to grow alive seqs,
+ which makes beam_size=1 doesn't equal greedy.
+2. prefers finished seq over alive seqs.
+3. prefers lower indices when equal probability (though unlikely).
+4. with custom length normalization and constraint.
+
+Notations:
+ B: batch_size, M: beam_size, T: max_decode_len, V: vocab_size, U: undefined
+"""
+# pylint: disable=invalid-name
+
+import tensorflow.compat.v2 as tf
+
+
+def length_normalization(start, alpha, min_len, max_len, out_of_range_penalty):
+ r"""Create length normalization function.
+
+ Combines length penalty from https://arxiv.org/abs/1609.08144,
+ and length constraint from https://www.aclweb.org/anthology/W18-2706.pdf.
+
+ scores = \sum_j log(P_j) / ((start + lengths)/(1 + start))**alpha
+ + out_of_range_penalty * (length > max_len or length < min_len)
+
+ Args:
+ start: int, length normalization start offset.
+ alpha: float, [0, 1.0], length normalization power.
+ min_len: int, minimum decode length.
+ max_len: int, maximum decode lengths.
+ out_of_range_penalty: float, penalty for lengths outside min len and max
+ len. Use a negative number that penalize out of range decodes, does hard
+ constraint if set to -inf.
+
+ Returns:
+ fn(log_probs_BxM, length)->scores_BxM: a function to normalize sum log
+ probabilities of sequence with current decoding lengths.
+ """
+
+ def length_norm_fn(log_probs_BxM, length_int):
+ """Normalize sum log probabilities given a sequence length."""
+ dtype = log_probs_BxM.dtype
+ norm_flt = tf.pow(((start + tf.cast(length_int, dtype)) / (1. + start)),
+ alpha)
+ log_probs_BxM /= norm_flt
+ too_short_bool = tf.less(length_int, min_len)
+ too_long_bool = tf.logical_and(tf.greater(length_int, max_len), max_len > 0)
+ out_of_range_bool = tf.logical_or(too_long_bool, too_short_bool)
+ log_probs_BxM += out_of_range_penalty * tf.cast(out_of_range_bool, dtype)
+ return log_probs_BxM
+
+ return length_norm_fn
+
+
+def beam_search(symbols_to_logits_fn,
+ init_seq_BxT,
+ initial_cache_BxU,
+ vocab_size,
+ beam_size,
+ length_norm_fn,
+ eos_id=1):
+ """Beam search.
+
+ Args:
+ symbols_to_logits_fn: fn(seq_BxT, cache_BxU, i) -> (logits_BxV, cache_BxU)
+ init_seq_BxT: initial sequence ids.
+ initial_cache_BxU: dictionary of tensors with shape BxU.
+ vocab_size: vocabulary size.
+ beam_size: beam size.
+ length_norm_fn: length normalization function.
+ eos_id: end of sequence.
+
+ Returns:
+ Tuple of (beams_BxMxT, scores_BxM). Beam searched sequences and scores.
+ """
+ B, T = init_seq_BxT.shape
+ M, V = beam_size, vocab_size
+ dtype = tf.float32
+ int_dtype = init_seq_BxT.dtype
+
+ def _loop_body(i, alive_seq_BxMxT, alive_log_probs_BxM, alive_cache_BxMxU,
+ finished_seq_BxMxT, finished_scores_BxM):
+ """Beam search loop body."""
+ # Decode one step with beam
+ logits_BMxV, cache_BMxU = symbols_to_logits_fn(
+ _flatten_beam_dim(alive_seq_BxMxT),
+ tf.nest.map_structure(_flatten_beam_dim, alive_cache_BxMxU), i)
+ logits_BxMxV = _unflatten_beam_dim(logits_BMxV, M)
+ new_cache_BxMxU = tf.nest.map_structure(lambda t: _unflatten_beam_dim(t, M),
+ cache_BMxU)
+
+ # select top 2 * beam_size and fill alive and finished.
+ log_probs_BxMxV = logits_BxMxV - tf.reduce_logsumexp(
+ logits_BxMxV, axis=2, keepdims=True)
+ log_probs_BxMxV += tf.expand_dims(alive_log_probs_BxM, axis=2)
+ log_probs_BxMV = tf.reshape(log_probs_BxMxV, [B, -1])
+ new_log_probs_Bx2M, topk_indices_Bx2M = tf.nn.top_k(log_probs_BxMV, k=2 * M)
+ topk_beam_Bx2M = topk_indices_Bx2M // V
+ topk_seq_Bx2MxT, new_cache_Bx2MxU = _gather_nested(
+ [alive_seq_BxMxT, new_cache_BxMxU], topk_beam_Bx2M)
+ topk_ids_Bx2M = topk_indices_Bx2M % V
+ new_seq_Bx2MxT = _update_i(topk_seq_Bx2MxT, topk_ids_Bx2M, i)
+ new_finished_flags_Bx2M = tf.cast(
+ tf.reduce_any(tf.equal(new_seq_Bx2MxT, eos_id), axis=-1), dtype)
+
+ # get new alive
+ _, topk_alive_indices_BxM = tf.nn.top_k(
+ new_log_probs_Bx2M + new_finished_flags_Bx2M * dtype.min, k=M)
+ (alive_seq_BxMxT, alive_log_probs_BxM, alive_cache_BxMxU) = _gather_nested(
+ [new_seq_Bx2MxT, new_log_probs_Bx2M, new_cache_Bx2MxU],
+ topk_alive_indices_BxM)
+
+ # get new finished
+ new_scores_Bx2M = length_norm_fn(new_log_probs_Bx2M, i + 1)
+ new_scores_Bx2M += (1 - new_finished_flags_Bx2M) * dtype.min
+ finished_seq_Bx3MxT = tf.concat([finished_seq_BxMxT, new_seq_Bx2MxT],
+ axis=1)
+ finished_scores_Bx3M = tf.concat([finished_scores_BxM, new_scores_Bx2M],
+ axis=1)
+ _, topk_finished_indices_BxM = tf.nn.top_k(finished_scores_Bx3M, k=M)
+ (finished_seq_BxMxT, finished_scores_BxM) = _gather_nested(
+ [finished_seq_Bx3MxT, finished_scores_Bx3M], topk_finished_indices_BxM)
+
+ return [
+ i + 1, alive_seq_BxMxT, alive_log_probs_BxM, alive_cache_BxMxU,
+ finished_seq_BxMxT, finished_scores_BxM
+ ]
+
+ # initialize.
+ init_i = tf.constant(0, dtype=int_dtype)
+ init_alive_seq_BxMxT = _expand_to_beam_size(init_seq_BxT, M)
+ log_probs_1xM = tf.constant([[0.] + [dtype.min] * (M - 1)], dtype=dtype)
+ init_alive_log_probs_BxM = tf.tile(log_probs_1xM, [B, 1])
+ init_alive_cache_BxMxU = tf.nest.map_structure(
+ lambda t: _expand_to_beam_size(t, M), initial_cache_BxU)
+ init_finished_seq_BxMxT = tf.zeros(tf.shape(init_alive_seq_BxMxT), int_dtype)
+ init_finished_scores_BxM = tf.zeros([B, M], dtype=dtype) + dtype.min
+
+ # run loop.
+ (_, final_alive_seq_BxMxT, final_alive_scores_BxM, _,
+ final_finished_seq_BxMxT, final_finished_scores_BxM) = tf.while_loop(
+ lambda *args: True, # Always do T iterations
+ _loop_body,
+ loop_vars=[
+ init_i, init_alive_seq_BxMxT, init_alive_log_probs_BxM,
+ init_alive_cache_BxMxU, init_finished_seq_BxMxT,
+ init_finished_scores_BxM
+ ],
+ parallel_iterations=1,
+ back_prop=False,
+ maximum_iterations=T,
+ )
+
+ # process finished.
+ final_finished_flag_BxMx1 = tf.reduce_any(
+ tf.equal(final_finished_seq_BxMxT, eos_id), axis=-1, keepdims=True)
+ final_seq_BxMxT = tf.where(
+ tf.tile(final_finished_flag_BxMx1, [1, 1, T]), final_finished_seq_BxMxT,
+ final_alive_seq_BxMxT)
+ final_scores_BxM = tf.where(
+ tf.squeeze(final_finished_flag_BxMx1, axis=-1), final_finished_scores_BxM,
+ final_alive_scores_BxM)
+ return final_seq_BxMxT, final_scores_BxM
+
+
+def _update_i(tensor_BxNxT, updates_BxN, i):
+ B, N, T = tensor_BxNxT.shape
+ tensor_BNxT = tf.reshape(tensor_BxNxT, [-1, T])
+ updates_BN = tf.reshape(updates_BxN, [-1])
+ batch_BN = tf.range(B * N, dtype=tf.int32)
+ i_BN = tf.fill([B * N], i)
+ ind_BNx2 = tf.stack([batch_BN, i_BN], axis=-1)
+ tensor_BNxT = tf.tensor_scatter_nd_update(tensor_BNxT, ind_BNx2, updates_BN)
+ return tf.reshape(tensor_BNxT, [B, N, T])
+
+
+def _expand_to_beam_size(tensor_BxU, beam_size):
+ tensor_Bx1xU = tf.expand_dims(tensor_BxU, axis=1)
+ tile_dims = [1] * tensor_Bx1xU.shape.ndims
+ tile_dims[1] = beam_size
+ tensor_BxMxU = tf.tile(tensor_Bx1xU, tile_dims)
+ return tensor_BxMxU
+
+
+def _flatten_beam_dim(tensor_BxMxU):
+ shape = tensor_BxMxU.shape.as_list()
+ tensor_BMxU = tf.reshape(tensor_BxMxU, [shape[0] * shape[1]] + shape[2:])
+ return tensor_BMxU
+
+
+def _unflatten_beam_dim(tensor_BMxU, M):
+ shape = tensor_BMxU.shape.as_list()
+ tensor_BxMxU = tf.reshape(tensor_BMxU, [shape[0] // M, M] + shape[1:])
+ return tensor_BxMxU
+
+
+def _gather_nested(nested_BxMxU, indices_BxN):
+
+ def _gather_beam(tensor_BxMxU):
+ tensor_BxNxU = tf.gather(tensor_BxMxU, indices_BxN, batch_dims=1, axis=1)
+ return tensor_BxNxU
+
+ return tf.nest.map_structure(_gather_beam, nested_BxMxU)
diff --git a/bigbird/core/decoder.py b/bigbird/core/decoder.py
new file mode 100644
index 0000000..19e0dd5
--- /dev/null
+++ b/bigbird/core/decoder.py
@@ -0,0 +1,554 @@
+# Copyright 2020 The BigBird Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""BigBird Decoder Layers."""
+
+from bigbird.core import attention
+from bigbird.core import beam_search
+from bigbird.core import utils
+import tensorflow.compat.v2 as tf
+
+
+class PrenormDecoderLayer(tf.compat.v1.layers.Layer):
+ """Decoder layer of a transformer in Pegasus style.
+
+ The layer_norm is taken before self-attention.
+ """
+
+ def __init__(self,
+ hidden_size=768,
+ intermediate_size=3072,
+ intermediate_act_fn=utils.gelu,
+ attention_probs_dropout_prob=0.0,
+ hidden_dropout_prob=0.1,
+ initializer_range=0.02,
+ num_attention_heads=12,
+ use_bias=True,
+ name=None):
+ """Constructor of a decoder layer of a transformer in Pegasus style.
+
+ Args:
+ hidden_size: (optional) int. Size of hidden dimension.
+ intermediate_size: (optional) int. Size of intermediate dimension.
+ intermediate_act_fn: optional) Activation function for intermediate layer.
+ attention_probs_dropout_prob: (optional) float. Dropout probability of the
+ attention probabilities.
+ hidden_dropout_prob: (optional) float. Dropout probability of the
+ attention.
+ initializer_range: (optional) float. Range of the weight initializer.
+ num_attention_heads: (optional) int. Number of attention heads.
+ use_bias: (optional) bool. Whether key/query/value uses a bias vector.
+ name: The name scope of this layer.
+ """
+ super(PrenormDecoderLayer, self).__init__(name=name)
+ self.hidden_dropout_prob = hidden_dropout_prob
+
+ # Attention layers
+ attention_head_size = hidden_size // num_attention_heads
+ self.self_attn_layer = attention.MultiHeadedAttentionLayer(
+ "original_full", use_bias=use_bias, name="self",
+ num_attention_heads=num_attention_heads,
+ size_per_head=attention_head_size,
+ initializer_range=initializer_range,
+ attention_probs_dropout_prob=attention_probs_dropout_prob)
+ self.cross_attn_layer = attention.MultiHeadedAttentionLayer(
+ "original_full", use_bias=use_bias, name="encdec",
+ num_attention_heads=num_attention_heads,
+ size_per_head=attention_head_size,
+ initializer_range=initializer_range,
+ attention_probs_dropout_prob=attention_probs_dropout_prob)
+
+ # Dense layers
+ self.self_proj_layer = utils.Dense3dProjLayer(
+ num_attention_heads, attention_head_size,
+ utils.create_initializer(initializer_range), None, "dense", use_bias)
+ self.cross_proj_layer = utils.Dense3dProjLayer(
+ num_attention_heads, attention_head_size,
+ utils.create_initializer(initializer_range), None, "dense", use_bias)
+ self.expand_layer = utils.Dense2dLayer(
+ intermediate_size, utils.create_initializer(initializer_range),
+ intermediate_act_fn, "dense")
+ self.contract_layer = utils.Dense2dLayer(
+ hidden_size, utils.create_initializer(initializer_range),
+ None, "dense")
+
+ # Normalization layer
+ self.first_layer_norm = utils.NormLayer()
+ self.second_layer_norm = utils.NormLayer()
+ self.third_layer_norm = utils.NormLayer()
+
+ @property
+ def trainable_weights(self):
+ self._trainable_weights = (self.self_attn_layer.trainable_weights +
+ self.cross_attn_layer.trainable_weights +
+ self.self_proj_layer.trainable_weights +
+ self.cross_proj_layer.trainable_weights +
+ self.expand_layer.trainable_weights +
+ self.contract_layer.trainable_weights +
+ self.first_layer_norm.trainable_weights +
+ self.second_layer_norm.trainable_weights +
+ self.third_layer_norm.trainable_weights)
+ return self._trainable_weights
+
+ def call(self,
+ layer_input,
+ encoder_outputs,
+ self_attention_mask,
+ attention_mask,
+ cache=None,
+ decode_i=None,
+ training=None):
+ """Implements a decoder layer of a transformer in Pegasus style.
+
+ The layer_norm is taken after self-attention.
+
+ Args:
+ layer_input: float Tensor of shape [batch_size, seq_length, hidden_size].
+ encoder_outputs: tensors with shape [batch_size, input_length,
+ num_hidden_layers, hidden_size]
+ self_attention_mask: bias for decoder self-attention layer. [1, 1,
+ target_length, target_length]
+ attention_mask: bias for encoder-decoder attention layer. [batch_size, 1,
+ 1, input_length]
+ cache: (Used during prediction) A dictionary with tensors containing
+ results of previous attentions. The dictionary must have the items:
+ {"k": tensor with shape
+ [batch_size, max_len, num_attention_heads, size_per_head],
+ "v": tensor with shape
+ [batch_size, max_len, num_attention_heads, size_per_head]}
+ decode_i: (Used during prediction) current location of decoding
+ training: Boolean indicating whether the call is training or inference.
+
+ Returns:
+ float Tensor of shape [batch_size, seq_length, hidden_size].
+
+ Raises:
+ ValueError: Any of the arguments or tensor shapes are invalid.
+ NotImplementedError: For unknown attention type.
+ """
+ with tf.compat.v1.variable_scope("attention"):
+ with tf.compat.v1.variable_scope("self") as sc:
+ normalized_layer_input = self.first_layer_norm(layer_input)
+ self_attention_output = self.self_attn_layer(
+ normalized_layer_input, normalized_layer_input, self_attention_mask,
+ cache=cache, decode_i=decode_i, training=training, scope=sc)
+
+ # Run a linear projection of `hidden_size` then add a residual
+ # with `layer_input`.
+ with tf.compat.v1.variable_scope("output"):
+ self_attention_output = self.self_proj_layer(self_attention_output)
+ self_attention_output = utils.dropout(self_attention_output,
+ self.hidden_dropout_prob,
+ training)
+ self_attention_output = self_attention_output + layer_input
+
+ with tf.compat.v1.variable_scope("encdec") as sc:
+ normalized_self_attention_output = self.second_layer_norm(
+ self_attention_output)
+ attention_output = self.cross_attn_layer(
+ normalized_self_attention_output, encoder_outputs, attention_mask,
+ training=training, scope=sc)
+
+ # Run a linear projection of `hidden_size` then add a residual
+ # with `layer_input`.
+ with tf.compat.v1.variable_scope("encdec_output"):
+ attention_output = self.cross_proj_layer(attention_output)
+ attention_output = utils.dropout(attention_output,
+ self.hidden_dropout_prob,
+ training)
+ attention_output = attention_output + self_attention_output
+
+ # The activation is only applied to the "intermediate" hidden layer.
+ with tf.compat.v1.variable_scope("intermediate"):
+ normalized_attention_output = self.third_layer_norm(attention_output)
+ intermediate_output = self.expand_layer(normalized_attention_output)
+
+ # Down-project back to `hidden_size` then add the residual.
+ with tf.compat.v1.variable_scope("output"):
+ layer_output = self.contract_layer(intermediate_output)
+ layer_output = utils.dropout(layer_output,
+ self.hidden_dropout_prob,
+ training)
+ layer_output = layer_output + attention_output
+ return layer_output
+
+
+class PostnormDecoderLayer(tf.compat.v1.layers.Layer):
+ """Decoder layer of a transformer in BERT style.
+
+ The layer_norm is taken before self-attention.
+ """
+
+ def __init__(self,
+ hidden_size=768,
+ intermediate_size=3072,
+ intermediate_act_fn=utils.gelu,
+ attention_probs_dropout_prob=0.0,
+ hidden_dropout_prob=0.1,
+ initializer_range=0.02,
+ num_attention_heads=12,
+ use_bias=True,
+ name=None):
+ """Constructor of a decoder layer of a transformer in BERT style.
+
+ Args:
+ hidden_size: (optional) int. Size of hidden dimension.
+ intermediate_size: (optional) int. Size of intermediate dimension.
+ intermediate_act_fn: optional) Activation function for intermediate layer.
+ attention_probs_dropout_prob: (optional) float. Dropout probability of the
+ attention probabilities.
+ hidden_dropout_prob: (optional) float. Dropout probability of the
+ attention.
+ initializer_range: (optional) float. Range of the weight initializer.
+ num_attention_heads: (optional) int. Number of attention heads.
+ use_bias: (optional) bool. Whether key/query/value uses a bias vector.
+ name: The name scope of this layer.
+ """
+ super(PostnormDecoderLayer, self).__init__(name=name)
+ self.hidden_dropout_prob = hidden_dropout_prob
+
+ # Attention layers
+ attention_head_size = hidden_size // num_attention_heads
+ self.self_attn_layer = attention.MultiHeadedAttentionLayer(
+ "original_full", use_bias=use_bias, name="self",
+ num_attention_heads=num_attention_heads,
+ size_per_head=attention_head_size,
+ initializer_range=initializer_range,
+ attention_probs_dropout_prob=attention_probs_dropout_prob)
+ self.cross_attn_layer = attention.MultiHeadedAttentionLayer(
+ "original_full", use_bias=use_bias, name="encdec",
+ num_attention_heads=num_attention_heads,
+ size_per_head=attention_head_size,
+ initializer_range=initializer_range,
+ attention_probs_dropout_prob=attention_probs_dropout_prob)
+
+ # Dense layers
+ self.self_proj_layer = utils.Dense3dProjLayer(
+ num_attention_heads, attention_head_size,
+ utils.create_initializer(initializer_range), None, "dense", use_bias)
+ self.cross_proj_layer = utils.Dense3dProjLayer(
+ num_attention_heads, attention_head_size,
+ utils.create_initializer(initializer_range), None, "dense", use_bias)
+ self.expand_layer = utils.Dense2dLayer(
+ intermediate_size, utils.create_initializer(initializer_range),
+ intermediate_act_fn, "dense")
+ self.contract_layer = utils.Dense2dLayer(
+ hidden_size, utils.create_initializer(initializer_range),
+ None, "dense")
+
+ # Normalization layer
+ self.first_layer_norm = utils.NormLayer()
+ self.second_layer_norm = utils.NormLayer()
+ self.third_layer_norm = utils.NormLayer()
+
+ @property
+ def trainable_weights(self):
+ self._trainable_weights = (self.self_attn_layer.trainable_weights +
+ self.cross_attn_layer.trainable_weights +
+ self.self_proj_layer.trainable_weights +
+ self.cross_proj_layer.trainable_weights +
+ self.expand_layer.trainable_weights +
+ self.contract_layer.trainable_weights +
+ self.first_layer_norm.trainable_weights +
+ self.second_layer_norm.trainable_weights +
+ self.third_layer_norm.trainable_weights)
+ return self._trainable_weights
+
+ def call(self,
+ layer_input,
+ encoder_outputs,
+ self_attention_mask,
+ attention_mask,
+ cache=None,
+ decode_i=None,
+ training=None):
+ """Implements a decoder layer of a transformer in BERT style.
+
+ The layer_norm is taken after self-attention.
+
+ Args:
+ layer_input: float Tensor of shape [batch_size, seq_length, hidden_size].
+ encoder_outputs: tensors with shape [batch_size, input_length,
+ num_hidden_layers, hidden_size]
+ self_attention_mask: bias for decoder self-attention layer. [1, 1,
+ target_length, target_length]
+ attention_mask: bias for encoder-decoder attention layer. [batch_size, 1,
+ 1, input_length]
+ cache: (Used during prediction) A dictionary with tensors containing
+ results of previous attentions. The dictionary must have the items:
+ {"k": tensor with shape
+ [batch_size, max_len, num_attention_heads, size_per_head],
+ "v": tensor with shape
+ [batch_size, max_len, num_attention_heads, size_per_head]}
+ decode_i: (Used during prediction) current location of decoding
+ training: Boolean indicating whether the call is training or inference.
+
+ Returns:
+ float Tensor of shape [batch_size, seq_length, hidden_size].
+
+ Raises:
+ ValueError: Any of the arguments or tensor shapes are invalid.
+ NotImplementedError: For unknown attention type.
+ """
+ with tf.compat.v1.variable_scope("attention"):
+ with tf.compat.v1.variable_scope("self") as sc:
+ self_attention_output = self.self_attn_layer(
+ layer_input, layer_input, self_attention_mask,
+ cache=cache, decode_i=decode_i, training=training, scope=sc)
+
+ # Run a linear projection of `hidden_size` then add a residual
+ # with `layer_input`.
+ with tf.compat.v1.variable_scope("output"):
+ self_attention_output = self.self_proj_layer(self_attention_output)
+ self_attention_output = utils.dropout(self_attention_output,
+ self.hidden_dropout_prob,
+ training)
+ self_attention_output = self.first_layer_norm(
+ self_attention_output + layer_input)
+
+ with tf.compat.v1.variable_scope("encdec") as sc:
+ attention_output = self.cross_attn_layer(
+ self_attention_output, encoder_outputs, attention_mask,
+ training=training, scope=sc)
+
+ # Run a linear projection of `hidden_size` then add a residual
+ # with `layer_input`.
+ with tf.compat.v1.variable_scope("encdec_output"):
+ attention_output = self.cross_proj_layer(attention_output)
+ attention_output = utils.dropout(attention_output,
+ self.hidden_dropout_prob,
+ training)
+ attention_output = self.second_layer_norm(
+ attention_output + self_attention_output)
+
+ # The activation is only applied to the "intermediate" hidden layer.
+ with tf.compat.v1.variable_scope("intermediate"):
+ intermediate_output = self.expand_layer(attention_output)
+
+ # Down-project back to `hidden_size` then add the residual.
+ with tf.compat.v1.variable_scope("output"):
+ layer_output = self.contract_layer(intermediate_output)
+ layer_output = utils.dropout(layer_output,
+ self.hidden_dropout_prob,
+ training)
+ layer_output = self.third_layer_norm(layer_output + attention_output)
+ return layer_output
+
+
+class DecoderStack(tf.compat.v1.layers.Layer):
+ """Transformer decoder stack."""
+
+ def __init__(self, params):
+ if params["couple_encoder_decoder"]:
+ name = "encoder"
+ with tf.compat.v1.variable_scope(
+ name, reuse=tf.compat.v1.AUTO_REUSE) as scope:
+ super(DecoderStack, self).__init__(name=name, _scope=scope)
+ else:
+ name = "decoder"
+ super(DecoderStack, self).__init__(name=name)
+
+ self.params = params
+
+ if params["norm_type"] == "prenorm":
+ decoder_class = PrenormDecoderLayer
+ elif params["norm_type"] == "postnorm":
+ decoder_class = PostnormDecoderLayer
+ else:
+ raise NotImplementedError(
+ "Norm type {} is not implemented".format(params["norm_type"]))
+
+ if self.params.get("num_decoder_layers", None) is not None:
+ num_hidden_layers = self.params["num_decoder_layers"]
+ else:
+ num_hidden_layers = self.params["num_hidden_layers"]
+
+ # Decoder layers
+ self.decoder_layers = [
+ decoder_class( # pylint: disable=g-complex-comprehension
+ self.params["hidden_size"],
+ self.params["intermediate_size"],
+ utils.get_activation(self.params["hidden_act"]),
+ self.params["attention_probs_dropout_prob"],
+ self.params["hidden_dropout_prob"],
+ self.params["initializer_range"],
+ self.params["num_attention_heads"],
+ self.params["use_bias"],
+ name="layer_%d" % layer_idx)
+ for layer_idx in range(num_hidden_layers)
+ ]
+
+ # Normalization layer
+ self.layer_norm = utils.NormLayer()
+
+ @property
+ def trainable_weights(self):
+ self._trainable_weights = sum(
+ [layer.trainable_weights for layer in self.decoder_layers],
+ []) + self.layer_norm.trainable_weights
+ return self._trainable_weights
+
+ def call(self,
+ decoder_inputs,
+ self_attention_mask,
+ encoder_outputs,
+ encoder_mask,
+ cache=None,
+ decode_i=None,
+ training=None):
+ """Return the output of the decoder layer stacks.
+
+ Args:
+ decoder_inputs: tensor with shape
+ [batch_size, target_length, hidden_size]
+ self_attention_mask: bias for decoder self-attention layer. [1, 1,
+ target_length, target_length]
+ encoder_outputs: tensors with shape [batch_size, input_length,
+ hidden_size]
+ encoder_mask: bias for encoder-decoder attention layer. [batch_size,
+ input_length]
+ cache: (Used during prediction) A dictionary with tensors containing
+ results of previous attentions. The dictionary must have the items:
+ {"k": tensor with shape
+ [batch_size, max_len, num_attention_heads, size_per_head],
+ "v": tensor with shape
+ [batch_size, max_len, num_attention_heads, size_per_head]}
+ decode_i: (Used during prediction) current location of decoding.
+ training: Boolean indicating whether the call is training or inference.
+
+ Returns:
+ Output of decoder layer stack. A float32 tensor with shape [batch_size,
+ target_length, hidden_size]
+ """
+ # Expand encoder mask to broadcast over num heads and from_seq axis
+ attention_mask = tf.expand_dims(tf.expand_dims(encoder_mask, 1), 1)
+
+ # if self.params["use_gradient_checkpointing"]::
+ # decoder_layer = recompute_gradient(decoder_layer)
+
+ if self.params["norm_type"] == "postnorm":
+ decoder_inputs = self.layer_norm(decoder_inputs)
+
+ layer_output = decoder_inputs
+ for layer in self.decoder_layers:
+ layer_cache = cache[layer.name] if cache is not None else None
+ layer_output = layer(
+ layer_output, encoder_outputs, self_attention_mask, attention_mask,
+ layer_cache, decode_i, training)
+
+ if self.params["norm_type"] == "prenorm":
+ layer_output = self.layer_norm(layer_output)
+
+ return layer_output
+
+
+def create_self_attention_mask(length):
+ with tf.name_scope("decoder_self_attention_mask"):
+ valid_locs = tf.linalg.band_part(tf.ones([length, length]), -1, 0)
+ valid_locs = tf.reshape(valid_locs, [1, 1, length, length])
+ return valid_locs
+
+
+def inplace_update_i(inp_tensor, updates, i):
+ """Inplace update a tensor. B: batch_size, L: tensor length."""
+ batch_size = inp_tensor.shape[0]
+ indices = tf.stack([
+ tf.range(batch_size, dtype=tf.int32),
+ tf.fill([batch_size], tf.cast(i, tf.int32))
+ ], axis=-1)
+ return tf.tensor_scatter_nd_update(inp_tensor, indices, updates)
+
+
+# pylint: disable=invalid-name
+def left2right_decode(symbols_to_logits_fn,
+ start_symbols,
+ context_BxU_dict,
+ batch_size,
+ max_decode_len,
+ vocab_size,
+ beam_size=1,
+ beam_start=5,
+ beam_alpha=0.6,
+ beam_min=0,
+ beam_max=-1,
+ eos_id=1):
+ """left to right decode.
+
+ Notations:
+ B: batch_size, V: vocab_size, T: decode_len, U: undefined dimensions
+
+ Args:
+ symbols_to_logits_fn: logits = fn(decodes, context, i). Shoud take
+ [batch_size, decoded_ids] and return [batch_size, vocab_size].
+ start_symbols: starting ids [batch_size]
+ context_BxU_dict: dict of Tensors.
+ batch_size: int, decode batch size.
+ max_decode_len: int, maximum number of steps to decode.
+ vocab_size: int, output vocab size.
+ beam_size: Number of beams to decode.
+ beam_start: start length for scaling, default to 5.
+ beam_alpha: Length penalty for decoding. Should be between 0 (shorter) and 1
+ (longer), default to 0.6.
+ beam_min: Minimum beam search lengths.
+ beam_max: Maximum beam search lengths. Set -1 to use unlimited.
+ eos_id: end of token id, default to 1.
+
+ Returns:
+ decodes: Tensor[batch, decode_len]
+ """
+ dtype = tf.int32
+ start_symbols = tf.expand_dims(start_symbols, 1)
+ # When beam_size=1, beam_search does not behave exactly like greedy.
+ # This is due to using 2 * beam_size in grow_topk, and keep the top beam_size
+ # ones that haven't reached EOS into alive.
+ # In this case, alpha value for length penalty will take effect.
+ if beam_size == 1:
+
+ def decode_loop(i, decodes_BxT, cache_BxU_dict):
+ logits_BxV = symbols_to_logits_fn(decodes_BxT, cache_BxU_dict, i)
+ decodes_BxT = inplace_update_i(
+ decodes_BxT, tf.argmax(logits_BxV, -1, output_type=tf.int32), i)
+ return i + 1, decodes_BxT, cache_BxU_dict
+
+ def loop_cond(i, decodes_BxT, unused_cache_BxU_dict):
+ finished_B = tf.reduce_any(tf.equal(decodes_BxT, eos_id), axis=1)
+ return tf.logical_and(i < max_decode_len,
+ tf.logical_not(tf.reduce_all(finished_B)))
+
+ init_dec_BxT = tf.concat([tf.cast(start_symbols, dtype=dtype),
+ tf.zeros([batch_size, max_decode_len-1],
+ dtype=dtype)], axis=1)
+ _, decodes, _ = tf.while_loop(
+ loop_cond, decode_loop,
+ [tf.constant(0, dtype=dtype), init_dec_BxT, context_BxU_dict])
+ return decodes
+
+ else:
+
+ def symbols_to_logits_fn_with_sampling(decodes_BxT, states_BxU_dict, i):
+ logits_BxV = symbols_to_logits_fn(decodes_BxT, states_BxU_dict, i)
+ return logits_BxV, states_BxU_dict
+
+ length_norm_fn = beam_search.length_normalization(beam_start, beam_alpha,
+ beam_min, beam_max, -1e3)
+
+ init_dec_BxT = tf.concat([tf.cast(start_symbols, dtype=tf.int32),
+ tf.zeros([batch_size, max_decode_len-1],
+ dtype=tf.int32)], axis=1)
+
+ beams, _ = beam_search.beam_search(
+ symbols_to_logits_fn_with_sampling,
+ init_dec_BxT,
+ context_BxU_dict, vocab_size, beam_size, length_norm_fn, eos_id)
+ return beams[:, 0, :]
diff --git a/bigbird/core/encoder.py b/bigbird/core/encoder.py
new file mode 100644
index 0000000..3e6f0e8
--- /dev/null
+++ b/bigbird/core/encoder.py
@@ -0,0 +1,423 @@
+# Copyright 2020 The BigBird Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""BigBird Encoder Layers."""
+
+from bigbird.core import attention
+from bigbird.core import utils
+import tensorflow.compat.v2 as tf
+
+
+class PrenormEncoderLayer(tf.compat.v1.layers.Layer):
+ """Encoder layer of a transformer in Pegasus style.
+
+ The layer_norm is taken before self-attention.
+ """
+
+ def __init__(self,
+ attention_type,
+ hidden_size=768,
+ intermediate_size=3072,
+ intermediate_act_fn=utils.gelu,
+ attention_probs_dropout_prob=0.0,
+ hidden_dropout_prob=0.1,
+ initializer_range=0.02,
+ num_attention_heads=12,
+ num_rand_blocks=3,
+ block_size=64,
+ use_bias=True,
+ seed=None,
+ name=None):
+ """Constructor of an encoder layer of a transformer in Pegasus style.
+
+ Args:
+ attention_type: Type of attention, needs to be one of ['original_full',
+ 'simulated_sparse', 'block_sparse'].
+ hidden_size: (optional) int. Size of hidden dimension.
+ intermediate_size: (optional) int. Size of intermediate dimension.
+ intermediate_act_fn: optional) Activation function for intermediate layer.
+ attention_probs_dropout_prob: (optional) float. Dropout probability of the
+ attention probabilities.
+ hidden_dropout_prob: (optional) float. Dropout probability of the
+ attention.
+ initializer_range: (optional) float. Range of the weight initializer.
+ num_attention_heads: (optional) int. Number of attention heads.
+ num_rand_blocks: (optional) int. Number of random chunks per row.
+ block_size: (optional) int. size of block in sequence.
+ use_bias: (optional) bool. Whether key/query/value uses a bias vector.
+ seed: (Optional) int. Reandom seed for generating random mask.
+ name: The name scope of this layer.
+ """
+ super(PrenormEncoderLayer, self).__init__(name=name)
+ self.hidden_dropout_prob = hidden_dropout_prob
+
+ # Attention layer
+ attention_head_size = hidden_size // num_attention_heads
+ self.attn_layer = attention.MultiHeadedAttentionLayer(
+ attention_type, num_attention_heads, num_rand_blocks,
+ attention_head_size, initializer_range, block_size, block_size,
+ attention_probs_dropout_prob, use_bias, seed, name="self")
+
+ # Dense layers
+ self.projection_layer = utils.Dense3dProjLayer(
+ num_attention_heads, attention_head_size,
+ utils.create_initializer(initializer_range), None, "dense", use_bias)
+ self.expand_layer = utils.Dense2dLayer(
+ intermediate_size, utils.create_initializer(initializer_range),
+ intermediate_act_fn, "dense")
+ self.contract_layer = utils.Dense2dLayer(
+ hidden_size, utils.create_initializer(initializer_range),
+ None, "dense")
+
+ # Normalization layer
+ self.first_layer_norm = utils.NormLayer()
+ self.second_layer_norm = utils.NormLayer()
+
+ @property
+ def trainable_weights(self):
+ self._trainable_weights = (self.attn_layer.trainable_weights +
+ self.projection_layer.trainable_weights +
+ self.expand_layer.trainable_weights +
+ self.contract_layer.trainable_weights +
+ self.first_layer_norm.trainable_weights +
+ self.second_layer_norm.trainable_weights)
+ return self._trainable_weights
+
+ def call(self,
+ layer_input,
+ attention_mask=None,
+ band_mask=None,
+ from_mask=None,
+ to_mask=None,
+ input_blocked_mask=None,
+ training=None):
+ """Implements a encoder layer of a transformer in Pegasus style.
+
+ Args:
+ layer_input: float Tensor of shape [batch_size, seq_length, hidden_size].
+ attention_mask: (optional) int32 Tensor of shape [batch_size,
+ seq_length, seq_length]. The values should be 1 or 0. The
+ attention scores will effectively be set to -infinity for any positions
+ in the mask that are 0, and will be unchanged for positions that are 1.
+ band_mask: (optional) int32 Tensor of shape [batch_size, 1,
+ seq_length//block_size-4, block_size, 3*block_size].
+ The values should be 1 or 0. The attention scores will effectively be
+ set to -infinity for any positions in the mask that are 0, and will be
+ unchanged for positions that are 1.
+ from_mask: (optional) int32 Tensor of shape [batch_size, 1,
+ seq_length, 1]. The values should be 1 or 0. The
+ attention scores will effectively be set to -infinity for any positions
+ in the mask that are 0, and will be unchanged for positions that are 1.
+ to_mask: (optional) int32 Tensor of shape [batch_size, 1, 1,
+ seq_length]. The values should be 1 or 0. The
+ attention scores will effectively be set to -infinity for any positions
+ in the mask that are 0, and will be unchanged for positions that are 1.
+ input_blocked_mask: (optional) int32 Tensor of shape [batch_size,
+ seq_length//block_size, block_size]. Same as from/to_mask, just
+ reshaped.
+ training: Boolean indicating whether the call is training or inference.
+
+ Returns:
+ float Tensor of shape [batch_size, seq_length, hidden_size].
+
+ Raises:
+ ValueError: Any of the arguments or tensor shapes are invalid.
+ NotImplementedError: For unknown attention type.
+ """
+
+ with tf.compat.v1.variable_scope("attention"):
+ with tf.compat.v1.variable_scope("self") as sc:
+ normalized_layer_input = self.first_layer_norm(layer_input)
+ attention_output = self.attn_layer(
+ normalized_layer_input, normalized_layer_input,
+ attention_mask, band_mask, from_mask, to_mask,
+ input_blocked_mask, input_blocked_mask, training, scope=sc)
+
+ # Run a linear projection of `hidden_size` then add a residual
+ # with `layer_input`.
+ with tf.compat.v1.variable_scope("output"):
+ attention_output = self.projection_layer(attention_output)
+ attention_output = utils.dropout(attention_output,
+ self.hidden_dropout_prob,
+ training)
+ attention_output = attention_output + layer_input
+
+ # The activation is only applied to the "intermediate" hidden layer.
+ with tf.compat.v1.variable_scope("intermediate"):
+ normalized_attention_output = self.second_layer_norm(attention_output)
+ intermediate_output = self.expand_layer(normalized_attention_output)
+
+ # Down-project back to `hidden_size` then add the residual.
+ with tf.compat.v1.variable_scope("output"):
+ layer_output = self.contract_layer(intermediate_output)
+ layer_output = utils.dropout(layer_output,
+ self.hidden_dropout_prob,
+ training)
+ layer_output = layer_output + attention_output
+ return layer_output
+
+
+class PostnormEncoderLayer(tf.compat.v1.layers.Layer):
+ """Encoder layer of a transformer in BERT style.
+
+ The layer_norm is taken after self-attention.
+ """
+
+ def __init__(self,
+ attention_type,
+ hidden_size=768,
+ intermediate_size=3072,
+ intermediate_act_fn=utils.gelu,
+ attention_probs_dropout_prob=0.0,
+ hidden_dropout_prob=0.1,
+ initializer_range=0.02,
+ num_attention_heads=12,
+ num_rand_blocks=3,
+ block_size=64,
+ use_bias=True,
+ seed=None,
+ name=None):
+ """Constructor of an encoder layer of a transformer in BERT style.
+
+ Args:
+ attention_type: Type of attention, needs to be one of ['original_full',
+ 'simulated_sparse', 'block_sparse'].
+ hidden_size: (optional) int. Size of hidden dimension.
+ intermediate_size: (optional) int. Size of intermediate dimension.
+ intermediate_act_fn: optional) Activation function for intermediate layer.
+ attention_probs_dropout_prob: (optional) float. Dropout probability of the
+ attention probabilities.
+ hidden_dropout_prob: (optional) float. Dropout probability of the
+ attention.
+ initializer_range: (optional) float. Range of the weight initializer.
+ num_attention_heads: (optional) int. Number of attention heads.
+ num_rand_blocks: (optional) int. Number of random chunks per row.
+ block_size: (optional) int. size of block in sequence.
+ use_bias: (optional) bool. Whether key/query/value uses a bias vector.
+ seed: (Optional) int. Reandom seed for generating random mask.
+ name: The name scope of this layer.
+ """
+ super(PostnormEncoderLayer, self).__init__(name=name)
+ self.hidden_dropout_prob = hidden_dropout_prob
+
+ # Attention layer
+ attention_head_size = hidden_size // num_attention_heads
+ self.attn_layer = attention.MultiHeadedAttentionLayer(
+ attention_type, num_attention_heads, num_rand_blocks,
+ attention_head_size, initializer_range, block_size, block_size,
+ attention_probs_dropout_prob, use_bias, seed, name="self")
+
+ # Dense layers
+ self.projection_layer = utils.Dense3dProjLayer(
+ num_attention_heads, attention_head_size,
+ utils.create_initializer(initializer_range), None, "dense", use_bias)
+ self.expand_layer = utils.Dense2dLayer(
+ intermediate_size, utils.create_initializer(initializer_range),
+ intermediate_act_fn, "dense")
+ self.contract_layer = utils.Dense2dLayer(
+ hidden_size, utils.create_initializer(initializer_range),
+ None, "dense")
+
+ # Normalization layer
+ self.first_layer_norm = utils.NormLayer()
+ self.second_layer_norm = utils.NormLayer()
+
+ @property
+ def trainable_weights(self):
+ self._trainable_weights = (self.attn_layer.trainable_weights +
+ self.projection_layer.trainable_weights +
+ self.expand_layer.trainable_weights +
+ self.contract_layer.trainable_weights +
+ self.first_layer_norm.trainable_weights +
+ self.second_layer_norm.trainable_weights)
+ return self._trainable_weights
+
+ def call(self,
+ layer_input,
+ attention_mask=None,
+ band_mask=None,
+ from_mask=None,
+ to_mask=None,
+ input_blocked_mask=None,
+ training=None):
+ """Implements a encoder layer of a transformer in BERT style.
+
+ Args:
+ layer_input: float Tensor of shape [batch_size, seq_length, hidden_size].
+ attention_mask: (optional) int32 Tensor of shape [batch_size,
+ seq_length, seq_length]. The values should be 1 or 0. The
+ attention scores will effectively be set to -infinity for any positions
+ in the mask that are 0, and will be unchanged for positions that are 1.
+ band_mask: (optional) int32 Tensor of shape [batch_size, 1,
+ seq_length//block_size-4, block_size, 3*block_size].
+ The values should be 1 or 0. The attention scores will effectively be
+ set to -infinity for any positions in the mask that are 0, and will be
+ unchanged for positions that are 1.
+ from_mask: (optional) int32 Tensor of shape [batch_size, 1,
+ seq_length, 1]. The values should be 1 or 0. The
+ attention scores will effectively be set to -infinity for any positions
+ in the mask that are 0, and will be unchanged for positions that are 1.
+ to_mask: (optional) int32 Tensor of shape [batch_size, 1, 1,
+ seq_length]. The values should be 1 or 0. The
+ attention scores will effectively be set to -infinity for any positions
+ in the mask that are 0, and will be unchanged for positions that are 1.
+ input_blocked_mask: (optional) int32 Tensor of shape [batch_size,
+ seq_length//block_size, block_size]. Same as from/to_mask, just
+ reshaped.
+ training: Boolean indicating whether the call is training or inference.
+
+ Returns:
+ float Tensor of shape [batch_size, seq_length, hidden_size].
+
+ Raises:
+ ValueError: Any of the arguments or tensor shapes are invalid.
+ NotImplementedError: For unknown attention type.
+ """
+
+ with tf.compat.v1.variable_scope("attention"):
+ with tf.compat.v1.variable_scope("self") as sc:
+ attention_output = self.attn_layer(
+ layer_input, layer_input,
+ attention_mask, band_mask, from_mask, to_mask,
+ input_blocked_mask, input_blocked_mask, training, scope=sc)
+
+ # Run a linear projection of `hidden_size` then add a residual
+ # with `layer_input`.
+ with tf.compat.v1.variable_scope("output"):
+ attention_output = self.projection_layer(attention_output)
+ attention_output = utils.dropout(attention_output,
+ self.hidden_dropout_prob,
+ training)
+ attention_output = self.first_layer_norm(attention_output + layer_input)
+
+ # The activation is only applied to the "intermediate" hidden layer.
+ with tf.compat.v1.variable_scope("intermediate"):
+ intermediate_output = self.expand_layer(attention_output)
+
+ # Down-project back to `hidden_size` then add the residual.
+ with tf.compat.v1.variable_scope("output"):
+ layer_output = self.contract_layer(intermediate_output)
+ layer_output = utils.dropout(layer_output,
+ self.hidden_dropout_prob,
+ training)
+ layer_output = self.second_layer_norm(layer_output + attention_output)
+ return layer_output
+
+
+class EncoderStack(tf.compat.v1.layers.Layer):
+ """Transformer encoder stack."""
+
+ def __init__(self, params):
+ name = "encoder"
+ super(EncoderStack, self).__init__(name=name)
+ self.params = params
+
+ if params["norm_type"] == "prenorm":
+ encoder_class = PrenormEncoderLayer
+ elif params["norm_type"] == "postnorm":
+ encoder_class = PostnormEncoderLayer
+ else:
+ raise NotImplementedError(
+ "Norm type {} is not implemented".format(params["norm_type"]))
+
+ # Encoder layers
+ self.encoder_layers = [
+ encoder_class( # pylint: disable=g-complex-comprehension
+ self.params["attention_type"],
+ self.params["hidden_size"],
+ self.params["intermediate_size"],
+ utils.get_activation(self.params["hidden_act"]),
+ self.params["attention_probs_dropout_prob"],
+ self.params["hidden_dropout_prob"],
+ self.params["initializer_range"],
+ self.params["num_attention_heads"],
+ self.params["num_rand_blocks"],
+ self.params["block_size"],
+ self.params["use_bias"],
+ seed=layer_idx,
+ name="layer_%d" % layer_idx)
+ for layer_idx in range(self.params["num_hidden_layers"])
+ ]
+
+ # Normalization layer
+ self.layer_norm = utils.NormLayer()
+
+ @property
+ def trainable_weights(self):
+ self._trainable_weights = sum(
+ [layer.trainable_weights for layer in self.encoder_layers],
+ []) + self.layer_norm.trainable_weights
+ return self._trainable_weights
+
+ def call(self,
+ encoder_inputs,
+ encoder_inputs_mask,
+ training=None):
+ """Return the output of the decoder layer stacks.
+
+ Args:
+ encoder_inputs: tensor with shape
+ [batch_size, input_length, hidden_size]
+ encoder_inputs_mask: Mask for enccoder input. [batch_size, input_length]
+ training: Boolean indicating whether the call is training or inference.
+
+ Returns:
+ Finaly layer encoder output. float tensor with shape
+ [batch_size, input_length, hidden_size]
+ """
+ encoder_shape = utils.get_shape_list(encoder_inputs, expected_rank=3)
+ batch_size = encoder_shape[0]
+ encoder_length = encoder_shape[1]
+
+ if self.params["attention_type"] == "block_sparse":
+ # reshape and cast for blocking
+ encoder_block_size = self.params["block_size"]
+ blocked_encoder_mask = tf.reshape(
+ encoder_inputs_mask,
+ (batch_size, encoder_length//encoder_block_size, encoder_block_size))
+ encoder_from_mask = tf.reshape(encoder_inputs_mask,
+ (batch_size, 1, encoder_length, 1))
+ encoder_to_mask = tf.reshape(encoder_inputs_mask,
+ (batch_size, 1, 1, encoder_length))
+
+ # create band padding
+ attention_mask = None
+ band_mask = attention.create_band_mask_from_inputs(
+ blocked_encoder_mask, blocked_encoder_mask)
+
+ else:
+ blocked_encoder_mask = None
+ encoder_to_mask = None
+ encoder_from_mask = None
+
+ attention_mask = attention.create_attention_mask_from_input_mask(
+ encoder_inputs_mask, encoder_inputs_mask)
+ band_mask = None
+
+ # if self.params["use_gradient_checkpointing"]:
+ # encoder_layer = recompute_gradient(encoder_layer)
+
+ if self.params["norm_type"] == "postnorm":
+ encoder_inputs = self.layer_norm(encoder_inputs)
+
+ layer_output = encoder_inputs
+ for layer in self.encoder_layers:
+ layer_output = layer(
+ layer_output, attention_mask, band_mask,
+ encoder_from_mask, encoder_to_mask, blocked_encoder_mask, training)
+
+ if self.params["norm_type"] == "prenorm":
+ layer_output = self.layer_norm(layer_output)
+
+ return layer_output
diff --git a/bigbird/core/flags.py b/bigbird/core/flags.py
new file mode 100644
index 0000000..ca425a1
--- /dev/null
+++ b/bigbird/core/flags.py
@@ -0,0 +1,313 @@
+# Copyright 2020 The BigBird Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Common flag definitions."""
+
+import json
+import sys
+
+from absl import flags
+from absl import logging
+import bigbird
+import tensorflow.compat.v2 as tf
+
+import sentencepiece as spm
+
+# pylint: disable=g-import-not-at-top
+if sys.version_info >= (3, 9):
+ import importlib.resources as importlib_resources
+else:
+ import importlib_resources
+
+
+############################### FLAGS UTILS ####################################
+
+FLAGS = flags.FLAGS
+DEFINE_bool = flags.DEFINE_bool
+DEFINE_enum = flags.DEFINE_enum
+DEFINE_float = flags.DEFINE_float
+DEFINE_integer = flags.DEFINE_integer
+DEFINE_string = flags.DEFINE_string
+
+
+# Flag names are globally defined! So in general, we need to be
+# careful to pick names that are unlikely to be used by other libraries.
+# If there is a conflict, we'll get an error at import time.
+
+# Basic model config flags
+
+flags.DEFINE_float(
+ "attention_probs_dropout_prob", 0.1,
+ "The dropout probability for attention coefficients when using original.")
+flags.DEFINE_string(
+ "hidden_act", "gelu",
+ "The non-linear activation function (function or string) in the encoder "
+ "and pooler.")
+flags.DEFINE_float(
+ "hidden_dropout_prob", 0.1,
+ "The dropout probability for all fully connected layers in the embeddings, "
+ "encoder, decoder, and pooler.")
+flags.DEFINE_integer(
+ "hidden_size", 768,
+ "Size of the transformer layers and the pooler layer.")
+flags.DEFINE_float(
+ "initializer_range", 0.02,
+ "The stdev of the truncated_normal_initializer for initializing all "
+ "weight matrices.")
+flags.DEFINE_integer(
+ "intermediate_size", 3072,
+ "The size of intermediate (i.e. feed-forward) layer in the Transformer.")
+flags.DEFINE_integer(
+ "max_position_embeddings", 4096,
+ "The size position embeddings of matrix, which dictates the maximum"
+ "length for which the model can be run.")
+flags.DEFINE_integer(
+ "num_attention_heads", 12,
+ "Number of attention heads for each attention layer in the Transformer.")
+flags.DEFINE_integer(
+ "num_hidden_layers", 12,
+ "Number of hidden layers in the model (same for encoder and decoder).")
+flags.DEFINE_integer(
+ "type_vocab_size", 2,
+ "The vocabulary size of the `token_type_ids`.")
+flags.DEFINE_bool(
+ "use_bias", True,
+ "Whether to use bias for key/query/value.")
+flags.DEFINE_bool(
+ "rescale_embedding", False,
+ "Whether to rescale word embedding by hidden dimensions.")
+flags.DEFINE_string(
+ "scope", "bert",
+ "Variable scope name.")
+flags.DEFINE_string(
+ "vocab_model_file", "gpt2",
+ "The sentence piece model for vocabulary. Shortcuts for standard "
+ "gpt2 and pegasus vocabs are their name respectively.")
+
+# Simulated and Block attention settings
+
+flags.DEFINE_enum(
+ "attention_type", "block_sparse",
+ ["original_full", "simulated_sparse", "block_sparse"],
+ "Selecting attention implementation. "
+ "'original_full': full attention from original bert. "
+ "'simulated_sparse': simulated sparse attention. "
+ "'block_sparse': blocked implementation of sparse attention.")
+flags.DEFINE_enum(
+ "norm_type", "postnorm",
+ ["prenorm", "postnorm"],
+ "Selecting when to apply layer-norm. "
+ "'prenorm': Before attention layer, e.g. Pegasus. "
+ "'postnorm': After attention layer, e.g. Bert.")
+flags.DEFINE_integer(
+ "block_size", 16,
+ "The block size for the attention mask.")
+flags.DEFINE_integer(
+ "num_rand_blocks", 3,
+ "Number of random blocks per row.")
+
+# Adaptive optimizer configs
+
+flags.DEFINE_float(
+ "weight_decay_rate", 0.01,
+ "L2 penalty as weight decay to be used.")
+
+flags.DEFINE_float(
+ "optimizer_beta1", 0.9,
+ "The exponential decay rate for the 1st moment estimates.")
+
+flags.DEFINE_float(
+ "optimizer_beta2", 0.999,
+ "The exponential decay rate for the 2nd moment estimates.")
+
+flags.DEFINE_float(
+ "optimizer_epsilon", 1e-6,
+ "Adaptivty trade-off parameter.")
+
+# TPU settings
+
+flags.DEFINE_bool(
+ "use_tpu", False,
+ "Whether to use TPU or GPU/CPU.")
+
+flags.DEFINE_string(
+ "tpu_name", None,
+ "The Cloud TPU to use for training. This should be either the name "
+ "used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 "
+ "url.")
+
+flags.DEFINE_string(
+ "tpu_zone", None,
+ "[Optional] GCE zone where the Cloud TPU is located in. If not "
+ "specified, we will attempt to automatically detect the GCE project from "
+ "metadata.")
+
+flags.DEFINE_string(
+ "tpu_job_name", None,
+ "Name of TPU worker, if anything other than 'tpu_worker'")
+
+flags.DEFINE_string(
+ "gcp_project", None,
+ "[Optional] Project name for the Cloud TPU-enabled project. If not "
+ "specified, we will attempt to automatically detect the GCE project from "
+ "metadata.")
+
+flags.DEFINE_string(
+ "master", None,
+ "[Optional] TensorFlow master URL.")
+
+flags.DEFINE_integer(
+ "num_tpu_cores", 8,
+ "Only used if `use_tpu` is True. Total number of TPU cores to use.")
+
+flags.DEFINE_string(
+ "iterations_per_loop", "1000",
+ "How many steps to make in each estimator call.")
+
+
+def as_dictionary():
+ """Get current config from flag."""
+
+ # Resolve vocab file location from hotword
+ if FLAGS.vocab_model_file == "gpt2":
+ FLAGS.vocab_model_file = str(importlib_resources.files(bigbird).joinpath(
+ "vocab/gpt2.model"))
+ elif FLAGS.vocab_model_file == "pegasus":
+ FLAGS.vocab_model_file = str(importlib_resources.files(bigbird).joinpath(
+ "vocab/pegasus.model"))
+
+ config = {
+ # transformer basic configs
+ "attention_probs_dropout_prob": FLAGS.attention_probs_dropout_prob,
+ "hidden_act": FLAGS.hidden_act,
+ "hidden_dropout_prob": FLAGS.hidden_dropout_prob,
+ "hidden_size": FLAGS.hidden_size,
+ "initializer_range": FLAGS.initializer_range,
+ "intermediate_size": FLAGS.intermediate_size,
+ "max_position_embeddings": FLAGS.max_position_embeddings,
+ "num_attention_heads": FLAGS.num_attention_heads,
+ "num_hidden_layers": FLAGS.num_hidden_layers,
+ "type_vocab_size": FLAGS.type_vocab_size,
+ "scope": FLAGS.scope,
+ "use_bias": FLAGS.use_bias,
+ "rescale_embedding": FLAGS.rescale_embedding,
+ "vocab_model_file": FLAGS.vocab_model_file,
+ # sparse mask configs
+ "attention_type": FLAGS.attention_type,
+ "norm_type": FLAGS.norm_type,
+ "block_size": FLAGS.block_size,
+ "num_rand_blocks": FLAGS.num_rand_blocks,
+ # common bert configs
+ "data_dir": FLAGS.data_dir,
+ "output_dir": FLAGS.output_dir,
+ "init_checkpoint": FLAGS.init_checkpoint,
+ "max_encoder_length": FLAGS.max_encoder_length,
+ "substitute_newline": FLAGS.substitute_newline,
+ "do_train": FLAGS.do_train,
+ "do_eval": FLAGS.do_eval,
+ "do_export": FLAGS.do_export,
+ "train_batch_size": FLAGS.train_batch_size,
+ "eval_batch_size": FLAGS.eval_batch_size,
+ "optimizer": FLAGS.optimizer,
+ "learning_rate": FLAGS.learning_rate,
+ "num_train_steps": FLAGS.num_train_steps,
+ "num_warmup_steps": FLAGS.num_warmup_steps,
+ "save_checkpoints_steps": FLAGS.save_checkpoints_steps,
+ "weight_decay_rate": FLAGS.weight_decay_rate,
+ "optimizer_beta1": FLAGS.optimizer_beta1,
+ "optimizer_beta2": FLAGS.optimizer_beta2,
+ "optimizer_epsilon": FLAGS.optimizer_epsilon,
+ # TPU settings
+ "use_tpu": FLAGS.use_tpu,
+ "tpu_name": FLAGS.tpu_name,
+ "tpu_zone": FLAGS.tpu_zone,
+ "tpu_job_name": FLAGS.tpu_job_name,
+ "gcp_project": FLAGS.gcp_project,
+ "master": FLAGS.master,
+ "num_tpu_cores": FLAGS.num_tpu_cores,
+ "iterations_per_loop": FLAGS.iterations_per_loop,
+ }
+
+ # pretraining dedicated flags
+ if hasattr(FLAGS, "max_predictions_per_seq"):
+ config["max_predictions_per_seq"] = FLAGS.max_predictions_per_seq
+ if hasattr(FLAGS, "masked_lm_prob"):
+ config["masked_lm_prob"] = FLAGS.masked_lm_prob
+ if hasattr(FLAGS, "max_eval_steps"):
+ config["max_eval_steps"] = FLAGS.max_eval_steps
+ if hasattr(FLAGS, "preprocessed_data"):
+ config["preprocessed_data"] = FLAGS.preprocessed_data
+ if hasattr(FLAGS, "use_nsp"):
+ config["use_nsp"] = FLAGS.use_nsp
+
+ # classifier dedicated flags
+ if hasattr(FLAGS, "num_labels"):
+ config["num_labels"] = FLAGS.num_labels
+
+ # summarization dedicated flags
+ if hasattr(FLAGS, "max_decoder_length"):
+ config["max_decoder_length"] = FLAGS.max_decoder_length
+ if hasattr(FLAGS, "trainable_bias"):
+ config["trainable_bias"] = FLAGS.trainable_bias
+ if hasattr(FLAGS, "couple_encoder_decoder"):
+ config["couple_encoder_decoder"] = FLAGS.couple_encoder_decoder
+ if hasattr(FLAGS, "beam_size"):
+ config["beam_size"] = FLAGS.beam_size
+ if hasattr(FLAGS, "alpha"):
+ config["alpha"] = FLAGS.alpha
+ if hasattr(FLAGS, "label_smoothing"):
+ config["label_smoothing"] = FLAGS.label_smoothing
+
+ # calculate vocab
+ sp_model = spm.SentencePieceProcessor()
+ sp_proto = tf.io.gfile.GFile(config["vocab_model_file"], "rb").read()
+ sp_model.LoadFromSerializedProto(sp_proto)
+ vocab_size = sp_model.GetPieceSize()
+ config["vocab_size"] = vocab_size
+
+ return config
+
+
+def save(path):
+ """Save current flag config."""
+ config = as_dictionary()
+ with tf.io.gfile.GFile(path, "w") as f:
+ json.dump(config, f, indent=4, sort_keys=True)
+
+ # log flags
+ max_len = max([len(ii) for ii in config.keys()])
+ fmt_string = "\t%" + str(max_len) + "s : %s"
+ logging.info("Arguments:")
+ for key, value in sorted(config.items()):
+ logging.info(fmt_string, key, value)
+
+ return config
+
+
+def load(path):
+ """Set flag from saved config."""
+
+ with tf.io.gfile.GFile(path) as f:
+ config = json.load(f)
+
+ # log and set flags
+ max_len = max([len(ii) for ii in config.keys()])
+ fmt_string = "\t%" + str(max_len) + "s : %s"
+ logging.info("Arguments:")
+ for key, value in config.items():
+ if hasattr(FLAGS, key):
+ logging.info(fmt_string, key, value)
+ setattr(FLAGS, key, value)
+
+ return config
diff --git a/bigbird/core/modeling.py b/bigbird/core/modeling.py
new file mode 100644
index 0000000..14e366b
--- /dev/null
+++ b/bigbird/core/modeling.py
@@ -0,0 +1,436 @@
+# Copyright 2020 The BigBird Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""The main BigBird model and related functions."""
+
+import copy
+
+from bigbird.core import decoder
+from bigbird.core import encoder
+from bigbird.core import utils
+import tensorflow.compat.v2 as tf
+
+
+class BertModel(tf.compat.v1.layers.Layer):
+ """BERT model ("Bidirectional Encoder Representations from Transformers").
+
+ Example usage:
+
+ ```python
+ # Already been converted into SentencePiece token ids
+ input_ids = tf.constant([[31, 51, 99], [15, 5, 0]])
+ token_type_ids = tf.constant([[0, 0, 1], [0, 2, 0]])
+
+ params = utils.BigBirdConfig(vocab_size=32000, hidden_size=512,
+ num_hidden_layers=8, num_attention_heads=6, intermediate_size=1024)
+
+ model = modeling.BertModel(params, train=True)
+
+ _, pooled_output = model(input_ids=input_ids, token_type_ids=token_type_ids)
+
+ label_embeddings = tf.get_variable(...)
+ logits = tf.matmul(pooled_output, label_embeddings)
+ ...
+ ```
+ """
+
+ def __init__(self, params):
+ """Constructor for BertModel.
+
+ Args:
+ params: `BigBirdConfig` dictionary.
+ """
+ self.params = copy.deepcopy(params)
+ self.scope = params["scope"]
+
+ with tf.compat.v1.variable_scope(
+ self.scope, reuse=tf.compat.v1.AUTO_REUSE) as vs:
+ self.embeder = utils.EmbeddingLayer(
+ vocab_size=self.params["vocab_size"],
+ emb_dim=self.params["hidden_size"],
+ initializer=utils.create_initializer(
+ self.params["initializer_range"]),
+ scale_emb=self.params["rescale_embedding"],
+ use_token_type=True,
+ num_token_types=self.params["type_vocab_size"],
+ use_position_embeddings=True,
+ max_position_embeddings=self.params["max_position_embeddings"],
+ dropout_prob=self.params["hidden_dropout_prob"])
+ self.encoder = encoder.EncoderStack(self.params)
+ self.pooler = tf.compat.v1.layers.Dense(
+ units=self.params["hidden_size"],
+ activation=tf.tanh,
+ kernel_initializer=utils.create_initializer(
+ self.params["initializer_range"]),
+ name="pooler/dense")
+ super(BertModel, self).__init__(name=self.scope, _scope=vs)
+
+ @property
+ def trainable_weights(self):
+ self._trainable_weights = (self.embeder.trainable_weights +
+ self.encoder.trainable_weights +
+ self.pooler.trainable_weights)
+ return self._trainable_weights
+
+ def call(self,
+ input_ids,
+ token_type_ids=None,
+ training=None):
+ """Constructor for BertModel.
+
+ Args:
+ input_ids: int32 Tensor of shape [batch_size, seq_length].
+ token_type_ids: (optional) int32 Tensor of shape [batch_size, seq_length].
+ training: Boolean indicating whether the call is training or inference.
+
+ Returns:
+ sequence_output: Tensor of shape [batch_size, seq_length, hidden_size]
+ pooled_output: Tensor of shape [batch_size, hidden_size]
+
+ Raises:
+ ValueError: The config is invalid or one of the input tensor shapes
+ is invalid.
+ """
+ if token_type_ids is None:
+ token_type_ids = tf.zeros_like(input_ids, dtype=tf.int32)
+
+ # Perform embedding lookup on the word ids.
+ embedding_output = self.embeder(input_ids,
+ self.params["max_encoder_length"],
+ token_type_ids=token_type_ids,
+ training=training)
+
+ # Generate mask.
+ input_mask = tf.where(input_ids > 0,
+ tf.ones_like(input_ids), tf.zeros_like(input_ids))
+
+ # Run the stacked transformer.
+ sequence_output = self.encoder(embedding_output, input_mask, training)
+
+ # The "pooler" converts the encoded sequence tensor of shape
+ # [batch_size, seq_length, hidden_size] to a tensor of shape
+ # [batch_size, hidden_size]. This is necessary for segment-level
+ # (or segment-pair-level) classification tasks where we need a fixed
+ # dimensional representation of the segment.
+ first_token_tensor = sequence_output[:, 0, :]
+ # We "pool" the model by simply taking the hidden state corresponding
+ # to the first token. We assume that this has been pre-trained
+ pooled_output = self.pooler(first_token_tensor)
+
+ return sequence_output, pooled_output
+
+
+class TransformerModel(tf.compat.v1.layers.Layer):
+ """Encoder-Decoder transformer model.
+
+ Example usage:
+
+ ```python
+ # Already been converted into SentencePiece token ids
+ input_ids = tf.constant([[31, 51, 99], [15, 5, 0]])
+ target_ids = tf.constant([[43, 76, 38], [56, 8, 0]])
+
+ params = utils.BigBirdConfig(vocab_size=32000, hidden_size=512,
+ num_hidden_layers=8, num_attention_heads=6, intermediate_size=1024)
+
+ model = modeling.TransformerModel(params, train=True)
+
+ predictions, _ = model(input_ids=input_ids, target_ids=target_ids)
+
+ log_probs, logits, pred_ids = predictions
+ ...
+ ```
+ """
+
+ def __init__(self, params):
+ """Constructor for TransformerModel.
+
+ Args:
+ params: `BigBirdConfig` dictionary.
+ """
+ self.params = copy.deepcopy(params)
+ self.scope = params["scope"]
+
+ with tf.compat.v1.variable_scope(
+ self.scope, reuse=tf.compat.v1.AUTO_REUSE) as vs:
+ self.embeder = utils.EmbeddingLayer(
+ vocab_size=self.params["vocab_size"],
+ emb_dim=self.params["hidden_size"],
+ initializer=utils.create_initializer(
+ self.params["initializer_range"]),
+ scale_emb=self.params["rescale_embedding"],
+ use_token_type=False,
+ num_token_types=None,
+ use_position_embeddings=True,
+ max_position_embeddings=self.params["max_position_embeddings"],
+ dropout_prob=self.params["hidden_dropout_prob"])
+ self.encoder = encoder.EncoderStack(self.params)
+ self.decoder = decoder.DecoderStack(self.params)
+ super(TransformerModel, self).__init__(name=self.scope, _scope=vs)
+
+ @property
+ def trainable_weights(self):
+ self._trainable_weights = (self.embeder.trainable_weights +
+ self.encoder.trainable_weights +
+ self.decoder.trainable_weights)
+ return self._trainable_weights
+
+ def _encode(self, input_ids, training=None):
+ """Generate continuous representation for ids.
+
+ Args:
+ input_ids: Int tensor with shape [batch_size, input_length].
+ training: Boolean indicating whether the call is training or inference.
+
+ Returns:
+ A float tensors of shape
+ [batch_size, input_length, hidden_size].
+ """
+ # Perform embedding lookup on the word ids.
+ input_embs = self.embeder(
+ input_ids, self.params["max_encoder_length"], training=training)
+
+ # Generate mask.
+ input_mask = tf.where(input_ids > 0,
+ tf.ones_like(input_ids), tf.zeros_like(input_ids))
+
+ # Run the stacked transformer.
+ encoder_output = self.encoder(input_embs, input_mask, training)
+
+ return encoder_output, input_mask
+
+ def _get_start_token_ids(self, tensor_for_shape):
+ start_token_id = 2
+ batch_size = utils.get_shape_list(tensor_for_shape)[0]
+ return tf.ones([batch_size], dtype=tf.int32) * start_token_id
+
+ def get_inputs_from_targets(self, targets, start_token_ids):
+ """Converts target ids to input ids, i.e. adds and removes last."""
+ length = tf.math.count_nonzero(targets, axis=1, dtype=tf.int32)
+ # Add start token ids.
+ inputs = tf.concat([tf.expand_dims(start_token_ids, axis=1), targets], 1)
+ # Remove from the input.
+ mask = tf.sequence_mask(length, self.params["max_decoder_length"]+1,
+ dtype=tf.int32)
+ inputs = (mask * inputs)[:, :-1]
+ return inputs
+
+ def _decode(self, target_ids, target_mask, start_token_ids,
+ encoder_output, encoder_mask, training):
+ """Compute likelihood of target tokens under the model.
+
+ Args:
+ target_ids: tensor with shape [batch_size, target_length, hidden_size]
+ target_mask: self-attention bias for decoder attention layer. [batch_size,
+ input_length]
+ start_token_ids: int32 tensor of shape [batch_size] for first decoder
+ input.
+ encoder_output: Continuous representation of input sequence. Float tensor
+ with shape [batch_size, input_length, hidden_size].
+ encoder_mask: Float tensor with shape [batch_size, input_length].
+ training: Boolean indicating whether the call is training or inference.
+
+ Returns:
+ A dict containing the output ids, the output log-probs, the output logits.
+ """
+
+ # Prepare inputs to decoder layers by shifting targets, embedding ids,
+ # adding positional encoding and applying dropout.
+ input_ids = self.get_inputs_from_targets(target_ids, start_token_ids)
+
+ input_embs = self.embeder(input_ids, self.params["max_decoder_length"],
+ training=training)
+
+ outputs = self.decoder(input_embs, target_mask,
+ encoder_output, encoder_mask, training=training)
+
+ logits = self.embeder.linear(outputs)
+ output_ids = tf.cast(tf.argmax(logits, axis=-1), tf.int32)
+
+ log_probs = -tf.nn.sparse_softmax_cross_entropy_with_logits(
+ labels=target_ids, logits=logits)
+ log_probs = tf.where(target_ids > 0, log_probs,
+ tf.zeros_like(log_probs, tf.float32))
+
+ return (tf.identity(log_probs, name="log_probs"),
+ tf.identity(logits, name="logits"),
+ tf.cast(output_ids, tf.int32, name="pred_ids"),)
+
+ def _init_cache(self, batch_size):
+ """Initialize cache for decoding."""
+
+ max_decode_len = self.params["max_decoder_length"]
+ num_heads = self.params["num_attention_heads"]
+ head_size = int(self.params["hidden_size"] / num_heads)
+
+ cache = {}
+ for layer in range(self.params["num_hidden_layers"]):
+ cache["layer_%d" % layer] = {
+ "k": tf.zeros([batch_size, num_heads, max_decode_len, head_size]),
+ "v": tf.zeros([batch_size, num_heads, max_decode_len, head_size]),
+ }
+ return cache
+
+ def _get_symbols_to_logits_fn(self, decoder_self_attention_mask):
+ """Returns a decoding function that calculates logits of the next tokens."""
+
+ max_decode_len = self.params["max_decoder_length"]
+
+ def _symbols_to_logits_fn(target_ids, cache, i):
+ """Generate logits for next candidate IDs.
+
+ Args:
+ target_ids: Current decoded sequences. int tensor with shape
+ [batch_size, i + 1]
+ cache: dictionary of values storing the encoder output, encoder-decoder
+ attention bias, and previous decoder attention values.
+ i: Loop index
+
+ Returns:
+ Tuple of
+ (logits with shape [batch_size * beam_size, vocab_size],
+ updated cache values)
+ """
+ decoder_input = tf.slice(target_ids,
+ [0, tf.maximum(tf.cast(0, i.dtype), i - 1)],
+ [target_ids.shape[0], 1])
+ self_attention_mask = tf.slice(decoder_self_attention_mask, [0, 0, i, 0],
+ [1, 1, 1, max_decode_len])
+
+ # Preprocess decoder input by getting embeddings and adding timing signal.
+ decoder_input = self.embeder(
+ decoder_input, 1, start_pos=i, training=False)
+
+ decoder_output = self.decoder(
+ decoder_input, self_attention_mask,
+ cache.get("encoder_output"), cache.get("encoder_mask"),
+ cache=cache, decode_i=i, training=False)
+
+ logits = self.embeder.linear(decoder_output)
+ logits = tf.squeeze(logits, axis=[1])
+
+ return logits
+
+ return _symbols_to_logits_fn
+
+ def _predict(self, target_ids, target_mask, start_token_ids,
+ encoder_output, encoder_mask):
+ """Beam decode output tokens and probabilities.
+
+ Args:
+ target_ids: tensor with shape [batch_size, target_length, hidden_size]
+ target_mask: self-attention bias for decoder attention layer. [batch_size,
+ input_length]
+ start_token_ids: int32 tensor of shape [batch_size] for first decoder
+ input.
+ encoder_output: Continuous representation of input sequence. Float
+ tensor with shape [batch_size, target_length, num_hidden_layers,
+ hidden_size]
+ encoder_mask: bias for encoder-decoder attention layer. [batch_size,
+ input_length]
+
+ Returns:
+ A tuple of:
+ `log_probs`: Log-probs of output tokens.
+ `logits`: Logits of output tokens.
+ `pred_ids`: Predicted output sequence.
+ """
+ batch_size = utils.get_shape_list(start_token_ids)[0]
+ end_token_id = 1
+
+ # One step logit function.
+ symbols_to_logits_fn = self._get_symbols_to_logits_fn(target_mask)
+
+ # Create cache storing decoder attention values for each layer.
+ cache = self._init_cache(batch_size)
+
+ if encoder_output is not None:
+ # Add encoder output and attention bias to the cache.
+ cache["encoder_output"] = encoder_output
+ cache["encoder_mask"] = encoder_mask
+
+ decoded_ids = decoder.left2right_decode(
+ symbols_to_logits_fn,
+ start_token_ids,
+ cache,
+ batch_size,
+ self.params["max_decoder_length"],
+ vocab_size=self.params["vocab_size"],
+ beam_size=self.params["beam_size"],
+ beam_start=5,
+ beam_alpha=self.params["alpha"],
+ beam_min=0,
+ beam_max=-1,
+ eos_id=end_token_id)
+
+ # Get the top sequence for each batch element
+ output_ids = tf.cast(decoded_ids, tf.int32, name="pred_ids")
+
+ # Calculate log probs for given sequence if available.
+ calc_ids = output_ids if target_ids is None else target_ids
+ output_log_probs, output_logits, _ = self._decode(
+ calc_ids, target_mask, start_token_ids,
+ encoder_output, encoder_mask, training=False)
+
+ return (output_log_probs, output_logits, output_ids)
+
+ def _decode_and_predict(self, target_ids, encoder_output, encoder_mask,
+ training):
+ """Decodes a sequence given the input and the encoder.
+
+ Args:
+ target_ids: tensor with shape [batch_size, target_length, hidden_size]
+ encoder_output: Continuous representation of input sequence. Float
+ tensor with shape [batch_size, target_length, num_hidden_layers,
+ hidden_size]
+ encoder_mask: bias for encoder-decoder attention layer. [batch_size,
+ input_length]
+ training: Boolean indicating whether the call is training or inference.
+
+ Returns:
+ A tuple of:
+ `log_probs`: Log-probs of output tokens.
+ `logits`: Logits of output tokens.
+ `pred_ids`: Predicted output sequence.
+ """
+ # Create initial set of IDs that will be passed into symbols_to_logits_fn.
+ start_token_ids = self._get_start_token_ids(encoder_output)
+
+ # Create causal self-attention mask for decoder.
+ target_mask = decoder.create_self_attention_mask(
+ self.params["max_decoder_length"])
+
+ predictions = {}
+ if training:
+ predictions = self._decode(target_ids, target_mask, start_token_ids,
+ encoder_output, encoder_mask, training=True)
+ else:
+ predictions = self._predict(target_ids, target_mask, start_token_ids,
+ encoder_output, encoder_mask)
+
+ return predictions
+
+ def call(self,
+ input_ids,
+ target_ids=None,
+ training=None):
+ # Run the inputs through the encoder layer to map the symbol
+ # representations to continuous representations.
+ encoder_output, encoder_mask = self._encode(input_ids, training)
+
+ # Decode.
+ predictions = self._decode_and_predict(target_ids, encoder_output,
+ encoder_mask, training)
+
+ return predictions, encoder_output
diff --git a/bigbird/core/optimization.py b/bigbird/core/optimization.py
new file mode 100644
index 0000000..9688d48
--- /dev/null
+++ b/bigbird/core/optimization.py
@@ -0,0 +1,275 @@
+# Copyright 2020 The BigBird Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Functions and classes related to optimization (weight updates)."""
+
+import re
+
+from absl import logging
+import tensorflow.compat.v2 as tf
+
+# pylint: disable=g-direct-tensorflow-import
+from tensorflow.python.ops import resource_variable_ops
+
+
+def get_linear_warmup_linear_decay_lr(init_lr, num_train_steps,
+ num_warmup_steps):
+ """Calculate learning rate with linear warmup and linear decay."""
+ global_step = tf.compat.v1.train.get_or_create_global_step()
+
+ learning_rate = tf.constant(value=init_lr, shape=[], dtype=tf.float32)
+
+ # Implements linear decay of the learning rate.
+ learning_rate = tf.compat.v1.train.polynomial_decay(
+ learning_rate,
+ global_step,
+ num_train_steps,
+ end_learning_rate=0.0,
+ power=1.0,
+ cycle=False)
+
+ # Implements linear warmup. I.e., if global_step < num_warmup_steps, the
+ # learning rate will be `global_step/num_warmup_steps * init_lr`.
+ if num_warmup_steps:
+ global_steps_int = tf.cast(global_step, tf.int32)
+ warmup_steps_int = tf.constant(num_warmup_steps, dtype=tf.int32)
+
+ global_steps_float = tf.cast(global_step, tf.float32)
+ warmup_steps_float = tf.cast(num_warmup_steps, tf.float32)
+
+ warmup_percent_done = global_steps_float / warmup_steps_float
+ warmup_learning_rate = init_lr * warmup_percent_done
+
+ is_warmup = tf.cast(global_steps_int < warmup_steps_int, tf.float32)
+ learning_rate = (
+ (1.0 - is_warmup) * learning_rate + is_warmup * warmup_learning_rate)
+
+ return learning_rate
+
+
+def get_linear_warmup_rsqrt_decay_lr(init_lr, hidden_size,
+ num_warmup_steps):
+ """Calculate learning rate with linear warmup and rsqrt decay."""
+ num_warmup_steps = tf.cast(num_warmup_steps, tf.float32)
+ global_step = tf.compat.v1.train.get_or_create_global_step()
+ global_step = tf.cast(global_step, tf.float32)
+
+ learning_rate = tf.constant(value=init_lr, shape=[], dtype=tf.float32)
+ learning_rate *= tf.math.rsqrt(tf.cast(hidden_size, tf.float32))
+ # Apply linear warmup
+ learning_rate *= tf.minimum(1.0, global_step / num_warmup_steps)
+ # Apply rsqrt decay
+ learning_rate *= tf.math.rsqrt(tf.maximum(global_step, num_warmup_steps))
+
+ return learning_rate
+
+
+def get_optimizer(params, learning_rate):
+ """Gets the optimzer based on the hparams and current mode (TPU vs. CPU/GPU).
+
+ Args:
+ params: A dictionary containing training hyperparameters.
+ learning_rate: A float32 scalar.
+
+ Returns:
+ A string or an optimizer instance.
+ """
+ optimizer = None
+
+ if params["optimizer"] == "Adafactor":
+ try:
+ from tensor2tensor.utils import adafactor # pylint: disable=g-import-not-at-top
+ optimizer = adafactor.AdafactorOptimizer(learning_rate=learning_rate)
+ except ImportError:
+ logging.error("tensor2tensor not installed. Cannot use Adafactor."
+ "Defaulting to Adam.")
+ params["optimizer"] = "Adam"
+
+ if params["optimizer"] == "Adam":
+ optimizer = tf.compat.v1.train.AdamOptimizer(
+ learning_rate,
+ beta1=params["optimizer_beta1"],
+ beta2=params["optimizer_beta2"],
+ epsilon=params["optimizer_epsilon"])
+
+ if params["optimizer"] == "AdamWeightDecay":
+ optimizer = AdamWeightDecayOptimizer(
+ learning_rate,
+ weight_decay_rate=params["weight_decay_rate"],
+ beta_1=params["optimizer_beta1"],
+ beta_2=params["optimizer_beta2"],
+ epsilon=params["optimizer_epsilon"],
+ exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"])
+
+ if params["optimizer"] == "SGD":
+ optimizer = tf.compat.v1.train.GradientDescentOptimizer(learning_rate)
+
+ if optimizer is None:
+ raise ValueError("Unknown optimizer: {}.".format(params["optimizer"]))
+
+ if params["use_tpu"]:
+ # Average the gradients across TPU cores.
+ optimizer = tf.compat.v1.tpu.CrossShardOptimizer(optimizer)
+
+ return optimizer
+
+
+class AdamWeightDecayOptimizer(tf.compat.v1.train.Optimizer):
+ """A basic Adam optimizer that includes "correct" L2 weight decay."""
+
+ def __init__(self,
+ learning_rate,
+ weight_decay_rate=0.0,
+ beta_1=0.9,
+ beta_2=0.999,
+ epsilon=1e-6,
+ exclude_from_weight_decay=None,
+ name="AdamWeightDecayOptimizer"):
+ """Constructs a AdamWeightDecayOptimizer."""
+ super(AdamWeightDecayOptimizer, self).__init__(False, name)
+
+ self.learning_rate = learning_rate
+ self.weight_decay_rate = weight_decay_rate
+ self.beta_1 = beta_1
+ self.beta_2 = beta_2
+ self.epsilon = epsilon
+ self.exclude_from_weight_decay = exclude_from_weight_decay
+
+ def _create_slots(self, var_list):
+ # Create slots for the first and second moments.
+ for v in var_list:
+ self._zeros_slot(v, "m", self._name)
+ self._zeros_slot(v, "v", self._name)
+
+ def _apply_dense(self, grad, var):
+ param_name = self._get_variable_name(var.name)
+ m = self.get_slot(var, "m")
+ v = self.get_slot(var, "v")
+
+ # Standard Adam update.
+ next_m = (
+ tf.multiply(self.beta_1, m) + tf.multiply(1.0 - self.beta_1, grad))
+ next_v = (
+ tf.multiply(self.beta_2, v) + tf.multiply(1.0 - self.beta_2,
+ tf.square(grad)))
+
+ update = next_m / (tf.sqrt(next_v) + self.epsilon)
+
+ # Just adding the square of the weights to the loss function is *not*
+ # the correct way of using L2 regularization/weight decay with Adam,
+ # since that will interact with the m and v parameters in strange ways.
+ #
+ # Instead we want ot decay the weights in a manner that doesn't interact
+ # with the m/v parameters. This is equivalent to adding the square
+ # of the weights to the loss with plain (non-momentum) SGD.
+ if self._do_use_weight_decay(param_name):
+ update += self.weight_decay_rate * var
+
+ update_with_lr = self.learning_rate * update
+
+ next_param = var - update_with_lr
+
+ return tf.group(
+ [var.assign(next_param),
+ m.assign(next_m),
+ v.assign(next_v)])
+
+ def _resource_apply_dense(self, grad, var):
+ """See `tf.train.Optimizer._resource_apply_dense()`."""
+ return self._apply_dense(grad, var)
+
+ def _apply_sparse(self, grad, var):
+ """See `tf.train.Optimizer._apply_sparse()`."""
+ def scatter_update_fn(x, i, v):
+ return tf.compat.v1.scatter_update(x, i, v, use_locking=self._use_locking)
+ return self._apply_sparse_shared(
+ grad.values, grad.indices, var, scatter_update_fn)
+
+ def _resource_apply_sparse(self, grad, var, indices):
+ """See `tf.train.Optimizer._resource_apply_spase()`."""
+ def scatter_update_fn(x, i, v):
+ with tf.control_dependencies(
+ [resource_variable_ops.resource_scatter_update(x.handle, i, v)]):
+ return x.value()
+ return self._apply_sparse_shared(grad, indices, var, scatter_update_fn)
+
+ def _apply_sparse_shared(self, grad, indices, var, scatter_update_fn):
+ """Applies sparse gradients to a variable.
+
+ Args:
+ grad: A tensor for the `values` of `tf.IndexedSlices`.
+ indices: A tensor for the `indices` of `tf.IndexedSlices`.
+ var: A `tf.Variable` object.
+ scatter_update_fn: A function which performs scattered update to
+ a `tf.Variable` object. It takes tuple of (x, i, v) where:
+ * x: A `tf.Variable` object which is updated by `i` and `v`,
+ * i: A tensor for the `indices` of `tf.IndexedSlices`,
+ * v: A tensor for the `values` of `tf.IndexedSlices`,
+ and returns a tensor after updating `x`.
+
+ Returns:
+ An op which updates `var` with `grad` and `indices`.
+ """
+ param_name = self._get_variable_name(var.name)
+ m = self.get_slot(var, "m")
+ v = self.get_slot(var, "v")
+
+ # m_t = beta1 * m + (1 - beta1) * g_t
+ m_scaled_g_values = tf.multiply(1.0 - self.beta_1, grad)
+ m_t = m.assign(m * self.beta_1)
+ with tf.control_dependencies([m_t]):
+ m_slice = tf.gather(m, indices) + m_scaled_g_values
+ m_t = scatter_update_fn(m, indices, m_slice)
+
+ # v_t = beta2 * v + (1 - beta2) * g_t^2
+ v_scaled_g_values = tf.multiply(1.0 - self.beta_2, tf.square(grad))
+ v_t = v.assign(v * self.beta_2)
+ with tf.control_dependencies([v_t]):
+ v_slice = tf.gather(v, indices) + v_scaled_g_values
+ v_t = scatter_update_fn(v, indices, v_slice)
+
+ update = m_t / (tf.sqrt(v_t) + self.epsilon)
+
+ # Just adding the square of the weights to the loss function is *not*
+ # the correct way of using L2 regularization/weight decay with Adam,
+ # since that will interact with the m and v parameters in strange ways.
+ #
+ # Instead we want ot decay the weights in a manner that doesn't interact
+ # with the m/v parameters. This is equivalent to adding the square
+ # of the weights to the loss with plain (non-momentum) SGD.
+ if self._do_use_weight_decay(param_name):
+ update += self.weight_decay_rate * var
+
+ update_with_lr = self.learning_rate * update
+
+ next_param = var - update_with_lr
+
+ return tf.group([var.assign(next_param), m_t, v_t])
+
+ def _do_use_weight_decay(self, param_name):
+ """Whether to use L2 weight decay for `param_name`."""
+ if not self.weight_decay_rate:
+ return False
+ if self.exclude_from_weight_decay:
+ for r in self.exclude_from_weight_decay:
+ if re.search(r, param_name) is not None:
+ return False
+ return True
+
+ def _get_variable_name(self, param_name):
+ """Get the variable name from the tensor name."""
+ m = re.match("^(.*):\\d+$", param_name)
+ if m is not None:
+ param_name = m.group(1)
+ return param_name
diff --git a/bigbird/core/utils.py b/bigbird/core/utils.py
new file mode 100644
index 0000000..56a39a6
--- /dev/null
+++ b/bigbird/core/utils.py
@@ -0,0 +1,741 @@
+# Copyright 2020 The BigBird Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Helper and utility functions."""
+
+import re
+
+from absl import logging
+import numpy as np
+import tensorflow.compat.v2 as tf
+
+
+############################### SHAPE UTILS ####################################
+
+
+def get_shape_list(tensor, expected_rank=None, name=None):
+ """Returns a list of the shape of tensor, preferring static dimensions.
+
+ Args:
+ tensor: A tf.Tensor object to find the shape of.
+ expected_rank: (optional) int. The expected rank of `tensor`. If this is
+ specified and the `tensor` has a different rank, and exception will be
+ thrown.
+ name: Optional name of the tensor for the error message.
+
+ Returns:
+ A list of dimensions of the shape of tensor. All static dimensions will
+ be returned as python integers, and dynamic dimensions will be returned
+ as tf.Tensor scalars.
+ """
+ if not tf.executing_eagerly() and name is None:
+ name = tensor.name
+
+ if expected_rank is not None:
+ assert_rank(tensor, expected_rank, name)
+
+ shape = tensor.shape.as_list()
+
+ non_static_indexes = []
+ for (index, dim) in enumerate(shape):
+ if dim is None:
+ non_static_indexes.append(index)
+
+ if not non_static_indexes:
+ return shape
+
+ assert False, "Static shape not available for {}".format(tensor)
+
+ dyn_shape = tf.shape(tensor)
+ for index in non_static_indexes:
+ shape[index] = dyn_shape[index]
+ return shape
+
+
+def reshape_to_matrix(input_tensor):
+ """Reshapes a >= rank 2 tensor to a rank 2 tensor (i.e., a matrix)."""
+ ndims = input_tensor.shape.ndims
+ if ndims < 2:
+ raise ValueError("Input tensor must have at least rank 2. Shape = %s" %
+ (input_tensor.shape))
+ if ndims == 2:
+ return input_tensor
+
+ width = input_tensor.shape[-1]
+ output_tensor = tf.reshape(input_tensor, [-1, width])
+ return output_tensor
+
+
+def reshape_from_matrix(output_tensor, orig_shape_list):
+ """Reshapes a rank 2 tensor back to its original rank >= 2 tensor."""
+ if len(orig_shape_list) == 2:
+ return output_tensor
+
+ output_shape = get_shape_list(output_tensor)
+
+ orig_dims = orig_shape_list[0:-1]
+ width = output_shape[-1]
+
+ return tf.reshape(output_tensor, orig_dims + [width])
+
+
+def assert_rank(tensor, expected_rank, name=None):
+ """Raises an exception if the tensor rank is not of the expected rank.
+
+ Args:
+ tensor: A tf.Tensor to check the rank of.
+ expected_rank: Python integer or list of integers, expected rank.
+ name: Optional name of the tensor for the error message.
+
+ Raises:
+ ValueError: If the expected shape doesn't match the actual shape.
+ """
+ if not tf.executing_eagerly() and name is None:
+ name = tensor.name
+
+ expected_rank_dict = {}
+ if isinstance(expected_rank, int):
+ expected_rank_dict[expected_rank] = True
+ else:
+ for x in expected_rank:
+ expected_rank_dict[x] = True
+
+ actual_rank = tensor.shape.ndims
+ if actual_rank not in expected_rank_dict:
+ scope_name = tf.compat.v1.get_variable_scope().name
+ raise ValueError(
+ "For the tensor `{}` in scope `{}`, the actual rank "
+ "`{}` (shape = {}) is not equal to the expected rank `{}`".format(
+ name, scope_name, actual_rank, str(tensor.shape),
+ str(expected_rank)))
+
+
+############################### DENSE LAYERS ###################################
+
+
+def create_initializer(initializer_range=0.02):
+ """Creates a `truncated_normal_initializer` with the given range."""
+ return tf.compat.v1.truncated_normal_initializer(stddev=initializer_range)
+
+
+class Dense3dLayer(tf.compat.v1.layers.Layer):
+ """A dense layer with 3D kernel."""
+
+ def __init__(self,
+ num_attention_heads,
+ size_per_head,
+ initializer,
+ activation,
+ name=None,
+ head_first=False,
+ use_bias=True):
+ """Constructor for dense layer with 3D kernel.
+
+ Args:
+ num_attention_heads: The size of output dimension.
+ size_per_head: The size per attention head.
+ initializer: Kernel initializer.
+ activation: Actication function.
+ name: The name scope of this layer.
+ head_first: Whether to output head dimension before or after sequence dim.
+ use_bias: Whether the layer uses a bias vector.
+ """
+ super(Dense3dLayer, self).__init__(name=name)
+ self.num_attention_heads = num_attention_heads
+ self.size_per_head = size_per_head
+ self.initializer = initializer
+ self.activation = activation
+ self.head_first = head_first
+ self.use_bias = use_bias
+
+ self.w = None
+ self.b = None
+
+ def call(self, input_tensor):
+ """Constructor for dense layer with 3D kernel.
+
+ Args:
+ input_tensor: float Tensor of shape [batch, seq_length, hidden_size].
+
+ Returns:
+ float logits Tensor.
+ """
+ last_dim = get_shape_list(input_tensor)[-1]
+ if self.w is None:
+ self.w = tf.compat.v1.get_variable(
+ name="kernel",
+ shape=[last_dim, self.num_attention_heads * self.size_per_head],
+ initializer=self.initializer)
+ self.initializer = None
+ self._trainable_weights.append(self.w)
+ reshape_w = tf.reshape(
+ self.w, [last_dim, self.num_attention_heads, self.size_per_head])
+ if self.head_first:
+ ret = tf.einsum("abc,cde->adbe", input_tensor, reshape_w)
+ else:
+ ret = tf.einsum("abc,cde->abde", input_tensor, reshape_w)
+
+ if self.use_bias:
+ if self.b is None:
+ self.b = tf.compat.v1.get_variable(
+ name="bias",
+ shape=[self.num_attention_heads * self.size_per_head],
+ initializer=tf.zeros_initializer)
+ self._trainable_weights.append(self.b)
+ if self.head_first:
+ reshape_b = tf.reshape(
+ self.b, [1, self.num_attention_heads, 1, self.size_per_head])
+ else:
+ reshape_b = tf.reshape(
+ self.b, [self.num_attention_heads, self.size_per_head])
+ ret += reshape_b
+
+ if self.activation is not None:
+ return self.activation(ret)
+ else:
+ return ret
+
+
+class Dense3dProjLayer(tf.compat.v1.layers.Layer):
+ """A dense layer with 3D kernel for projection."""
+
+ def __init__(self,
+ num_attention_heads,
+ size_per_head,
+ initializer,
+ activation,
+ name=None,
+ use_bias=True):
+ """Constructor for dense layer with 3D kernel for projection.
+
+ Args:
+ num_attention_heads: The size of output dimension.
+ size_per_head: The size per attention head.
+ initializer: Kernel initializer.
+ activation: Actication function.
+ name: The name scope of this layer.
+ use_bias: Whether the layer uses a bias vector.
+ """
+ super(Dense3dProjLayer, self).__init__(name=name)
+ self.num_attention_heads = num_attention_heads
+ self.size_per_head = size_per_head
+ self.initializer = initializer
+ self.activation = activation
+ self.use_bias = use_bias
+
+ self.w = None
+ self.b = None
+
+ def call(self, input_tensor):
+ """Constructor for dense layer with 3D kernel for projection.
+
+ Args:
+ input_tensor: float Tensor of shape [batch,from_seq_length,
+ num_attention_heads, size_per_head].
+
+ Returns:
+ float logits Tensor.
+ """
+ hidden_size = self.num_attention_heads * self.size_per_head
+ if self.w is None:
+ self.w = tf.compat.v1.get_variable(
+ name="kernel",
+ shape=[hidden_size, hidden_size],
+ initializer=self.initializer)
+ self.initializer = None
+ self._trainable_weights.append(self.w)
+ reshape_w = tf.reshape(
+ self.w, [self.num_attention_heads, self.size_per_head, hidden_size])
+ ret = tf.einsum("BFNH,NHD->BFD", input_tensor, reshape_w)
+
+ if self.use_bias:
+ if self.b is None:
+ self.b = tf.compat.v1.get_variable(
+ name="bias",
+ shape=[hidden_size],
+ initializer=tf.zeros_initializer)
+ self._trainable_weights.append(self.b)
+ ret += self.b
+
+ if self.activation is not None:
+ return self.activation(ret)
+ else:
+ return ret
+
+
+class Dense2dLayer(tf.compat.v1.layers.Layer):
+ """A dense layer with 2D kernel."""
+
+ def __init__(self,
+ output_size,
+ initializer,
+ activation,
+ name=None,
+ use_bias=True):
+ """Constructor for dense layer with 2D kernel.
+
+ Args:
+ output_size: The size of output dimension.
+ initializer: Kernel initializer.
+ activation: Actication function.
+ name: The name scope of this layer.
+ use_bias: Whether the layer uses a bias vector.
+ """
+ super(Dense2dLayer, self).__init__(name=name)
+ self.output_size = output_size
+ self.initializer = initializer
+ self.activation = activation
+ self.use_bias = use_bias
+
+ self.w = None
+ self.b = None
+
+ def call(self, input_tensor):
+ """Forward pass for dense layer with 2D kernel.
+
+ Args:
+ input_tensor: Float tensor with rank 3.
+
+ Returns:
+ float logits Tensor.
+ """
+ if self.w is None:
+ last_dim = get_shape_list(input_tensor)[-1]
+ self.w = tf.compat.v1.get_variable(
+ name="kernel",
+ shape=[last_dim, self.output_size],
+ initializer=self.initializer)
+ self.initializer = None
+ self._trainable_weights.append(self.w)
+ ret = tf.einsum("abc,cd->abd", input_tensor, self.w)
+
+ if self.use_bias:
+ if self.b is None:
+ self.b = tf.compat.v1.get_variable(
+ name="bias",
+ shape=[self.output_size],
+ initializer=tf.zeros_initializer)
+ self._trainable_weights.append(self.b)
+ ret += self.b
+
+ if self.activation is not None:
+ return self.activation(ret)
+ else:
+ return ret
+
+
+def gelu(x):
+ """Gaussian Error Linear Unit.
+
+ This is a smoother version of the RELU.
+ Original paper: https://arxiv.org/abs/1606.08415
+ Args:
+ x: float Tensor to perform activation.
+
+ Returns:
+ `x` with the GELU activation applied.
+ """
+ cdf = 0.5 * (1.0 + tf.tanh(
+ (np.sqrt(2 / np.pi) * (x + 0.044715 * tf.pow(x, 3)))))
+ return x * cdf
+
+
+def get_activation(activation_string):
+ """Maps a string to a Python function, e.g., "relu" => `tf.nn.relu`.
+
+ Args:
+ activation_string: String name of the activation function.
+
+ Returns:
+ A Python function corresponding to the activation function. If
+ `activation_string` is None, empty, or "linear", this will return None.
+ If `activation_string` is not a string, it will return `activation_string`.
+
+ Raises:
+ ValueError: The `activation_string` does not correspond to a known
+ activation.
+ """
+
+ # We assume that anything that"s not a string is already an activation
+ # function, so we just return it.
+ if not isinstance(activation_string, str):
+ return activation_string
+
+ if not activation_string:
+ return None
+
+ act = activation_string.lower()
+ if act == "linear":
+ return None
+ elif act == "relu":
+ return tf.nn.relu
+ elif act == "gelu":
+ return gelu
+ elif act == "tanh":
+ return tf.tanh
+ else:
+ raise ValueError("Unsupported activation: %s" % act)
+
+
+########################## NORM & DROPOUT LAYERS ###############################
+
+
+def dropout(input_tensor, dropout_prob, training=True):
+ """Perform dropout.
+
+ Args:
+ input_tensor: float Tensor.
+ dropout_prob: Python float. The probability of dropping out a value (NOT of
+ *keeping* a dimension as in `tf.nn.dropout`).
+ training: Boolean indicating whether the call is training or inference.
+
+ Returns:
+ A version of `input_tensor` with dropout applied.
+ """
+ if not training or dropout_prob is None or dropout_prob == 0.0:
+ return input_tensor
+
+ output = tf.nn.dropout(input_tensor, rate=dropout_prob)
+ return output
+
+
+class NormLayer(tf.compat.v1.layers.Layer):
+ """Replacement for contrib_layers.layer_norm."""
+
+ def __init__(self, name="LayerNorm"):
+ super(NormLayer, self).__init__(name=name)
+ self.beta = None
+ self.gamma = None
+
+ def call(self, input_tensor):
+ inputs = tf.convert_to_tensor(input_tensor)
+ inputs_shape = get_shape_list(inputs)
+ inputs_rank = len(inputs_shape)
+ dtype = inputs.dtype.base_dtype
+ norm_axis = inputs_rank - 1
+ params_shape = [inputs_shape[norm_axis]]
+
+ # Allocate parameters for the beta and gamma of the normalization.
+ if self.beta is None:
+ self.beta = tf.compat.v1.get_variable(
+ "beta",
+ shape=params_shape,
+ dtype=dtype,
+ initializer=tf.zeros_initializer(),
+ trainable=True)
+ self._trainable_weights.append(self.beta)
+ if self.gamma is None:
+ self.gamma = tf.compat.v1.get_variable(
+ "gamma",
+ shape=params_shape,
+ dtype=dtype,
+ initializer=tf.ones_initializer(),
+ trainable=True)
+ self._trainable_weights.append(self.gamma)
+ # Compute norm along last axis
+ mean, variance = tf.nn.moments(inputs, [norm_axis], keepdims=True)
+ # Compute layer normalization using the batch_normalization function.
+ # Note that epsilon must be increased for float16 due to the limited
+ # representable range.
+ variance_epsilon = 1e-12 if dtype != tf.float16 else 1e-3
+ outputs = tf.nn.batch_normalization(
+ inputs,
+ mean,
+ variance,
+ offset=self.beta,
+ scale=self.gamma,
+ variance_epsilon=variance_epsilon)
+ outputs.set_shape(inputs_shape)
+ return outputs
+
+
+############################# EMBEDDING LAYER ##################################
+
+
+class EmbeddingLayer(tf.compat.v1.layers.Layer):
+ """An embedding layer."""
+
+ def __init__(self,
+ vocab_size,
+ emb_dim,
+ initializer,
+ scale_emb=False,
+ use_token_type=False,
+ num_token_types=16,
+ use_position_embeddings=True,
+ max_position_embeddings=4096,
+ dropout_prob=0.0,
+ name="embeddings"):
+ super(EmbeddingLayer, self).__init__(name=name)
+ self.vocab_size = vocab_size
+ self.emb_dim = emb_dim
+ self.scale_emb = scale_emb
+ self.num_token_types = num_token_types
+ self.max_position_embeddings = max_position_embeddings
+ self.dropout_prob = dropout_prob
+
+ with tf.compat.v1.variable_scope(name):
+ self.word_embeddings = tf.compat.v1.get_variable(
+ "word_embeddings", [vocab_size, emb_dim],
+ dtype=tf.float32, initializer=initializer)
+ self._trainable_weights.append(self.word_embeddings)
+
+ if use_token_type:
+ self.token_type_table = tf.compat.v1.get_variable(
+ "token_type_embeddings", [num_token_types, emb_dim],
+ dtype=tf.float32, initializer=initializer)
+ self._trainable_weights.append(self.token_type_table)
+ else:
+ self.token_type_table = None
+
+ if use_position_embeddings:
+ self.position_embeddings = tf.compat.v1.get_variable(
+ "position_embeddings", [max_position_embeddings, emb_dim],
+ dtype=tf.float32, initializer=initializer)
+ self._trainable_weights.append(self.position_embeddings)
+ else:
+ self.position_embeddings = None
+
+ def call(self,
+ input_ids,
+ seq_length,
+ start_pos=0,
+ token_type_ids=None,
+ training=None):
+ if input_ids is None:
+ return None
+
+ # subtoken embedding
+ output = tf.nn.embedding_lookup(params=self.word_embeddings, ids=input_ids)
+
+ if self.scale_emb:
+ output = output * self.emb_dim ** 0.5
+
+ if self.token_type_table is not None:
+ # This vocab will be small so we always do one-hot here, since it is
+ # always faster for a small vocabulary.
+ one_hot_ids = tf.one_hot(token_type_ids, depth=self.num_token_types)
+ token_type_embeddings = tf.tensordot(
+ one_hot_ids, self.token_type_table, 1)
+ output += token_type_embeddings
+
+ if self.position_embeddings is not None:
+ # assert_op = tf.compat.v1.assert_less_equal(
+ # start_pos + seq_length, self.max_position_embeddings)
+ # with tf.control_dependencies([assert_op]):
+ # So `position_embeddings` is effectively an embedding table for
+ # position [0, 1, 2, ..., max_position_embeddings-1], and the current
+ # sequence has positions [0, 1, 2, ... seq_length-1], so we can just
+ # perform a slice.
+ position_embeddings = tf.slice(self.position_embeddings, [start_pos, 0],
+ [seq_length, self.emb_dim])
+ output += tf.expand_dims(position_embeddings, axis=0)
+
+ if training and self.dropout_prob > 0:
+ output = tf.nn.dropout(output, self.dropout_prob)
+ return output
+
+ def linear(self, x):
+ """Computes logits by running x through a linear layer.
+
+ Args:
+ x: A float32 tensor with shape [..., hidden_size]
+ Returns:
+ float32 tensor with shape [..., vocab_size].
+ """
+ with tf.compat.v1.name_scope("presoftmax_linear"):
+ logits = tf.tensordot(x, self.word_embeddings, [[-1], [1]])
+ return logits
+
+
+########################## TPU/CHECKPOINT UTILS ################################
+
+
+def get_estimator(config, model_fn, keep_checkpoint_max=10):
+ """Create TPUEstimator object for given config and model_fn."""
+ tpu_cluster_resolver = None
+ if config["use_tpu"] and config["tpu_name"]:
+ tpu_cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
+ config["tpu_name"],
+ zone=config["tpu_zone"],
+ project=config["gcp_project"])
+
+ # Batch size book-keeping
+ # Estimators handle batch sizes differently among GPUs and TPUs
+ # GPU: Estimator needs per core batch size
+ # TPU: Estimator needs total batch size, i.e. num_cores * per core batch size
+ config_train_batch_size = config["train_batch_size"] # For estimator
+ config_eval_batch_size = config["eval_batch_size"] # For estimator
+ effective_train_batch_size = config["train_batch_size"] # For human
+ effective_eval_batch_size = config["eval_batch_size"] # For human
+ if config["use_tpu"]:
+ sliced_eval_mode = tf.compat.v1.estimator.tpu.InputPipelineConfig.SLICED
+ distribute_strategy = None
+ config_train_batch_size *= config["num_tpu_cores"]
+ config_eval_batch_size *= config["num_tpu_cores"]
+ effective_train_batch_size = config_train_batch_size
+ effective_eval_batch_size = config_eval_batch_size
+ else:
+ sliced_eval_mode = tf.compat.v1.estimator.tpu.InputPipelineConfig.PER_HOST_V1
+ distribute_strategy = tf.distribute.MirroredStrategy(devices=None)
+ effective_train_batch_size *= distribute_strategy.num_replicas_in_sync
+ # effective_eval_batch_size *= distribute_strategy.num_replicas_in_sync
+
+ is_per_host = tf.compat.v1.estimator.tpu.InputPipelineConfig.PER_HOST_V2
+ run_config = tf.compat.v1.estimator.tpu.RunConfig(
+ cluster=tpu_cluster_resolver,
+ master=config["master"],
+ model_dir=config["output_dir"],
+ save_checkpoints_steps=config["save_checkpoints_steps"],
+ keep_checkpoint_max=keep_checkpoint_max,
+ train_distribute=distribute_strategy,
+ tpu_config=tf.compat.v1.estimator.tpu.TPUConfig(
+ tpu_job_name=config["tpu_job_name"],
+ iterations_per_loop=config["iterations_per_loop"],
+ num_shards=config["num_tpu_cores"],
+ per_host_input_for_training=is_per_host,
+ eval_training_input_configuration=sliced_eval_mode))
+
+ if config["init_checkpoint"]:
+ ckpt_var_list = tf.compat.v1.train.list_variables(config["init_checkpoint"])
+ ckpt_var_list = {
+ name: shape for name, shape in ckpt_var_list
+ if not re.findall("(Adam|Adafactor|global_step)", name)
+ }
+ vars_to_warm_start = "({})".format("|".join(ckpt_var_list.keys()))
+ warm_start_settings = tf.estimator.WarmStartSettings(
+ ckpt_to_initialize_from=config["init_checkpoint"],
+ vars_to_warm_start=vars_to_warm_start)
+ else:
+ ckpt_var_list = {}
+ warm_start_settings = None
+ config["ckpt_var_list"] = ckpt_var_list
+
+ # If no TPU, this will fall back to normal Estimator on CPU or GPU.
+ estimator = tf.compat.v1.estimator.tpu.TPUEstimator(
+ use_tpu=config["use_tpu"],
+ model_fn=model_fn,
+ config=run_config,
+ train_batch_size=config_train_batch_size,
+ eval_batch_size=config_eval_batch_size,
+ warm_start_from=warm_start_settings)
+
+ # assign batch sizes
+ estimator.train_batch_size = effective_train_batch_size
+ estimator.eval_batch_size = effective_eval_batch_size
+
+ return estimator
+
+
+def log_variables(variables, ckpt_var_list):
+ """Log trainable variables."""
+ logging.info("**** Trainable Variables ****")
+
+ model_var_list = {var.name: var.get_shape().as_list() for var in variables}
+ num_params = sum(np.prod(shape) for shape in model_var_list.values())
+ length = max(len(name) for name in model_var_list) + 2
+ line = "{{:<{}}}{{:<13}}{{}}".format(length)
+
+ logging.info("The model has {} trainable variables "
+ "({:,} parameters):\n".format(len(model_var_list), num_params))
+ logging.info(line.format("Name", "Initialized", "Shape"))
+ logging.info(line.format("----", "-----------", "-----"))
+
+ ckpt_var_list = ckpt_var_list.copy()
+ for name, shape in model_var_list.items():
+ name = name.split(":")[0]
+ if name in ckpt_var_list:
+ warm_started = "from ckpt"
+ del ckpt_var_list[name]
+ else:
+ warm_started = "random"
+ logging.info(line.format(name, warm_started, shape))
+
+ if ckpt_var_list:
+ logging.warning(
+ "The warm start checkpoint contained %d variables that were not used "
+ "for the model:\n", len(ckpt_var_list))
+ for name, shape in ckpt_var_list.items():
+ logging.warning(line.format(name, "not used", shape))
+
+
+def add_scalars_to_summary(summary_dir, scalar_tensors_dict):
+ """Creates a host_call function that writes summaries on TPU."""
+
+ # All tensors outfed from TPU should preserve batch size dimension.
+ scalar_tensors_dict = {
+ k: tf.reshape(v, [1]) for k, v in scalar_tensors_dict.items()
+ }
+
+ def host_call_fn(**kwargs):
+ writer = tf.summary.create_file_writer(summary_dir, max_queue=1000)
+ always_record = tf.summary.record_if(True)
+ with writer.as_default(), always_record:
+ for name, scalar in kwargs.items():
+ tf.summary.scalar(name, tf.reduce_mean(scalar),
+ tf.compat.v1.train.get_or_create_global_step())
+ return tf.compat.v1.summary.all_v2_summary_ops()
+
+ return host_call_fn, scalar_tensors_dict
+
+
+########################## DEFAULT CONFIG UTILS ################################
+
+
+def get_default_config():
+ """Default values for BigBird."""
+
+ default_config = {
+ # transformer basic configs
+ "attention_probs_dropout_prob": 0.1,
+ "hidden_act": "gelu",
+ "hidden_dropout_prob": 0.1,
+ "hidden_size": 768,
+ "initializer_range": 0.02,
+ "intermediate_size": 3072,
+ "max_position_embeddings": 4096,
+ "num_attention_heads": 12,
+ "num_hidden_layers": 12,
+ "type_vocab_size": 2,
+ "use_bias": True,
+ "rescale_embedding": False,
+ "scope": "bert",
+ # sparse mask configs
+ "attention_type": "block_sparse",
+ "norm_type": "postnorm",
+ "block_size": 16,
+ "num_rand_blocks": 3,
+ # common bert configs
+ "max_encoder_length": 1024,
+ "max_decoder_length": 64,
+ "couple_encoder_decoder": False,
+ "beam_size": 5,
+ "alpha": 0.7,
+ "label_smoothing": 0.1,
+ "weight_decay_rate": 0.01,
+ "optimizer_beta1": 0.9,
+ "optimizer_beta2": 0.999,
+ "optimizer_epsilon": 1e-6,
+ # TPU settings
+ "use_tpu": True,
+ "tpu_name": None,
+ "tpu_zone": None,
+ "tpu_job_name": None,
+ "gcp_project": None,
+ "master": None,
+ "num_tpu_cores": 8,
+ "iterations_per_loop": "1000",
+ }
+
+ return default_config
diff --git a/bigbird/pretrain/__init__.py b/bigbird/pretrain/__init__.py
new file mode 100644
index 0000000..f6cd7c8
--- /dev/null
+++ b/bigbird/pretrain/__init__.py
@@ -0,0 +1,14 @@
+# Copyright 2020 The BigBird Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
diff --git a/bigbird/pretrain/run_pretraining.py b/bigbird/pretrain/run_pretraining.py
new file mode 100644
index 0000000..a494908
--- /dev/null
+++ b/bigbird/pretrain/run_pretraining.py
@@ -0,0 +1,668 @@
+# Copyright 2020 The BigBird Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Run masked LM/next sentence pre-training for BigBird."""
+
+import os
+import time
+
+from absl import app
+from absl import logging
+from bigbird.core import flags
+from bigbird.core import modeling
+from bigbird.core import optimization
+from bigbird.core import utils
+import numpy as np
+import tensorflow.compat.v2 as tf
+import tensorflow_datasets as tfds
+import tensorflow_text as tft
+
+import sentencepiece as spm
+
+
+FLAGS = flags.FLAGS
+
+## Required parameters
+
+flags.DEFINE_string(
+ "data_dir", "tfds://wiki40b/en",
+ "The input data dir. Should contain the TFRecord files. "
+ "Can be TF Dataset with prefix tfds://")
+
+flags.DEFINE_string(
+ "output_dir", "/tmp/bigb",
+ "The output directory where the model checkpoints will be written.")
+
+## Other parameters
+flags.DEFINE_string(
+ "init_checkpoint", None,
+ "Initial checkpoint (usually from a pre-trained BigBird model).")
+
+flags.DEFINE_integer(
+ "max_encoder_length", 512,
+ "The maximum total input sequence length after SentencePiece tokenization. "
+ "Sequences longer than this will be truncated, and sequences shorter "
+ "than this will be padded. Must match data generation.")
+
+flags.DEFINE_integer(
+ "max_predictions_per_seq", 75,
+ "Maximum number of masked LM predictions per sequence. "
+ "Must match data generation.")
+
+flags.DEFINE_float(
+ "masked_lm_prob", 0.15,
+ "Masked LM probability.")
+
+flags.DEFINE_string(
+ "substitute_newline", " ",
+ "Replace newline charachter from text with supplied string.")
+
+flags.DEFINE_bool(
+ "do_train", True,
+ "Whether to run training.")
+
+flags.DEFINE_bool(
+ "do_eval", False,
+ "Whether to run eval on the dev set.")
+
+flags.DEFINE_bool(
+ "do_export", False,
+ "Whether to export the model as TF SavedModel.")
+
+flags.DEFINE_integer(
+ "train_batch_size", 4,
+ "Local batch size for training. "
+ "Total batch size will be multiplied by number gpu/tpu cores available.")
+
+flags.DEFINE_integer(
+ "eval_batch_size", 4,
+ "Local batch size for eval. "
+ "Total batch size will be multiplied by number gpu/tpu cores available.")
+
+flags.DEFINE_string(
+ "optimizer", "AdamWeightDecay",
+ "Optimizer to use. Can be Adafactor, Adam, and AdamWeightDecay.")
+
+flags.DEFINE_float(
+ "learning_rate", 1e-4,
+ "The initial learning rate for Adam.")
+
+flags.DEFINE_integer(
+ "num_train_steps", 100000,
+ "Total number of training steps to perform.")
+
+flags.DEFINE_integer(
+ "num_warmup_steps", 10000,
+ "Number of steps to perform linear warmup.")
+
+flags.DEFINE_integer(
+ "save_checkpoints_steps", 1000,
+ "How often to save the model checkpoint.")
+
+flags.DEFINE_integer(
+ "max_eval_steps", 100,
+ "Maximum number of eval steps.")
+
+flags.DEFINE_bool(
+ "preprocessed_data", False,
+ "Whether TFRecord data is already tokenized and masked.")
+
+flags.DEFINE_bool(
+ "use_nsp", False,
+ "Whether to use next sentence prediction loss.")
+
+
+def input_fn_builder(data_dir, vocab_model_file, masked_lm_prob,
+ max_encoder_length, max_predictions_per_seq,
+ preprocessed_data, substitute_newline, is_training,
+ tmp_dir=None):
+ """Creates an `input_fn` closure to be passed to TPUEstimator."""
+
+ sp_model = spm.SentencePieceProcessor()
+ sp_proto = tf.io.gfile.GFile(vocab_model_file, "rb").read()
+ sp_model.LoadFromSerializedProto(sp_proto)
+ vocab_size = sp_model.GetPieceSize()
+ word_start_subtoken = np.array(
+ [sp_model.IdToPiece(i)[0] == "▁" for i in range(vocab_size)])
+
+ feature_shapes = {
+ "input_ids": [max_encoder_length],
+ "segment_ids": [max_encoder_length],
+ "masked_lm_positions": [max_predictions_per_seq],
+ "masked_lm_ids": [max_predictions_per_seq],
+ "masked_lm_weights": [max_predictions_per_seq],
+ "next_sentence_labels": [1]
+ }
+
+ def _decode_record(record):
+ """Decodes a record to a TensorFlow example."""
+ name_to_features = {
+ "input_ids":
+ tf.io.FixedLenFeature([max_encoder_length], tf.int64),
+ "segment_ids":
+ tf.io.FixedLenFeature([max_encoder_length], tf.int64),
+ "masked_lm_positions":
+ tf.io.FixedLenFeature([max_predictions_per_seq], tf.int64),
+ "masked_lm_ids":
+ tf.io.FixedLenFeature([max_predictions_per_seq], tf.int64),
+ "masked_lm_weights":
+ tf.io.FixedLenFeature([max_predictions_per_seq], tf.float32),
+ "next_sentence_labels":
+ tf.io.FixedLenFeature([1], tf.int64),
+ }
+ example = tf.io.parse_single_example(record, name_to_features)
+
+ # tf.Example only supports tf.int64, but the TPU only supports tf.int32.
+ # So cast all int64 to int32.
+ for name in list(example.keys()):
+ t = example[name]
+ if t.dtype == tf.int64:
+ t = tf.cast(t, tf.int32)
+ example[name] = t
+
+ return example
+
+ def do_masking(example):
+ text = example["text"]
+ tokenizer = tft.SentencepieceTokenizer(
+ model=tf.io.gfile.GFile(vocab_model_file, "rb").read())
+ if substitute_newline:
+ text = tf.strings.regex_replace(text, "\n", substitute_newline)
+ subtokens = tokenizer.tokenize(text)
+ (subtokens, masked_lm_positions, masked_lm_ids,
+ masked_lm_weights) = tf.compat.v1.py_func(
+ numpy_masking, [subtokens], [tf.int32, tf.int32, tf.int32, tf.float32],
+ stateful=False)
+ features = {
+ "input_ids": subtokens,
+ "segment_ids": tf.zeros_like(subtokens),
+ "masked_lm_positions": masked_lm_positions,
+ "masked_lm_ids": masked_lm_ids,
+ "masked_lm_weights": masked_lm_weights,
+ "next_sentence_labels": tf.zeros([1], dtype=tf.int64),
+ }
+ return features
+
+ def numpy_masking(subtokens):
+ # Find a random span in text
+ end_pos = max_encoder_length - 2 + np.random.randint(
+ max(1, len(subtokens) - max_encoder_length - 2))
+ start_pos = max(0, end_pos - max_encoder_length + 2)
+ subtokens = subtokens[start_pos:end_pos]
+
+ # The start might be inside a word so fix it
+ # such that span always starts at a word
+ word_begin_mark = word_start_subtoken[subtokens]
+ word_begins_pos = np.flatnonzero(word_begin_mark).astype(np.int32)
+ if word_begins_pos.size == 0:
+ # if no word boundary present, we do not do whole word masking
+ # and we fall back to random masking.
+ word_begins_pos = np.arange(len(subtokens), dtype=np.int32)
+ word_begin_mark = np.logical_not(word_begin_mark)
+ print(subtokens, start_pos, end_pos, word_begin_mark)
+ correct_start_pos = word_begins_pos[0]
+ subtokens = subtokens[correct_start_pos:]
+ word_begin_mark = word_begin_mark[correct_start_pos:]
+ word_begins_pos = word_begins_pos - correct_start_pos
+ num_tokens = len(subtokens)
+
+ # @e want to do whole word masking so split by word boundary
+ words = np.split(np.arange(num_tokens, dtype=np.int32), word_begins_pos)[1:]
+ assert len(words) == len(word_begins_pos)
+
+ # Decide elements to mask
+ num_to_predict = min(
+ max_predictions_per_seq,
+ max(1, int(round(len(word_begins_pos) * masked_lm_prob))))
+ masked_lm_positions = np.concatenate(np.random.choice(
+ np.array([[]] + words, dtype=np.object)[1:],
+ num_to_predict, replace=False), 0)
+ # but this might have excess subtokens than max_predictions_per_seq
+ if len(masked_lm_positions) > max_predictions_per_seq:
+ masked_lm_positions = masked_lm_positions[:max_predictions_per_seq+1]
+ # however last word can cross word boundaries, remove crossing words
+ truncate_masking_at = np.flatnonzero(
+ word_begin_mark[masked_lm_positions])[-1]
+ masked_lm_positions = masked_lm_positions[:truncate_masking_at]
+
+ # sort masking positions
+ masked_lm_positions = np.sort(masked_lm_positions)
+ masked_lm_ids = subtokens[masked_lm_positions]
+
+ # replance input token with [MASK] 80%, random 10%, or leave it as it is.
+ randomness = np.random.rand(len(masked_lm_positions))
+ mask_index = masked_lm_positions[randomness < 0.8]
+ random_index = masked_lm_positions[randomness > 0.9]
+
+ subtokens[mask_index] = 67 # id of masked token
+ subtokens[random_index] = np.random.randint( # ignore special tokens
+ 101, vocab_size, len(random_index), dtype=np.int32)
+
+ # add [CLS] (65) and [SEP] (66) tokens
+ subtokens = np.concatenate([
+ np.array([65], dtype=np.int32), subtokens,
+ np.array([66], dtype=np.int32)
+ ])
+
+ # pad everything to correct shape
+ pad_inp = max_encoder_length - num_tokens - 2
+ subtokens = np.pad(subtokens, [0, pad_inp], "constant")
+
+ pad_out = max_predictions_per_seq - len(masked_lm_positions)
+ masked_lm_weights = np.pad(
+ np.ones_like(masked_lm_positions, dtype=np.float32),
+ [0, pad_out], "constant")
+ masked_lm_positions = np.pad(
+ masked_lm_positions + 1, [0, pad_out], "constant")
+ masked_lm_ids = np.pad(masked_lm_ids, [0, pad_out], "constant")
+
+ return subtokens, masked_lm_positions, masked_lm_ids, masked_lm_weights
+
+ def input_fn(params):
+ """The actual input function."""
+ batch_size = params["batch_size"]
+
+ # Load dataset and handle tfds separately
+ split = "train" if is_training else "test"
+ if "tfds://" == data_dir[:7]:
+ d = tfds.load(data_dir[7:], split=split,
+ shuffle_files=is_training,
+ data_dir=tmp_dir)
+ else:
+ input_files = tf.io.gfile.glob(
+ os.path.join(data_dir, "{}.tfrecord*".format(split)))
+
+ # For training, we want a lot of parallel reading and shuffling.
+ # For eval, we want no shuffling and parallel reading doesn't matter.
+ if is_training:
+ d = tf.data.Dataset.from_tensor_slices(tf.constant(input_files))
+ d = d.shuffle(buffer_size=len(input_files))
+
+ # Non deterministic mode means that the interleaving is not exact.
+ # This adds even more randomness to the training pipeline.
+ d = d.interleave(tf.data.TFRecordDataset,
+ deterministic=False,
+ num_parallel_calls=tf.data.experimental.AUTOTUNE)
+ else:
+ d = tf.data.TFRecordDataset(input_files)
+
+ if preprocessed_data:
+ d = d.map(_decode_record,
+ num_parallel_calls=tf.data.experimental.AUTOTUNE)
+ else:
+ d = d.map(do_masking,
+ num_parallel_calls=tf.data.experimental.AUTOTUNE)
+
+ if is_training:
+ d = d.shuffle(buffer_size=10000, reshuffle_each_iteration=True)
+ d = d.repeat()
+
+ d = d.padded_batch(batch_size, feature_shapes,
+ drop_remainder=True) # For static shape
+ return d
+
+ return input_fn
+
+
+def serving_input_fn_builder(batch_size, max_encoder_length,
+ vocab_model_file, substitute_newline):
+ """Creates an `input_fn` closure for exported SavedModel."""
+ def dynamic_padding(inp, min_size):
+ pad_size = tf.maximum(min_size - tf.shape(inp)[1], 0)
+ paddings = [[0, 0], [0, pad_size]]
+ return tf.pad(inp, paddings)
+
+ def input_fn():
+ # text input
+ text = tf.compat.v1.placeholder(tf.string, [batch_size], name="input_text")
+
+ # text tokenize
+ tokenizer = tft.SentencepieceTokenizer(
+ model=tf.io.gfile.GFile(vocab_model_file, "rb").read())
+ if substitute_newline:
+ text = tf.strings.regex_replace(text, "\n", substitute_newline)
+ ids = tokenizer.tokenize(text)
+ if isinstance(ids, tf.RaggedTensor):
+ ids = ids.to_tensor(0)
+
+ # text padding: Pad only if necessary and reshape properly
+ padded_ids = dynamic_padding(ids, max_encoder_length)
+ ids = tf.slice(padded_ids, [0, 0], [batch_size, max_encoder_length])
+
+ receiver_tensors = {"input": text}
+ features = {"input_ids": tf.cast(ids, tf.int32, name="input_ids")}
+
+ return tf.estimator.export.ServingInputReceiver(
+ features=features, receiver_tensors=receiver_tensors)
+
+ return input_fn
+
+
+def model_fn_builder(bert_config):
+ """Returns `model_fn` closure for TPUEstimator."""
+
+ def model_fn(features, labels, mode, params): # pylint: disable=unused-argument
+ """The `model_fn` for TPUEstimator."""
+
+ is_training = (mode == tf.estimator.ModeKeys.TRAIN)
+
+ model = modeling.BertModel(bert_config)
+ masked_lm = MaskedLMLayer(
+ bert_config["hidden_size"], bert_config["vocab_size"], model.embeder,
+ initializer=utils.create_initializer(bert_config["initializer_range"]),
+ activation_fn=utils.get_activation(bert_config["hidden_act"]))
+ next_sentence = NSPLayer(
+ bert_config["hidden_size"],
+ initializer=utils.create_initializer(bert_config["initializer_range"]))
+
+ sequence_output, pooled_output = model(
+ features["input_ids"], training=is_training,
+ token_type_ids=features.get("segment_ids"))
+
+ masked_lm_loss, masked_lm_log_probs = masked_lm(
+ sequence_output,
+ label_ids=features.get("masked_lm_ids"),
+ label_weights=features.get("masked_lm_weights"),
+ masked_lm_positions=features.get("masked_lm_positions"))
+
+ next_sentence_loss, next_sentence_log_probs = next_sentence(
+ pooled_output, features.get("next_sentence_labels"))
+
+ total_loss = masked_lm_loss
+ if bert_config["use_nsp"]:
+ total_loss += next_sentence_loss
+
+ tvars = tf.compat.v1.trainable_variables()
+ utils.log_variables(tvars, bert_config["ckpt_var_list"])
+
+ output_spec = None
+ if mode == tf.estimator.ModeKeys.TRAIN:
+
+ learning_rate = optimization.get_linear_warmup_linear_decay_lr(
+ init_lr=bert_config["learning_rate"],
+ num_train_steps=bert_config["num_train_steps"],
+ num_warmup_steps=bert_config["num_warmup_steps"])
+
+ optimizer = optimization.get_optimizer(bert_config, learning_rate)
+
+ global_step = tf.compat.v1.train.get_global_step()
+
+ gradients = optimizer.compute_gradients(total_loss, tvars)
+ train_op = optimizer.apply_gradients(gradients, global_step=global_step)
+
+ output_spec = tf.compat.v1.estimator.tpu.TPUEstimatorSpec(
+ mode=mode,
+ loss=total_loss,
+ train_op=train_op,
+ host_call=utils.add_scalars_to_summary(
+ bert_config["output_dir"], {"learning_rate": learning_rate}))
+
+ elif mode == tf.estimator.ModeKeys.EVAL:
+
+ def metric_fn(masked_lm_loss_value, masked_lm_log_probs, masked_lm_ids,
+ masked_lm_weights, next_sentence_loss_value,
+ next_sentence_log_probs, next_sentence_labels):
+ """Computes the loss and accuracy of the model."""
+ masked_lm_predictions = tf.argmax(
+ masked_lm_log_probs, axis=-1, output_type=tf.int32)
+ masked_lm_accuracy = tf.compat.v1.metrics.accuracy(
+ labels=masked_lm_ids,
+ predictions=masked_lm_predictions,
+ weights=masked_lm_weights)
+ masked_lm_mean_loss = tf.compat.v1.metrics.mean(
+ values=masked_lm_loss_value)
+
+ next_sentence_predictions = tf.argmax(
+ next_sentence_log_probs, axis=-1, output_type=tf.int32)
+ next_sentence_accuracy = tf.compat.v1.metrics.accuracy(
+ labels=next_sentence_labels, predictions=next_sentence_predictions)
+ next_sentence_mean_loss = tf.compat.v1.metrics.mean(
+ values=next_sentence_loss_value)
+
+ return {
+ "masked_lm_accuracy": masked_lm_accuracy,
+ "masked_lm_loss": masked_lm_mean_loss,
+ "next_sentence_accuracy": next_sentence_accuracy,
+ "next_sentence_loss": next_sentence_mean_loss,
+ }
+
+ eval_metrics = (metric_fn, [
+ masked_lm_loss, masked_lm_log_probs, features["masked_lm_ids"],
+ features["masked_lm_weights"], next_sentence_loss,
+ next_sentence_log_probs, features["next_sentence_labels"]
+ ])
+ output_spec = tf.compat.v1.estimator.tpu.TPUEstimatorSpec(
+ mode=mode,
+ loss=total_loss,
+ eval_metrics=eval_metrics)
+ else:
+
+ output_spec = tf.compat.v1.estimator.tpu.TPUEstimatorSpec(
+ mode=mode,
+ predictions={
+ "log-probabilities": masked_lm_log_probs,
+ "seq-embeddings": sequence_output
+ })
+
+ return output_spec
+
+ return model_fn
+
+
+class MaskedLMLayer(tf.compat.v1.layers.Layer):
+ """Get loss and log probs for the masked LM."""
+
+ def __init__(self,
+ hidden_size,
+ vocab_size,
+ embeder,
+ initializer=None,
+ activation_fn=None,
+ name="cls/predictions"):
+ super(MaskedLMLayer, self).__init__(name=name)
+ self.hidden_size = hidden_size
+ self.vocab_size = vocab_size
+ self.embeder = embeder
+
+ # We apply one more non-linear transformation before the output layer.
+ # This matrix is not used after pre-training.
+ self.extra_layer = utils.Dense2dLayer(
+ hidden_size, initializer,
+ activation_fn, "transform")
+ self.norm_layer = utils.NormLayer("transform")
+
+ # The output weights are the same as the input embeddings, but there is
+ # an output-only bias for each token.
+ self.output_bias = tf.compat.v1.get_variable(
+ name+"/output_bias",
+ shape=[vocab_size],
+ initializer=tf.zeros_initializer())
+
+ @property
+ def trainable_weights(self):
+ self._trainable_weights = (self.extra_layer +
+ self.norm_layer.trainable_weights +
+ [self.output_bias])
+ return self._trainable_weights
+
+ def call(self, input_tensor,
+ label_ids=None,
+ label_weights=None,
+ masked_lm_positions=None):
+ if masked_lm_positions is not None:
+ input_tensor = tf.gather(input_tensor, masked_lm_positions, batch_dims=1)
+
+ # We apply one more non-linear transformation before the output layer.
+ # This matrix is not used after pre-training.
+ with tf.compat.v1.variable_scope("transform") as sc:
+ input_tensor = self.extra_layer(input_tensor, scope=sc)
+ input_tensor = self.norm_layer(input_tensor, scope=sc)
+
+ # The output weights are the same as the input embeddings, but there is
+ # an output-only bias for each token.
+ logits = self.embeder.linear(input_tensor)
+ logits = tf.nn.bias_add(logits, self.output_bias)
+ log_probs = tf.nn.log_softmax(logits, axis=-1)
+
+ if label_ids is not None:
+ one_hot_labels = tf.one_hot(
+ label_ids, depth=self.vocab_size, dtype=tf.float32)
+
+ # The `positions` tensor might be zero-padded (if the sequence is too
+ # short to have the maximum number of predictions). The `label_weights`
+ # tensor has a value of 1.0 for every real prediction and 0.0 for the
+ # padding predictions.
+ per_example_loss = -tf.reduce_sum(log_probs * one_hot_labels, axis=-1)
+ numerator = tf.reduce_sum(label_weights * per_example_loss)
+ denominator = tf.reduce_sum(label_weights) + 1e-5
+ loss = numerator / denominator
+ else:
+ loss = tf.constant(0.0)
+
+ return loss, log_probs
+
+
+class NSPLayer(tf.compat.v1.layers.Layer):
+ """Get loss and log probs for the next sentence prediction."""
+
+ def __init__(self,
+ hidden_size,
+ initializer=None,
+ name="cls/seq_relationship"):
+ super(NSPLayer, self).__init__(name=name)
+ self.hidden_size = hidden_size
+
+ # Simple binary classification. Note that 0 is "next sentence" and 1 is
+ # "random sentence". This weight matrix is not used after pre-training.
+ with tf.compat.v1.variable_scope(name):
+ self.output_weights = tf.compat.v1.get_variable(
+ "output_weights",
+ shape=[2, hidden_size],
+ initializer=initializer)
+ self._trainable_weights.append(self.output_weights)
+ self.output_bias = tf.compat.v1.get_variable(
+ "output_bias", shape=[2], initializer=tf.zeros_initializer())
+ self._trainable_weights.append(self.output_bias)
+
+ def call(self, input_tensor, next_sentence_labels=None):
+ logits = tf.matmul(input_tensor, self.output_weights, transpose_b=True)
+ logits = tf.nn.bias_add(logits, self.output_bias)
+ log_probs = tf.nn.log_softmax(logits, axis=-1)
+
+ if next_sentence_labels is not None:
+ labels = tf.reshape(next_sentence_labels, [-1])
+ one_hot_labels = tf.one_hot(labels, depth=2, dtype=tf.float32)
+ per_example_loss = -tf.reduce_sum(one_hot_labels * log_probs, axis=-1)
+ loss = tf.reduce_mean(per_example_loss)
+ else:
+ loss = tf.constant(0.0)
+ return loss, log_probs
+
+
+def main(_):
+
+ if not FLAGS.do_train and not FLAGS.do_eval and not FLAGS.do_export:
+ raise ValueError(
+ "At least one of `do_train`, `do_eval` must be True.")
+
+ bert_config = flags.as_dictionary()
+
+ if FLAGS.max_encoder_length > bert_config["max_position_embeddings"]:
+ raise ValueError(
+ "Cannot use sequence length %d because the BERT model "
+ "was only trained up to sequence length %d" %
+ (FLAGS.max_encoder_length, bert_config["max_position_embeddings"]))
+
+ tf.io.gfile.makedirs(FLAGS.output_dir)
+ if FLAGS.do_train:
+ flags.save(os.path.join(FLAGS.output_dir, "pretrain.config"))
+
+ model_fn = model_fn_builder(bert_config)
+ estimator = utils.get_estimator(bert_config, model_fn)
+
+ if FLAGS.do_train:
+ logging.info("***** Running training *****")
+ logging.info(" Batch size = %d", estimator.train_batch_size)
+ logging.info(" Num steps = %d", FLAGS.num_train_steps)
+ train_input_fn = input_fn_builder(
+ data_dir=FLAGS.data_dir,
+ vocab_model_file=FLAGS.vocab_model_file,
+ masked_lm_prob=FLAGS.masked_lm_prob,
+ max_encoder_length=FLAGS.max_encoder_length,
+ max_predictions_per_seq=FLAGS.max_predictions_per_seq,
+ preprocessed_data=FLAGS.preprocessed_data,
+ substitute_newline=FLAGS.substitute_newline,
+ tmp_dir=os.path.join(FLAGS.output_dir, "tfds"),
+ is_training=True)
+ estimator.train(input_fn=train_input_fn, max_steps=FLAGS.num_train_steps)
+
+ if FLAGS.do_eval:
+ logging.info("***** Running evaluation *****")
+ logging.info(" Batch size = %d", estimator.eval_batch_size)
+
+ eval_input_fn = input_fn_builder(
+ data_dir=FLAGS.data_dir,
+ vocab_model_file=FLAGS.vocab_model_file,
+ masked_lm_prob=FLAGS.masked_lm_prob,
+ max_encoder_length=FLAGS.max_encoder_length,
+ max_predictions_per_seq=FLAGS.max_predictions_per_seq,
+ preprocessed_data=FLAGS.preprocessed_data,
+ substitute_newline=FLAGS.substitute_newline,
+ tmp_dir=os.path.join(FLAGS.output_dir, "tfds"),
+ is_training=False)
+
+ # Run continuous evaluation for latest checkpoint as training progresses.
+ last_evaluated = None
+ while True:
+ latest = tf.train.latest_checkpoint(FLAGS.output_dir)
+ if latest == last_evaluated:
+ if not latest:
+ logging.info("No checkpoints found yet.")
+ else:
+ logging.info("Latest checkpoint %s already evaluated.", latest)
+ time.sleep(300)
+ continue
+ else:
+ logging.info("Evaluating check point %s", latest)
+ last_evaluated = latest
+
+ current_step = int(os.path.basename(latest).split("-")[1])
+ output_eval_file = os.path.join(
+ FLAGS.output_dir, "eval_results_{}.txt".format(current_step))
+ result = estimator.evaluate(input_fn=eval_input_fn,
+ steps=FLAGS.max_eval_steps,
+ checkpoint_path=latest)
+
+ with tf.io.gfile.GFile(output_eval_file, "w") as writer:
+ logging.info("***** Eval results *****")
+ for key in sorted(result.keys()):
+ logging.info(" %s = %s", key, str(result[key]))
+ writer.write("%s = %s\n" % (key, str(result[key])))
+
+ if FLAGS.do_export:
+ logging.info("***** Running export *****")
+
+ serving_input_fn = serving_input_fn_builder(
+ batch_size=FLAGS.eval_batch_size,
+ vocab_model_file=FLAGS.vocab_model_file,
+ max_encoder_length=FLAGS.max_encoder_length,
+ substitute_newline=FLAGS.substitute_newline)
+
+ estimator.export_saved_model(
+ os.path.join(FLAGS.output_dir, "export"), serving_input_fn)
+
+
+if __name__ == "__main__":
+ tf.compat.v1.disable_v2_behavior()
+ app.run(main)
diff --git a/bigbird/summarization/__init__.py b/bigbird/summarization/__init__.py
new file mode 100644
index 0000000..f6cd7c8
--- /dev/null
+++ b/bigbird/summarization/__init__.py
@@ -0,0 +1,14 @@
+# Copyright 2020 The BigBird Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
diff --git a/bigbird/summarization/pegasus_large.sh b/bigbird/summarization/pegasus_large.sh
new file mode 100644
index 0000000..9e1250b
--- /dev/null
+++ b/bigbird/summarization/pegasus_large.sh
@@ -0,0 +1,32 @@
+#!/bin/bash
+
+# TF_XLA_FLAGS=--tf_xla_auto_jit=2
+python3 bigbird/summarization/run_summarization.py \
+ --data_dir="tfds://scientific_papers/pubmed" \
+ --output_dir="$GCP_EXP_BUCKET"summarization/pubmed \
+ --attention_type=block_sparse \
+ --couple_encoder_decoder=False \
+ --max_encoder_length=3072 \
+ --max_decoder_length=256 \
+ --num_attention_heads=16 \
+ --num_hidden_layers=16 \
+ --hidden_size=1024 \
+ --intermediate_size=4096 \
+ --block_size=64 \
+ --scope=pegasus \
+ --norm_type=prenorm \
+ --hidden_act=relu \
+ --use_bias=False \
+ --rescale_embedding=True \
+ --vocab_model_file=pegasus \
+ --substitute_newline="" \
+ --train_batch_size=2 \
+ --eval_batch_size=2 \
+ --do_train=True \
+ --do_eval=False \
+ --use_tpu=True \
+ --tpu_name=bigbird \
+ --tpu_zone=europe-west4-a \
+ --gcp_project="$GCP_PROJECT_NAME" \
+ --num_tpu_cores=128 \
+ --init_checkpoint=gs://bigbird-transformer/summarization/pubmed/pegasus/model.ckpt-0
diff --git a/bigbird/summarization/roberta_base.sh b/bigbird/summarization/roberta_base.sh
new file mode 100644
index 0000000..996bd77
--- /dev/null
+++ b/bigbird/summarization/roberta_base.sh
@@ -0,0 +1,25 @@
+#!/bin/bash
+
+# TF_XLA_FLAGS=--tf_xla_auto_jit=2
+python3 bigbird/summarization/run_summarization.py \
+ --data_dir="tfds://scientific_papers/pubmed" \
+ --output_dir="$GCP_EXP_BUCKET"summarization/pubmed \
+ --attention_type=block_sparse \
+ --couple_encoder_decoder=True \
+ --max_encoder_length=3072 \
+ --max_decoder_length=256 \
+ --num_attention_heads=12 \
+ --num_hidden_layers=12 \
+ --hidden_size=768 \
+ --intermediate_size=3072 \
+ --block_size=64 \
+ --train_batch_size=4 \
+ --eval_batch_size=4 \
+ --do_train=True \
+ --do_eval=False \
+ --use_tpu=True \
+ --tpu_name=bigbird \
+ --tpu_zone=europe-west4-a \
+ --gcp_project="$GCP_PROJECT_NAME" \
+ --num_tpu_cores=64 \
+ --init_checkpoint=gs://bigbird-transformer/pretrain/bigbr_base/model.ckpt-0
diff --git a/bigbird/summarization/run_summarization.py b/bigbird/summarization/run_summarization.py
new file mode 100644
index 0000000..8b3a2e1
--- /dev/null
+++ b/bigbird/summarization/run_summarization.py
@@ -0,0 +1,533 @@
+# Copyright 2020 The BigBird Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Run summarization fine-tuning for BigBird.."""
+
+import os
+import time
+
+from absl import app
+from absl import logging
+from bigbird.core import flags
+from bigbird.core import modeling
+from bigbird.core import optimization
+from bigbird.core import utils
+import tensorflow.compat.v2 as tf
+import tensorflow_datasets as tfds
+import tensorflow_text as tft
+
+
+from rouge_score import rouge_scorer
+
+FLAGS = flags.FLAGS
+
+## Required parameters
+
+flags.DEFINE_string(
+ "data_dir", "tfds://scientific_papers/pubmed",
+ "The input data dir. Should contain the TFRecord files. "
+ "Can be TF Dataset with prefix tfds://")
+
+flags.DEFINE_string(
+ "output_dir", "/tmp/bigb",
+ "The output directory where the model checkpoints will be written.")
+
+## Other parameters
+
+flags.DEFINE_string(
+ "init_checkpoint", None,
+ "Initial checkpoint (usually from a pre-trained BigBird model).")
+
+flags.DEFINE_integer(
+ "max_encoder_length", 128,
+ "The maximum total input sequence length after SentencePiece tokenization. "
+ "Sequences longer than this will be truncated, and sequences shorter "
+ "than this will be padded.")
+
+flags.DEFINE_integer(
+ "max_decoder_length", 128,
+ "The maximum total input sequence length after SentencePiece tokenization. "
+ "Sequences longer than this will be truncated, and sequences shorter "
+ "than this will be padded.")
+
+flags.DEFINE_string(
+ "substitute_newline", None,
+ "Replace newline charachter from text with supplied string.")
+
+flags.DEFINE_bool(
+ "do_train", True,
+ "Whether to run training.")
+
+flags.DEFINE_bool(
+ "do_eval", False,
+ "Whether to run eval on the dev set.")
+
+flags.DEFINE_bool(
+ "do_export", False,
+ "Whether to export the model as TF SavedModel.")
+
+flags.DEFINE_integer(
+ "train_batch_size", 8,
+ "Local batch size for training. "
+ "Total batch size will be multiplied by number gpu/tpu cores available.")
+
+flags.DEFINE_integer(
+ "eval_batch_size", 8,
+ "Local batch size for eval. "
+ "Total batch size will be multiplied by number gpu/tpu cores available.")
+
+flags.DEFINE_string(
+ "optimizer", "Adafactor",
+ "Optimizer to use. Can be Adafactor, Adam, and AdamWeightDecay.")
+
+flags.DEFINE_float(
+ "learning_rate", 0.32,
+ "The initial learning rate for Adam.")
+
+flags.DEFINE_integer(
+ "num_train_steps", 1000,
+ "Total number of training steps to perform.")
+
+flags.DEFINE_integer(
+ "num_warmup_steps", 100,
+ "Number of steps to perform linear warmup.")
+
+flags.DEFINE_integer(
+ "save_checkpoints_steps", 2000,
+ "How often to save the model checkpoint.")
+
+flags.DEFINE_integer(
+ "max_eval_steps", 100,
+ "Maximum number of eval steps.")
+
+flags.DEFINE_bool(
+ "couple_encoder_decoder", False,
+ "Whether to tie encoder and decoder weights.")
+
+flags.DEFINE_integer(
+ "beam_size", 5,
+ "Beam size for decoding.")
+
+flags.DEFINE_float(
+ "alpha", 0.8,
+ "Strength of length normalization for beam search.")
+
+flags.DEFINE_float(
+ "label_smoothing", 0.1,
+ "Label smoothing for prediction cross entropy loss.")
+
+
+def input_fn_builder(data_dir, vocab_model_file, max_encoder_length,
+ max_decoder_length, substitute_newline, is_training,
+ tmp_dir=None):
+ """Creates an `input_fn` closure to be passed to TPUEstimator."""
+
+ def _decode_record(record):
+ """Decodes a record to a TensorFlow example."""
+ name_to_features = {
+ "document": tf.io.FixedLenFeature([], tf.string),
+ "summary": tf.io.FixedLenFeature([], tf.string),
+ }
+ example = tf.io.parse_single_example(record, name_to_features)
+ return example["document"], example["summary"]
+
+ def _tokenize_example(document, summary):
+ tokenizer = tft.SentencepieceTokenizer(
+ model=tf.io.gfile.GFile(vocab_model_file, "rb").read())
+ if substitute_newline:
+ document = tf.strings.regex_replace(document, "\n", substitute_newline)
+ # Remove space before special tokens.
+ document = tf.strings.regex_replace(document, r" ([<\[]\S+[>\]])", b"\\1")
+ document_ids = tokenizer.tokenize(document)
+ if isinstance(document_ids, tf.RaggedTensor):
+ document_ids = document_ids.to_tensor(0)
+ document_ids = document_ids[:max_encoder_length]
+
+ # Remove newline optionally
+ if substitute_newline:
+ summary = tf.strings.regex_replace(summary, "\n", substitute_newline)
+ # Remove space before special tokens.
+ summary = tf.strings.regex_replace(summary, r" ([<\[]\S+[>\]])", b"\\1")
+ summary_ids = tokenizer.tokenize(summary)
+ # Add [EOS] (1) special tokens.
+ suffix = tf.constant([1])
+ summary_ids = tf.concat([summary_ids, suffix], axis=0)
+ if isinstance(summary_ids, tf.RaggedTensor):
+ summary_ids = summary_ids.to_tensor(0)
+ summary_ids = summary_ids[:max_decoder_length]
+
+ return document_ids, summary_ids
+
+ def input_fn(params):
+ """The actual input function."""
+ batch_size = params["batch_size"]
+
+ # Load dataset and handle tfds separately
+ split = "train" if is_training else "validation"
+ if "tfds://" == data_dir[:7]:
+ d = tfds.load(data_dir[7:], split=split, data_dir=tmp_dir,
+ shuffle_files=is_training, as_supervised=True)
+ else:
+ input_files = tf.io.gfile.glob(
+ os.path.join(data_dir, "{}.tfrecord*".format(split)))
+
+ # For training, we want a lot of parallel reading and shuffling.
+ # For eval, we want no shuffling and parallel reading doesn't matter.
+ if is_training:
+ d = tf.data.Dataset.from_tensor_slices(tf.constant(input_files))
+ d = d.shuffle(buffer_size=len(input_files))
+
+ # Non deterministic mode means that the interleaving is not exact.
+ # This adds even more randomness to the training pipeline.
+ d = d.interleave(tf.data.TFRecordDataset,
+ deterministic=False,
+ num_parallel_calls=tf.data.experimental.AUTOTUNE)
+ else:
+ d = tf.data.TFRecordDataset(input_files)
+
+ d = d.map(_decode_record,
+ num_parallel_calls=tf.data.experimental.AUTOTUNE,
+ deterministic=is_training)
+
+ d = d.map(_tokenize_example,
+ num_parallel_calls=tf.data.experimental.AUTOTUNE,
+ deterministic=is_training)
+
+ if is_training:
+ d = d.shuffle(buffer_size=10000, reshuffle_each_iteration=True)
+ d = d.repeat()
+ d = d.padded_batch(batch_size, ([max_encoder_length], [max_decoder_length]),
+ drop_remainder=True) # For static shape
+ return d
+
+ return input_fn
+
+
+def serving_input_fn_builder(batch_size, max_encoder_length,
+ vocab_model_file, substitute_newline):
+ """Creates an `input_fn` closure for exported SavedModel."""
+ def dynamic_padding(inp, min_size):
+ pad_size = tf.maximum(min_size - tf.shape(inp)[1], 0)
+ paddings = [[0, 0], [0, pad_size]]
+ return tf.pad(inp, paddings)
+
+ def input_fn():
+ # text input
+ text = tf.compat.v1.placeholder(tf.string, [batch_size], name="input_text")
+
+ # text tokenize
+ tokenizer = tft.SentencepieceTokenizer(
+ model=tf.io.gfile.GFile(vocab_model_file, "rb").read())
+ if substitute_newline:
+ text = tf.strings.regex_replace(text, "\n", substitute_newline)
+ # Remove space before special tokens.
+ text = tf.strings.regex_replace(text, r" ([<\[]\S+[>\]])", b"\\1")
+ ids = tokenizer.tokenize(text)
+ if isinstance(ids, tf.RaggedTensor):
+ ids = ids.to_tensor(0)
+
+ # text padding: Pad only if necessary and reshape properly
+ padded_ids = dynamic_padding(ids, max_encoder_length)
+ ids = tf.slice(padded_ids, [0, 0], [batch_size, max_encoder_length])
+
+ receiver_tensors = {"input": text}
+ features = {"input_ids": tf.cast(ids, tf.int32, name="input_ids")}
+
+ return tf.estimator.export.ServingInputReceiver(
+ features=features, receiver_tensors=receiver_tensors)
+
+ return input_fn
+
+
+def model_fn_builder(transformer_config):
+ """Returns `model_fn` closure for TPUEstimator."""
+
+ def model_fn(features, labels, mode, params): # pylint: disable=unused-argument
+ """The `model_fn` for TPUEstimator."""
+
+ if isinstance(features, dict):
+ if not labels and "target_ids" in features:
+ labels = features["target_ids"]
+ features = features["input_ids"]
+
+ is_training = (mode == tf.estimator.ModeKeys.TRAIN)
+
+ model = modeling.TransformerModel(transformer_config)
+ (llh, logits, pred_ids), _ = model(features, target_ids=labels,
+ training=is_training)
+
+ total_loss = padded_cross_entropy_loss(
+ logits, labels,
+ transformer_config["label_smoothing"],
+ transformer_config["vocab_size"])
+
+ tvars = tf.compat.v1.trainable_variables()
+ utils.log_variables(tvars, transformer_config["ckpt_var_list"])
+
+ output_spec = None
+ if mode == tf.estimator.ModeKeys.TRAIN:
+
+ learning_rate = optimization.get_linear_warmup_rsqrt_decay_lr(
+ init_lr=transformer_config["learning_rate"],
+ hidden_size=transformer_config["hidden_size"],
+ num_warmup_steps=transformer_config["num_warmup_steps"])
+
+ optimizer = optimization.get_optimizer(transformer_config, learning_rate)
+
+ global_step = tf.compat.v1.train.get_global_step()
+
+ if not transformer_config["use_bias"]:
+ logging.info("Fixing position embedding, i.e. not trainable.")
+ posemb = "pegasus/embeddings/position_embeddings"
+ tvars = list(filter(lambda v: v.name.split(":")[0] != posemb, tvars))
+
+ gradients = optimizer.compute_gradients(total_loss, tvars)
+ train_op = optimizer.apply_gradients(gradients, global_step=global_step)
+
+ output_spec = tf.compat.v1.estimator.tpu.TPUEstimatorSpec(
+ mode=mode,
+ loss=total_loss,
+ train_op=train_op,
+ host_call=utils.add_scalars_to_summary(
+ transformer_config["output_dir"],
+ {"learning_rate": learning_rate}))
+
+ elif mode == tf.estimator.ModeKeys.EVAL:
+
+ tokenizer = tft.SentencepieceTokenizer(
+ model=tf.io.gfile.GFile(transformer_config["vocab_model_file"],
+ "rb").read())
+
+ def rouge_py_func(label_sent, pred_sent):
+ """Approximate ROUGE scores, always run externally for final scores."""
+ scorer = rouge_scorer.RougeScorer(
+ ["rouge1", "rouge2", "rougeLsum"],
+ use_stemmer=True)
+ r1, r2, rl = 0.0, 0.0, 0.0
+ for ls, ps in zip(label_sent, pred_sent):
+ score = scorer.score(ls.decode("utf-8"), ps.decode("utf-8"))
+ r1 += score["rouge1"].fmeasure
+ r2 += score["rouge2"].fmeasure
+ rl += score["rougeLsum"].fmeasure
+ return r1/len(label_sent), r2/len(label_sent), rl/len(label_sent)
+
+ def metric_fn(loss, log_probs, label_ids, pred_ids):
+ loss = tf.compat.v1.metrics.mean(values=loss)
+ log_probs = tf.compat.v1.metrics.mean(
+ values=log_probs,
+ weights=tf.cast(tf.not_equal(label_ids, 0), tf.float32))
+ metric_dict = {
+ "prediction_loss": loss,
+ "log_likelihood": log_probs,
+ }
+
+ if not transformer_config["use_tpu"]:
+ # Approximate ROUGE scores if not running on tpus.
+ # Always run externally for final scores.
+ label_sent = tokenizer.detokenize(label_ids)
+ label_sent = tf.strings.regex_replace(label_sent, r"([<\[]\S+[>\]])",
+ b" \\1")
+ pred_sent = tokenizer.detokenize(pred_ids)
+ pred_sent = tf.strings.regex_replace(pred_sent, r"([<\[]\S+[>\]])",
+ b" \\1")
+ if transformer_config["substitute_newline"]:
+ label_sent = tf.strings.regex_replace(
+ label_sent, transformer_config["substitute_newline"], "\n")
+ pred_sent = tf.strings.regex_replace(
+ pred_sent, transformer_config["substitute_newline"], "\n")
+ rouge_value = tf.compat.v1.py_func(
+ func=rouge_py_func,
+ inp=[label_sent, pred_sent],
+ Tout=[tf.float64, tf.float64, tf.float64],
+ stateful=False)
+ rouge_value = tf.cast(rouge_value, tf.float32)
+ rouge1 = tf.compat.v1.metrics.mean(values=rouge_value[0])
+ rouge2 = tf.compat.v1.metrics.mean(values=rouge_value[1])
+ rougeL = tf.compat.v1.metrics.mean(values=rouge_value[2]) # pylint: disable=invalid-name
+
+ metric_dict.update({
+ "eval/Rouge-1": rouge1,
+ "eval/Rouge-2": rouge2,
+ "eval/Rouge-L": rougeL,
+ })
+ return metric_dict
+
+ eval_metrics = (metric_fn,
+ [total_loss, llh, labels, pred_ids])
+ output_spec = tf.compat.v1.estimator.tpu.TPUEstimatorSpec(
+ mode=mode,
+ loss=total_loss,
+ eval_metrics=eval_metrics)
+ else:
+
+ prediction_dict = {"pred_ids": pred_ids}
+ if not transformer_config["use_tpu"]:
+ tokenizer = tft.SentencepieceTokenizer(
+ model=tf.io.gfile.GFile(transformer_config["vocab_model_file"],
+ "rb").read())
+ pred_sent = tokenizer.detokenize(pred_ids)
+ # Add a space before special tokens.
+ pred_sent = tf.strings.regex_replace(
+ pred_sent, r"([<\[]\S+[>\]])", b" \\1")
+ if transformer_config["substitute_newline"]:
+ pred_sent = tf.strings.regex_replace(
+ pred_sent, transformer_config["substitute_newline"], "\n")
+ prediction_dict.update({"pred_sent": pred_sent})
+
+ output_spec = tf.compat.v1.estimator.tpu.TPUEstimatorSpec(
+ mode=mode,
+ predictions=prediction_dict)
+
+ return output_spec
+
+ return model_fn
+
+
+def padded_cross_entropy_loss(logits, labels, smoothing, vocab_size):
+ """Calculate cross entropy loss while ignoring padding.
+
+ Args:
+ logits: Tensor of size [batch_size, length_logits, vocab_size]
+ labels: Tensor of size [batch_size, length_labels]
+ smoothing: Label smoothing constant, used to determine the on and off values
+ vocab_size: int size of the vocabulary
+ Returns:
+ Returns the cross entropy loss and weight tensors: float32 tensors with
+ shape [batch_size, max(length_logits, length_labels)]
+ """
+ with tf.name_scope("loss"):
+
+ if labels is not None:
+ # Calculate smoothing cross entropy
+ with tf.name_scope("smoothing_cross_entropy"):
+ confidence = 1.0 - smoothing
+ vocab_float = tf.cast(vocab_size - 1, tf.float32)
+ low_confidence = (1.0 - confidence) / vocab_float
+ soft_targets = tf.one_hot(
+ labels,
+ depth=vocab_size,
+ on_value=confidence,
+ off_value=low_confidence)
+ xentropy = tf.nn.softmax_cross_entropy_with_logits(
+ logits=logits, labels=soft_targets)
+
+ # Calculate the best (lowest) possible value of cross entropy, and
+ # subtract from the cross entropy loss.
+ normalizing_constant = -(
+ confidence * tf.math.log(confidence) + vocab_float *
+ low_confidence * tf.math.log(low_confidence + 1e-20))
+ xentropy -= normalizing_constant
+
+ weights = tf.cast(tf.not_equal(labels, 0), tf.float32)
+ loss = tf.reduce_sum(xentropy) / tf.reduce_sum(weights)
+
+ else:
+ loss = tf.constant(0.0)
+
+ return loss
+
+
+def main(_):
+
+ if not FLAGS.do_train and not FLAGS.do_eval and not FLAGS.do_export:
+ raise ValueError(
+ "At least one of `do_train`, `do_eval` must be True.")
+
+ transformer_config = flags.as_dictionary()
+
+ if FLAGS.max_encoder_length > transformer_config["max_position_embeddings"]:
+ raise ValueError(
+ "Cannot use sequence length %d because the model "
+ "was only trained up to sequence length %d" %
+ (FLAGS.max_encoder_length,
+ transformer_config["max_position_embeddings"]))
+
+ tf.io.gfile.makedirs(FLAGS.output_dir)
+ if FLAGS.do_train:
+ flags.save(os.path.join(FLAGS.output_dir, "summarization.config"))
+
+ model_fn = model_fn_builder(transformer_config)
+ estimator = utils.get_estimator(transformer_config, model_fn)
+
+ if FLAGS.do_train:
+ logging.info("***** Running training *****")
+ logging.info(" Batch size = %d", estimator.train_batch_size)
+ logging.info(" Num steps = %d", FLAGS.num_train_steps)
+ train_input_fn = input_fn_builder(
+ data_dir=FLAGS.data_dir,
+ vocab_model_file=FLAGS.vocab_model_file,
+ max_encoder_length=FLAGS.max_encoder_length,
+ max_decoder_length=FLAGS.max_decoder_length,
+ substitute_newline=FLAGS.substitute_newline,
+ tmp_dir=os.path.join(FLAGS.output_dir, "tfds"),
+ is_training=True)
+ estimator.train(input_fn=train_input_fn, max_steps=FLAGS.num_train_steps)
+
+ if FLAGS.do_eval:
+ logging.info("***** Running evaluation *****")
+ logging.info(" Batch size = %d", estimator.eval_batch_size)
+
+ eval_input_fn = input_fn_builder(
+ data_dir=FLAGS.data_dir,
+ vocab_model_file=FLAGS.vocab_model_file,
+ max_encoder_length=FLAGS.max_encoder_length,
+ max_decoder_length=FLAGS.max_decoder_length,
+ substitute_newline=FLAGS.substitute_newline,
+ tmp_dir=os.path.join(FLAGS.output_dir, "tfds"),
+ is_training=False)
+
+ # Run continuous evaluation for latest checkpoint as training progresses.
+ last_evaluated = None
+ while True:
+ latest = tf.train.latest_checkpoint(FLAGS.output_dir)
+ if latest == last_evaluated:
+ if not latest:
+ logging.info("No checkpoints found yet.")
+ else:
+ logging.info("Latest checkpoint %s already evaluated.", latest)
+ time.sleep(300)
+ continue
+ else:
+ logging.info("Evaluating check point %s", latest)
+ last_evaluated = latest
+
+ current_step = int(os.path.basename(latest).split("-")[1])
+ output_eval_file = os.path.join(
+ FLAGS.output_dir, "eval_results_{}.txt".format(current_step))
+ result = estimator.evaluate(input_fn=eval_input_fn,
+ steps=FLAGS.max_eval_steps,
+ checkpoint_path=latest)
+
+ with tf.io.gfile.GFile(output_eval_file, "w") as writer:
+ logging.info("***** Eval results *****")
+ for key in sorted(result.keys()):
+ logging.info(" %s = %s", key, str(result[key]))
+ writer.write("%s = %s\n" % (key, str(result[key])))
+
+ if FLAGS.do_export:
+ logging.info("***** Running export *****")
+
+ serving_input_fn = serving_input_fn_builder(
+ batch_size=FLAGS.eval_batch_size,
+ vocab_model_file=FLAGS.vocab_model_file,
+ max_encoder_length=FLAGS.max_encoder_length,
+ substitute_newline=FLAGS.substitute_newline)
+
+ estimator.export_saved_model(
+ os.path.join(FLAGS.output_dir, "export"), serving_input_fn)
+
+
+if __name__ == "__main__":
+ tf.compat.v1.disable_v2_behavior()
+ app.run(main)
diff --git a/bigbird/vocab/gpt2.model b/bigbird/vocab/gpt2.model
new file mode 100644
index 0000000..486f369
Binary files /dev/null and b/bigbird/vocab/gpt2.model differ
diff --git a/bigbird/vocab/pegasus.model b/bigbird/vocab/pegasus.model
new file mode 100644
index 0000000..cb9f8aa
Binary files /dev/null and b/bigbird/vocab/pegasus.model differ
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000..2340851
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,9 @@
+absl-py
+natsort
+numpy
+rouge-score
+sentencepiece
+tensorflow
+tensorflow-text
+tensor2tensor
+tfds-nightly
diff --git a/setup.py b/setup.py
new file mode 100644
index 0000000..774fdf4
--- /dev/null
+++ b/setup.py
@@ -0,0 +1,45 @@
+# Copyright 2020 The BigBird Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Install BigBird."""
+
+import setuptools
+
+# Get install requirements from the REQUIREMENTS file.
+with open('requirements.txt') as fp:
+ _REQUIREMENTS = fp.read().splitlines()
+
+# Get the long description from the README file.
+with open('README.md') as fp:
+ _LONG_DESCRIPTION = fp.read()
+
+setuptools.setup(
+ name='bigbird',
+ version='0.0.1',
+ description='Big Bird: Transformers for Long Sequences',
+ long_description=_LONG_DESCRIPTION,
+ long_description_content_type='text/markdown',
+ author='Google Inc.',
+ author_email='no-reply@google.com',
+ url='http://github.com/google-research/bigbird',
+ license='Apache 2.0',
+ packages=[
+ 'bigbird', 'bigbird.core', 'bigbird.classifier',
+ 'bigbird.pretrain', 'bigbird.summarization'
+ ],
+ package_data={'bigbird': ['vocab/*']},
+ scripts=[],
+ install_requires=_REQUIREMENTS,
+ keywords='deeplearning machinelearning nlp classifier qa summarization transformer pretraining',
+)