From 8cadd5df0386a298687f7b861f174c11bf4aa4bd Mon Sep 17 00:00:00 2001 From: fatih <34196005+fcakyon@users.noreply.github.com> Date: Wed, 10 Nov 2021 21:07:37 +0300 Subject: [PATCH] initial upload (#1) * initial commit * read logging level from env * fix comment * update setup.cfg * update environment yml * rename train to run * relocate modules * add init file * update readme --- .github/workflows/ci.yml | 85 ++++ .gitignore | 140 +++++ LICENSE | 21 + README.md | 242 ++++++++- configs/default/config.json | 48 ++ configs/default/config.yaml | 35 ++ configs/paper/berturk-qatask-tquad2.yaml | 36 ++ configs/paper/mt5base-3task-both-tquad2.yaml | 35 ++ .../paper/mt5base-3task-highlight-tquad2.yaml | 35 ++ .../paper/mt5base-3task-prepend-tquad2.yaml | 35 ++ configs/paper/mt5base-qatask-tquad2.yaml | 35 ++ configs/paper/mt5base-qgtask-both-tquad2.yaml | 35 ++ .../mt5base-qgtask-highlight-tquad2.yaml | 35 ++ .../paper/mt5base-qgtask-prepend-tquad2.yaml | 35 ++ configs/paper/mt5small-3task-both-tquad2.yaml | 35 ++ core/__init__.py | 0 core/api.py | 297 +++++++++++ core/argument_parsers.py | 162 ++++++ core/bert_api.py | 129 +++++ core/collator.py | 81 +++ core/dataset_parsers.py | 195 +++++++ core/evaluate.py | 289 +++++++++++ core/generate.py | 163 ++++++ core/pipelines.py | 242 +++++++++ environment.yml | 26 + hf/__init__.py | 0 hf/model.py | 106 ++++ prepare_data.py | 193 +++++++ pyproject.toml | 17 + requirements.txt | 13 + run.py | 230 +++++++++ setup.cfg | 5 + tests/__init__.py | 0 tests/test_config.yaml | 40 ++ tr_non_suffixes | 246 +++++++++ utils/__init__.py | 5 + utils/file.py | 141 +++++ utils/neptune.py | 27 + utils/nlp.py | 481 ++++++++++++++++++ utils/torch.py | 56 ++ utils/wandb.py | 22 + 41 files changed, 4051 insertions(+), 2 deletions(-) create mode 100644 .github/workflows/ci.yml create mode 100644 .gitignore create mode 100644 LICENSE create mode 100644 configs/default/config.json create mode 100644 configs/default/config.yaml create mode 100644 configs/paper/berturk-qatask-tquad2.yaml create mode 100644 configs/paper/mt5base-3task-both-tquad2.yaml create mode 100644 configs/paper/mt5base-3task-highlight-tquad2.yaml create mode 100644 configs/paper/mt5base-3task-prepend-tquad2.yaml create mode 100644 configs/paper/mt5base-qatask-tquad2.yaml create mode 100644 configs/paper/mt5base-qgtask-both-tquad2.yaml create mode 100644 configs/paper/mt5base-qgtask-highlight-tquad2.yaml create mode 100644 configs/paper/mt5base-qgtask-prepend-tquad2.yaml create mode 100644 configs/paper/mt5small-3task-both-tquad2.yaml create mode 100644 core/__init__.py create mode 100644 core/api.py create mode 100644 core/argument_parsers.py create mode 100644 core/bert_api.py create mode 100644 core/collator.py create mode 100644 core/dataset_parsers.py create mode 100644 core/evaluate.py create mode 100644 core/generate.py create mode 100644 core/pipelines.py create mode 100644 environment.yml create mode 100644 hf/__init__.py create mode 100644 hf/model.py create mode 100644 prepare_data.py create mode 100644 pyproject.toml create mode 100644 requirements.txt create mode 100644 run.py create mode 100644 setup.cfg create mode 100644 tests/__init__.py create mode 100644 tests/test_config.yaml create mode 100644 tr_non_suffixes create mode 100644 utils/__init__.py create mode 100644 utils/file.py create mode 100644 utils/neptune.py create mode 100644 utils/nlp.py create mode 100644 utils/torch.py create mode 100644 utils/wandb.py diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..e7ed3be --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,85 @@ +name: CI + +on: + push: + branches: [main] + pull_request: + branches: [main] + +jobs: + build: + runs-on: ubuntu-latest + strategy: + matrix: + operating-system: [ubuntu-latest] + python-version: [3.8, 3.9] + fail-fast: false + + steps: + - name: Checkout + uses: actions/checkout@v2 + + - name: Set up Python + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + + - name: Restore Ubuntu cache + uses: actions/cache@v1 + if: matrix.operating-system == 'ubuntu-latest' + with: + path: ~/.cache/pip + key: ${{ matrix.os }}-${{ matrix.python-version }}-${{ hashFiles('**/setup.py')}} + restore-keys: ${{ matrix.os }}-${{ matrix.python-version }}- + + - name: Restore MacOS cache + uses: actions/cache@v1 + if: matrix.operating-system == 'macos-latest' + with: + path: ~/Library/Caches/pip + key: ${{ matrix.os }}-${{ matrix.python-version }}-${{ hashFiles('**/setup.py')}} + restore-keys: ${{ matrix.os }}-${{ matrix.python-version }}- + + - name: Restore Windows cache + uses: actions/cache@v1 + if: matrix.operating-system == 'windows-latest' + with: + path: ~\AppData\Local\pip\Cache + key: ${{ matrix.os }}-${{ matrix.python-version }}-${{ hashFiles('**/setup.py')}} + restore-keys: ${{ matrix.os }}-${{ matrix.python-version }}- + + - name: Update pip + run: python -m pip install --upgrade pip + + - name: Install PyTorch on Linux and Windows + if: > + matrix.operating-system == 'ubuntu-latest' || + matrix.operating-system == 'windows-latest' + run: > + pip install torch==1.10.0+cpu + -f https://download.pytorch.org/whl/torch_stable.html + + - name: Install PyTorch on MacOS + if: matrix.operating-system == 'macos-latest' + run: pip install torch==1.10.0 + + - name: Install requirements + run: > + pip install -r requirements.txt + + - name: Lint with flake8, black and isort + run: | + pip install "black==21.7b0" "flake8==3.9.2" "isort==5.9.2" + # stop the build if there are Python syntax errors or undefined names + flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics + black . --check --config pyproject.toml + isort -c . + # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide + flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics + + - name: Test prepare_data and train + run: | + # prepare_data + python prepare_data.py tests/test_config.yaml + # train + python run.py tests/test_config.yaml diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..0631120 --- /dev/null +++ b/.gitignore @@ -0,0 +1,140 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +pip-wheel-metadata/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +.python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# other +.idea/ +.vscode +*.pt +data/ +mt5_qg_tokenizer/ +runs/ +*tokenizer/ +wandb +.neptune \ No newline at end of file diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..6fb8020 --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2021 obss + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/README.md b/README.md index 49ed13b..9eb07a4 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,240 @@ -# turkish-question-generation -Automated question generation and question answering from Turkish texts using text-to-text transformers +
+

+ Turkish Question Generation +

+ +

+ Offical implementation for "Automated question generation & question answering from Turkish texts using text-to-text transformers". +

+
+ + +
+ +install + + +```bash +git clone https://github.com/obss/turkish-question-generation.git +cd turkish-question-generation +pip install -r requirements.txt +``` +
+ +
+ +train + + +- start a training using args: + +```bash +python run.py --model_name_or_path google/mt5-small --output_dir runs/exp1 --do_train --do_eval --tokenizer_name_or_path mt5_qg_tokenizer --per_device_train_batch_size 4 --gradient_accumulation_steps 2 --learning_rate 1e-4 --seed 42 --save_total_limit 1 +``` + +- download [json config](configs/default/config.json) file and start a training: + +```bash +python run.py config.json +``` + +- downlaod [yaml config](configs/default/config.yaml) file and start a training: + +```bash +python run.py config.yaml +``` + +
+ +
+ +evaluate + + +- arrange related params in config: + +```yaml +do_train: false +do_eval: true +eval_dataset_list: ["tquad2-valid", "xquad.tr"] +prepare_data: true +mt5_task_list: ["qa", "qg", "ans_ext"] +mt5_qg_format: "both" +no_cuda: false +``` + +- start an evaluation: + +```bash +python run.py config.yaml +``` + +
+ +
+ +neptune + + +- install neptune: + +```bash +pip install neptune-client +``` + +- download [config](configs/default/config.yaml) file and arrange neptune params: + +```yaml +run_name: 'exp1' +neptune_project: 'name/project' +neptune_api_token: 'YOUR_API_TOKEN' +``` + +- start a training: + +```bash +python train.py config.yaml +``` + +
+ +
+ +wandb + + +- install wandb: + +```bash +pip install wandb +``` + +- download [config](configs/default/config.yaml) file and arrange wandb params: + +```yaml +run_name: 'exp1' +wandb_project: 'turque' +``` + +- start a training: + +```bash +python train.py config.yaml +``` + +
+ +
+ +finetuned checkpoints + + +[model_url1]: https://drive.google.com/file/d/10hHFuavHCofDczGSzsH1xPHgTgAocOl1/view?usp=sharing +[model_url2]: https://huggingface.co/google/mt5-small +[data_url1]: https://github.com/okanvk/Turkish-Reading-Comprehension-Question-Answering-Dataset/blob/master/data/2018%20%2B%202020%20veri%20k%C3%BCmesi/final_train_data_v2.json +[data_url2]: https://github.com/okanvk/Turkish-Reading-Comprehension-Question-Answering-Dataset/blob/master/data/2018%20%2B%202020%20veri%20k%C3%BCmesi/final_dev_data_v2.json +[data_url3]: https://github.com/deepmind/xquad/blob/master/xquad.tr.json + + +|Name |Model |data
train |params
(M) |model size
(GB) | +|--- |--- |--- |--- |--- | +|[turque-s1][model_url1] |[mt5-small][model_url2] |[tquad-train][data_url1]+[tquad-val][data_url2]+[xquad.tr][data_url3] |300M |1.2GB | + +
+ +
+ +format + + +- answer extraction: + +input: +``` +" Osman Bey 1258 yılında Söğüt’te doğdu. Osman Bey 1 Ağustos 1326’da Bursa’da hayatını kaybetmiştir.1281 yılında Osman Bey 23 yaşında iken Ahi teşkilatından olan Şeyh Edebali’nin kızı Malhun Hatun ile evlendi." +``` + +target: +``` + 1258 Söğüt’te +``` + +- question answering: + +input: +``` +"question: Osman Bey nerede doğmuştur? context: Osman Bey 1258 yılında Söğüt’te doğdu. Osman Bey 1 Ağustos 1326’da Bursa’da hayatını kaybetmiştir.1281 yılında Osman Bey 23 yaşında iken Ahi teşkilatından olan Şeyh Edebali’nin kızı Malhun Hatun ile evlendi." +``` + +target: +``` +"Söğüt’te" +``` + +- question generation (prepend): + +input: +``` +"answer: Söğüt’te context: Osman Bey 1258 yılında Söğüt’te doğdu. Osman Bey 1 Ağustos 1326’da Bursa’da hayatını kaybetmiştir.1281 yılında Osman Bey 23 yaşında iken Ahi teşkilatından olan Şeyh Edebali’nin kızı Malhun Hatun ile evlendi." +``` + +target: +``` +"Osman Bey nerede doğmuştur?" +``` + +- question generation (highlight): + +input: +``` +"generate question: Osman Bey 1258 yılında Söğüt’te doğdu. Osman Bey 1 Ağustos 1326’da Bursa’da hayatını kaybetmiştir.1281 yılında Osman Bey 23 yaşında iken Ahi teşkilatından olan Şeyh Edebali’nin kızı Malhun Hatun ile evlendi." +``` + +target: +``` +"Osman Bey nerede doğmuştur?" +``` + +- question generation (both): + +input: +``` +"answer: Söğüt’te context: Osman Bey 1258 yılında Söğüt’te doğdu. Osman Bey 1 Ağustos 1326’da Bursa’da hayatını kaybetmiştir.1281 yılında Osman Bey 23 yaşında iken Ahi teşkilatından olan Şeyh Edebali’nin kızı Malhun Hatun ile evlendi." +``` + +target: +``` +"Osman Bey nerede doğmuştur?" +``` +
+ +
+ +paper configs + + +You can find the config files used in the paper under [configs/paper](configs/paper). + +
+ +
+ +contributing + + +Before opening a PR: + +- Install required development packages: + +```bash +pip install "black==21.7b0" "flake8==3.9.2" "isort==5.9.2" +``` + +- Reformat with black and isort: + +```bash +black . --config pyproject.toml +isort . +``` + +
diff --git a/configs/default/config.json b/configs/default/config.json new file mode 100644 index 0000000..b367344 --- /dev/null +++ b/configs/default/config.json @@ -0,0 +1,48 @@ +{ + "model_name_or_path": "google/mt5-small", + "tokenizer_path": "mt5_small_tokenizer", + "label_smoothing_factor": 0, + "freeze_embeddings": false, + "run_name": "exp1", + "wandb_project": null, + "neptune_project": null, + "neptune_api_token": null, + "train_dataset_list": [ + "tquad2-train" + ], + "valid_dataset_list": [ + "tquad2-valid" + ], + "eval_dataset_list": [ + "tquad2-valid", + "xquad.tr" + ], + "train_file_path": "data/train_data_multitask_mt5.pt", + "valid_file_path": "data/valid_data_multitask_mt5.pt", + "max_source_length": 512, + "max_target_length": 80, + "prepare_data": true, + "mt5_task_list": [ + "qa", + "qg", + "ans_ext" + ], + "mt5_qg_format": "highlight", + "output_dir": "runs/exp1", + "do_train": true, + "do_eval": true, + "evaluation_strategy": "steps", + "eval_steps": 2000, + "per_device_train_batch_size": 4, + "per_device_eval_batch_size": 4, + "gradient_accumulation_steps": 1, + "eval_accumulation_steps": 1, + "learning_rate": 1e-4, + "num_train_epochs": 10, + "save_total_limit": 1, + "no_cuda": false, + "seed": 42, + "fp16": false, + "fp16_full_eval": false, + "adafactor": true +} \ No newline at end of file diff --git a/configs/default/config.yaml b/configs/default/config.yaml new file mode 100644 index 0000000..c125fa8 --- /dev/null +++ b/configs/default/config.yaml @@ -0,0 +1,35 @@ +model_name_or_path: "google/mt5-small" +tokenizer_path: "mt5_small_tokenizer" +label_smoothing_factor: 0 +freeze_embeddings: false +run_name: "exp1" +wandb_project: null +neptune_project: null +neptune_api_token: null +train_dataset_list: ["tquad2-train"] +valid_dataset_list: ["tquad2-valid"] +eval_dataset_list: ["tquad2-valid", "xquad.tr"] +train_file_path: "data/train_data_multitask_mt5.pt" +valid_file_path: "data/valid_data_multitask_mt5.pt" +max_source_length: 512 +max_target_length: 80 +prepare_data: true +mt5_task_list: ["qa", "qg", "ans_ext"] +mt5_qg_format: "highlight" +output_dir: "runs/exp1" +do_train: true +do_eval: true +evaluation_strategy: "steps" +eval_steps: 2000 +per_device_train_batch_size: 4 +per_device_eval_batch_size: 4 +gradient_accumulation_steps: 1 +eval_accumulation_steps: 1 +learning_rate: 1.0e-4 +num_train_epochs: 10 +save_total_limit: 1 +no_cuda: false +seed: 42 +fp16: false +fp16_full_eval: false +adafactor: true \ No newline at end of file diff --git a/configs/paper/berturk-qatask-tquad2.yaml b/configs/paper/berturk-qatask-tquad2.yaml new file mode 100644 index 0000000..dc5a610 --- /dev/null +++ b/configs/paper/berturk-qatask-tquad2.yaml @@ -0,0 +1,36 @@ +model_name_or_path: "dbmdz/bert-base-turkish-cased" +tokenizer_path: "tokenizers/bert-base-turkish-cased_tokenizer" +label_smoothing: 0 +freeze_embeddings: false +run_name: "turque-bertbase-adamw-1e4-3ep-tquad2train" +wandb_project: null +wandb_id: null +neptune_project: null +neptune_run: null +neptune_api_token: +train_dataset_list: ["tquad2-train"] +valid_dataset_list: ["tquad2-valid"] +eval_dataset_list: ["tquad2-valid", "xquad.tr"] +train_file_path: "data/train_data.pt" +valid_file_path: "data/valid_data.pt" +max_source_length: 512 +max_target_length: 64 +prepare_data: true +output_dir: "runs/bert-base/turque-bertbase-adamw-1e4-3ep-tquad2train" +do_train: true +do_eval: true +evaluation_strategy: "steps" +eval_steps: 300 +logging_steps: 300 +per_device_train_batch_size: 32 +per_device_eval_batch_size: 32 +gradient_accumulation_steps: 1 +eval_accumulation_steps: 8 +learning_rate: 1.0e-4 +num_train_epochs: 3 +save_total_limit: 1 +no_cuda: false +seed: 42 +fp16: false +fp16_full_eval: false +adafactor: false \ No newline at end of file diff --git a/configs/paper/mt5base-3task-both-tquad2.yaml b/configs/paper/mt5base-3task-both-tquad2.yaml new file mode 100644 index 0000000..4963108 --- /dev/null +++ b/configs/paper/mt5base-3task-both-tquad2.yaml @@ -0,0 +1,35 @@ +model_name_or_path: "google/mt5-base" +tokenizer_path: "tokenizers/mt5-base" +label_smoothing_factor: 0 +freeze_embeddings: false +run_name: "turque-mt5base-3task-adamw-tquad2train" +wandb_project: null +neptune_project: null +neptune_api_token: null +train_dataset_list: ["tquad2-train"] +valid_dataset_list: ["tquad2-valid"] +eval_dataset_list: ["tquad2-valid", "xquad.tr"] +train_file_path: "data/train_data.pt" +valid_file_path: "data/valid_data.pt" +max_source_length: 512 +max_target_length: 64 +prepare_data: true +mt5_task_list: ["qa", "qg", "ans_ext"] +mt5_qg_format: "both" +output_dir: "runs/mt5-base/3task-adamw-tquad2train" +do_train: true +do_eval: true +evaluation_strategy: "steps" +eval_steps: 300 +per_device_train_batch_size: 32 +per_device_eval_batch_size: 32 +gradient_accumulation_steps: 8 +eval_accumulation_steps: 1 +learning_rate: 1.0e-3 +num_train_epochs: 15 +save_total_limit: 1 +no_cuda: false +seed: 42 +fp16: false +fp16_full_eval: false +adafactor: false \ No newline at end of file diff --git a/configs/paper/mt5base-3task-highlight-tquad2.yaml b/configs/paper/mt5base-3task-highlight-tquad2.yaml new file mode 100644 index 0000000..74def9d --- /dev/null +++ b/configs/paper/mt5base-3task-highlight-tquad2.yaml @@ -0,0 +1,35 @@ +model_name_or_path: "google/mt5-base" +tokenizer_path: "tokenizers/mt5-base" +label_smoothing_factor: 0 +freeze_embeddings: false +run_name: "turque-mt5base-3task-highlight-adamw-tquad2train" +wandb_project: null +neptune_project: null +neptune_api_token: null +train_dataset_list: ["tquad2-train"] +valid_dataset_list: ["tquad2-valid"] +eval_dataset_list: ["tquad2-valid", "xquad.tr"] +train_file_path: "data/train_data.pt" +valid_file_path: "data/valid_data.pt" +max_source_length: 512 +max_target_length: 64 +prepare_data: true +mt5_task_list: ["qa", "qg", "ans_ext"] +mt5_qg_format: "highlight" +output_dir: "runs/mt5-base/3task-highlight-adamw-tquad2train" +do_train: true +do_eval: true +evaluation_strategy: "steps" +eval_steps: 300 +per_device_train_batch_size: 32 +per_device_eval_batch_size: 32 +gradient_accumulation_steps: 8 +eval_accumulation_steps: 1 +learning_rate: 1.0e-3 +num_train_epochs: 15 +save_total_limit: 1 +no_cuda: false +seed: 42 +fp16: false +fp16_full_eval: false +adafactor: false \ No newline at end of file diff --git a/configs/paper/mt5base-3task-prepend-tquad2.yaml b/configs/paper/mt5base-3task-prepend-tquad2.yaml new file mode 100644 index 0000000..ec19cd3 --- /dev/null +++ b/configs/paper/mt5base-3task-prepend-tquad2.yaml @@ -0,0 +1,35 @@ +model_name_or_path: "google/mt5-base" +tokenizer_path: "tokenizers/mt5-base" +label_smoothing_factor: 0 +freeze_embeddings: false +run_name: "turque-mt5base-3task-prepend-adamw-tquad2train" +wandb_project: null +neptune_project: null +neptune_api_token: null +train_dataset_list: ["tquad2-train"] +valid_dataset_list: ["tquad2-valid"] +eval_dataset_list: ["tquad2-valid", "xquad.tr"] +train_file_path: "data/train_data.pt" +valid_file_path: "data/valid_data.pt" +max_source_length: 512 +max_target_length: 64 +prepare_data: true +mt5_task_list: ["qa", "qg", "ans_ext"] +mt5_qg_format: "prepend" +output_dir: "runs/mt5-base/3task-prepend-adamw-tquad2train" +do_train: true +do_eval: true +evaluation_strategy: "steps" +eval_steps: 300 +per_device_train_batch_size: 32 +per_device_eval_batch_size: 32 +gradient_accumulation_steps: 8 +eval_accumulation_steps: 1 +learning_rate: 1.0e-3 +num_train_epochs: 15 +save_total_limit: 1 +no_cuda: false +seed: 42 +fp16: false +fp16_full_eval: false +adafactor: false \ No newline at end of file diff --git a/configs/paper/mt5base-qatask-tquad2.yaml b/configs/paper/mt5base-qatask-tquad2.yaml new file mode 100644 index 0000000..623b68e --- /dev/null +++ b/configs/paper/mt5base-qatask-tquad2.yaml @@ -0,0 +1,35 @@ +model_name_or_path: "google/mt5-base" +tokenizer_path: "tokenizers/mt5-base" +label_smoothing_factor: 0 +freeze_embeddings: false +run_name: "turque-mt5base-qa-adamw-tquad2train" +wandb_project: null +neptune_project: null +neptune_api_token: null +train_dataset_list: ["tquad2-train"] +valid_dataset_list: ["tquad2-valid"] +eval_dataset_list: ["tquad2-valid", "xquad.tr"] +train_file_path: "data/train_data.pt" +valid_file_path: "data/valid_data.pt" +max_source_length: 512 +max_target_length: 64 +prepare_data: true +mt5_task_list: ["qa"] +mt5_qg_format: "both" +output_dir: "runs/mt5-base/qa-adamw-tquad2train" +do_train: true +do_eval: true +evaluation_strategy: "steps" +eval_steps: 300 +per_device_train_batch_size: 32 +per_device_eval_batch_size: 32 +gradient_accumulation_steps: 8 +eval_accumulation_steps: 1 +learning_rate: 1.0e-3 +num_train_epochs: 15 +save_total_limit: 1 +no_cuda: false +seed: 42 +fp16: false +fp16_full_eval: false +adafactor: false \ No newline at end of file diff --git a/configs/paper/mt5base-qgtask-both-tquad2.yaml b/configs/paper/mt5base-qgtask-both-tquad2.yaml new file mode 100644 index 0000000..1c64949 --- /dev/null +++ b/configs/paper/mt5base-qgtask-both-tquad2.yaml @@ -0,0 +1,35 @@ +model_name_or_path: "google/mt5-base" +tokenizer_path: "tokenizers/mt5-base" +label_smoothing_factor: 0 +freeze_embeddings: false +run_name: "turque-mt5base-qg-adamw-tquad2train" +wandb_project: null +neptune_project: null +neptune_api_token: null +train_dataset_list: ["tquad2-train"] +valid_dataset_list: ["tquad2-valid"] +eval_dataset_list: ["tquad2-valid", "xquad.tr"] +train_file_path: "data/train_data.pt" +valid_file_path: "data/valid_data.pt" +max_source_length: 512 +max_target_length: 64 +prepare_data: true +mt5_task_list: ["qg"] +mt5_qg_format: "both" +output_dir: "runs/mt5-base/qg-adamw-tquad2train" +do_train: true +do_eval: true +evaluation_strategy: "steps" +eval_steps: 300 +per_device_train_batch_size: 32 +per_device_eval_batch_size: 32 +gradient_accumulation_steps: 8 +eval_accumulation_steps: 1 +learning_rate: 1.0e-3 +num_train_epochs: 15 +save_total_limit: 1 +no_cuda: false +seed: 42 +fp16: false +fp16_full_eval: false +adafactor: false \ No newline at end of file diff --git a/configs/paper/mt5base-qgtask-highlight-tquad2.yaml b/configs/paper/mt5base-qgtask-highlight-tquad2.yaml new file mode 100644 index 0000000..06cfdbb --- /dev/null +++ b/configs/paper/mt5base-qgtask-highlight-tquad2.yaml @@ -0,0 +1,35 @@ +model_name_or_path: "google/mt5-base" +tokenizer_path: "tokenizers/mt5-base" +label_smoothing_factor: 0 +freeze_embeddings: false +run_name: "turque-mt5base-qg-highlight-adamw-tquad2train" +wandb_project: null +neptune_project: null +neptune_api_token: null +train_dataset_list: ["tquad2-train"] +valid_dataset_list: ["tquad2-valid"] +eval_dataset_list: ["tquad2-valid", "xquad.tr"] +train_file_path: "data/train_data.pt" +valid_file_path: "data/valid_data.pt" +max_source_length: 512 +max_target_length: 64 +prepare_data: true +mt5_task_list: ["qg"] +mt5_qg_format: "highlight" +output_dir: "runs/mt5-base/qg-highlight-adamw-tquad2train" +do_train: true +do_eval: true +evaluation_strategy: "steps" +eval_steps: 300 +per_device_train_batch_size: 32 +per_device_eval_batch_size: 32 +gradient_accumulation_steps: 8 +eval_accumulation_steps: 1 +learning_rate: 1.0e-3 +num_train_epochs: 15 +save_total_limit: 1 +no_cuda: false +seed: 42 +fp16: false +fp16_full_eval: false +adafactor: false \ No newline at end of file diff --git a/configs/paper/mt5base-qgtask-prepend-tquad2.yaml b/configs/paper/mt5base-qgtask-prepend-tquad2.yaml new file mode 100644 index 0000000..d40605a --- /dev/null +++ b/configs/paper/mt5base-qgtask-prepend-tquad2.yaml @@ -0,0 +1,35 @@ +model_name_or_path: "google/mt5-base" +tokenizer_path: "tokenizers/mt5-base" +label_smoothing_factor: 0 +freeze_embeddings: false +run_name: "turque-mt5base-qg-prepend-adamw-tquad2train" +wandb_project: null +neptune_project: null +neptune_api_token: null +train_dataset_list: ["tquad2-train"] +valid_dataset_list: ["tquad2-valid"] +eval_dataset_list: ["tquad2-valid", "xquad.tr"] +train_file_path: "data/train_data.pt" +valid_file_path: "data/valid_data.pt" +max_source_length: 512 +max_target_length: 64 +prepare_data: true +mt5_task_list: ["qg"] +mt5_qg_format: "prepend" +output_dir: "runs/mt5-base/qg-prepend-adamw-tquad2train" +do_train: true +do_eval: true +evaluation_strategy: "steps" +eval_steps: 300 +per_device_train_batch_size: 32 +per_device_eval_batch_size: 32 +gradient_accumulation_steps: 8 +eval_accumulation_steps: 1 +learning_rate: 1.0e-3 +num_train_epochs: 15 +save_total_limit: 1 +no_cuda: false +seed: 42 +fp16: false +fp16_full_eval: false +adafactor: false \ No newline at end of file diff --git a/configs/paper/mt5small-3task-both-tquad2.yaml b/configs/paper/mt5small-3task-both-tquad2.yaml new file mode 100644 index 0000000..f6f0f6b --- /dev/null +++ b/configs/paper/mt5small-3task-both-tquad2.yaml @@ -0,0 +1,35 @@ +model_name_or_path: "google/mt5-small" +tokenizer_path: "tokenizers/mt5-small" +label_smoothing_factor: 0 +freeze_embeddings: false +run_name: "turque-mt5small-adamw-1e3-15ep-tquad2train" +wandb_project: null +neptune_project: null +neptune_api_token: null +train_dataset_list: ["tquad2-train"] +valid_dataset_list: ["tquad2-valid"] +eval_dataset_list: ["tquad2-valid", "xquad.tr"] +train_file_path: "data/train_data.pt" +valid_file_path: "data/valid_data.pt" +max_source_length: 512 +max_target_length: 64 +prepare_data: true +mt5_task_list: ["qa", "qg", "ans_ext"] +mt5_qg_format: "both" +output_dir: "runs/mt5-small/3task/adamw-1e3-15ep-both-tquad2train" +do_train: true +do_eval: true +evaluation_strategy: "steps" +eval_steps: 300 +per_device_train_batch_size: 64 +per_device_eval_batch_size: 64 +gradient_accumulation_steps: 4 +eval_accumulation_steps: 1 +learning_rate: 1.0e-3 +num_train_epochs: 15 +save_total_limit: 1 +no_cuda: false +seed: 42 +fp16: false +fp16_full_eval: false +adafactor: false \ No newline at end of file diff --git a/core/__init__.py b/core/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/core/api.py b/core/api.py new file mode 100644 index 0000000..eba3199 --- /dev/null +++ b/core/api.py @@ -0,0 +1,297 @@ +import itertools +import logging +import os +import re +from collections import OrderedDict +from typing import Dict, List, Optional, Union + +from tqdm import tqdm + +from hf.model import MT5Model +from utils.file import load_json +from utils.nlp import ( + add_start_end_to_answer_list_per_sentence, + normalize_text, + postprocess_answer_extraction_output, + prepare_answer_extraction_samples, + prepare_qa_sample, + prepare_qg_samples, + sentence_tokenize, +) + +logger = logging.getLogger(__name__) +logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=os.environ.get("LOGLEVEL", "INFO").upper(), +) + + +class TurQue: + def __init__( + self, + model_url_or_path: str = None, + use_cuda: bool = None, + max_source_length: int = 512, + max_target_length: int = 80, + generate_num_beams: int = 4, + top_k: int = None, + top_p: float = None, + qg_format: str = "highlight", + ): + model_url_or_path = "turque-s1" if model_url_or_path is None else model_url_or_path + mt5_model = MT5Model(model_url_or_path, use_cuda=use_cuda) + self.model = mt5_model.model + self.tokenizer = mt5_model.tokenizer + self.model_type = mt5_model.type + + self.max_source_length = max_source_length + self.max_target_length = max_target_length + self.generate_num_beams = generate_num_beams + self.top_k = top_k + self.top_p = top_p + self.qg_format = qg_format + + def __call__( + self, + task: str, + context: str, + question: Optional[str] = None, + answer_list: Optional[Union[List[str], List[Dict]]] = None, + ): + context = normalize_text(context) + if task == "answer-extraction": + # return answer list using context + answer_list = self._generate_answer_list_from_context(context) + output = [{"answer": answer["text"]} for answer in answer_list] + return output + elif task == "question-answering": + # return answer list using context and question + question = normalize_text(question) + answer = self._generate_answer_from_context_and_question(question, context) + output = [{"answer": answer}] + return output + elif task == "question-generation": + # return question list using context + if answer_list is None: + answer_list = self._generate_answer_list_from_context(context) + + if not answer_list: + return [{"answer": None, "question": None}] + else: + _answer_list = [] + for answer in answer_list: + answer["text"] = normalize_text(answer["text"]) + _answer_list.append(answer) + answer_list = _answer_list + + samples = prepare_qg_samples(context, answer_list, qg_format=self.qg_format) + + if samples[0]["answer"] is None: + return [{"answer": None, "question": None}] + else: + inputs = [sample["source_text"] for sample in samples] + + # single generation without padding is 5 times faster than padding + batch generation + question_list = [] + for input in inputs: + question = self._generate_question_list([input], padding=False)[0] + question_list.append(question) + + output = [ + {"answer": sample["answer"], "question": question} + for sample, question in zip(samples, question_list) + ] + return output + else: + raise NameError( + f"{task} is not defined. 'task' must be one of ['answer-extraction', 'question-answering', 'question-generation']" + ) + + def _generate_question_list(self, inputs, padding=True, truncation=True): + inputs = self._tokenize(inputs, padding=padding, truncation=truncation) + + outs = self.model.generate( + input_ids=inputs["input_ids"].to(self.model.device), + attention_mask=inputs["attention_mask"].to(self.model.device), + max_length=self.max_target_length, + num_beams=self.generate_num_beams, + top_k=self.top_k, + top_p=self.top_p, + ) + + question_list = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in outs] + return question_list + + def _generate_answer_list_from_context(self, context): + samples = prepare_answer_extraction_samples(context=context) + inputs = [sample["source_text"] for sample in samples] + + inputs = self._tokenize(inputs, padding=True, truncation=True) + + outs = self.model.generate( + input_ids=inputs["input_ids"].to(self.model.device), + attention_mask=inputs["attention_mask"].to(self.model.device), + max_length=self.max_target_length, + ) + + output_list = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in outs] + + answer_list_per_sentence = [] + for output in output_list: + # postprocess answer extraction output list + answer_text_list = postprocess_answer_extraction_output(output) + answer_list = [{"text": normalize_text(answer_text)} for answer_text in answer_text_list] + answer_list_per_sentence.append(answer_list) + + sentence_list = sentence_tokenize(context) + answer_list = add_start_end_to_answer_list_per_sentence(sentence_list, answer_list_per_sentence) + + return answer_list + + def _tokenize(self, inputs, padding=True, truncation=True, add_special_tokens=True): + inputs = self.tokenizer.batch_encode_plus( + inputs, + max_length=self.max_source_length, + add_special_tokens=add_special_tokens, + truncation=truncation, + padding="max_length" if padding else False, + pad_to_max_length=padding, + return_tensors="pt", + ) + return inputs + + def _generate_answer_from_context_and_question(self, question, context): + sample = prepare_qa_sample(context, question) + source_text = sample["source_text"] + inputs = self._tokenize([source_text], padding=False) + + outs = self.model.generate( + input_ids=inputs["input_ids"].to(self.model.device), + attention_mask=inputs["attention_mask"].to(self.model.device), + max_length=self.max_target_length, + ) + + answer = self.tokenizer.decode(outs[0], skip_special_tokens=True) + + return answer + + def qg_from_file(self, path_or_dict: str, use_answers: bool = False, **kwargs): + """performs question-generation using the contexts and answers + from squad formatted json file or answers extracted by the model""" + + for k, v in kwargs.items(): + setattr(self, k, v) + task = "question-generation" + + # read data from path or dict + if isinstance(path_or_dict, str): + data = load_json(path_or_dict)["data"] + else: + data = path_or_dict["data"] + + out = {"data": []} + for article in tqdm(data, desc="Question Generation from articles"): + out_article = {"paragraphs": [], "title": article.get("title")} + # iterate over each paragraph + for paragraph in article["paragraphs"]: + context = paragraph["context"] + out_para = {"context": context, "qas": []} + if use_answers and paragraph["qas"]: + answer_list = [] + paragraph["qas"] = sorted(paragraph["qas"], key=lambda k: k["answers"][0]["answer_start"]) + for qa in paragraph["qas"]: + answer = qa["answers"][0] + answer_list.append(answer) + elif use_answers and not paragraph["qas"]: # pass if paragraph["qas"] is empty + continue + else: + answer_list = None + qg_out = self(task=task, context=context, answer_list=answer_list) + if qg_out[0]["question"] is None: + continue + for qa, gold_qa in zip(qg_out, paragraph["qas"]): + if use_answers: + qa["gold_question"] = gold_qa["question"] + out_para["qas"].append(qa) + out_article["paragraphs"].append(out_para) + out["data"].append(out_article) + return out + + def qa_from_file(self, path_or_dict: str, **kwargs): + """performs question-answering using the contexts and questions + from squad formatted json file""" + + for k, v in kwargs.items(): + setattr(self, k, v) + task = "question-answering" + + # read data from path or dict + if isinstance(path_or_dict, str): + data = load_json(path_or_dict)["data"] + else: + data = path_or_dict["data"] + + out = {"data": []} + for article in tqdm(data, desc="Question Answering from articles"): + out_article = {"paragraphs": [], "title": article.get("title")} + # iterate over each paragraph + for paragraph in article["paragraphs"]: + context = paragraph["context"] + out_para = {"context": context, "qas": []} + + # extract questions from dataset paragraph and get answer from model + for qa in paragraph["qas"]: + question = qa.get("question") + if question is not None: + qa_out = self(task=task, context=context, question=question)[0] + # append q&a pair into out_para + qa_out["gold_answer"] = qa["answers"][0]["text"] + qa_out["question"] = question + out_para["qas"].append(qa_out) + else: + logger.warning("skipping a paragraph without questions.") + + out_article["paragraphs"].append(out_para) + out["data"].append(out_article) + return out + + def ans_ext_from_file(self, path_or_dict: str, **kwargs): + """performs answer-extraction using the contexts from squad formatted json file""" + + for k, v in kwargs.items(): + setattr(self, k, v) + task = "answer-extraction" + + # read data from path or dict + if isinstance(path_or_dict, str): + data = load_json(path_or_dict)["data"] + else: + data = path_or_dict["data"] + + out = {"data": []} + for article in tqdm(data, desc="Answer Extraction from articles"): + out_article = {"paragraphs": [], "title": article.get("title")} + # iterate over each paragraph + for paragraph in article["paragraphs"]: + context = paragraph["context"] + out_para = {"context": context, "gold_answer_list": [], "predicted_answer_list": []} + + if paragraph["qas"]: + gold_answer_list = [] + paragraph["qas"] = sorted(paragraph["qas"], key=lambda k: k["answers"][0]["answer_start"]) + for qa in paragraph["qas"]: + answer = qa["answers"][0] + gold_answer_list.append(answer["text"]) + else: + logger.warning("skipping a paragraph without q/a's.") + # extract answers + ans_ext_out = self(task=task, context=context) + # add gold and predicted answers + predicted_answer_list = [output["answer"] for output in ans_ext_out] + out_para["gold_answer_list"] = gold_answer_list + out_para["predicted_answer_list"] = predicted_answer_list + + out_article["paragraphs"].append(out_para) + out["data"].append(out_article) + return out diff --git a/core/argument_parsers.py b/core/argument_parsers.py new file mode 100644 index 0000000..f79089b --- /dev/null +++ b/core/argument_parsers.py @@ -0,0 +1,162 @@ +import logging +import os +import sys +from dataclasses import dataclass, field +from pathlib import Path +from typing import Dict, List, Optional + +from transformers import HfArgumentParser, TrainingArguments + +from utils.file import read_yaml + +logger = logging.getLogger(__name__) +logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=os.environ.get("LOGLEVEL", "INFO").upper(), +) + + +@dataclass +class ModelArguments: + """ + Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. + """ + + model_name_or_path: str = field( + default="google/mt5-small", + metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}, + ) + tokenizer_path: Optional[str] = field( + default=None, metadata={"help": "Pretrained tokenizer path to be saved/loaded"} + ) + cache_dir: Optional[str] = field( + default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"} + ) + + wandb_project: Optional[str] = field(default=None, metadata={"help": "Wandb project for experiment tracking"}) + wandb_id: Optional[str] = field( + default=None, metadata={"help": "Wandb run id, will be used when 'do_train=False', 'do_eval=True'"} + ) + neptune_project: Optional[str] = field(default=None, metadata={"help": "Neptune project for experiment tracking"}) + neptune_run: Optional[str] = field( + default=None, metadata={"help": "Neptune run id, will be used when 'do_train=False', 'do_eval=True'"} + ) + neptune_api_token: Optional[str] = field( + default=None, metadata={"help": "Neptune api token for experiment tracking"} + ) + model_type: str = field(default=None, metadata={"help": "'mt5' or 'bert'"}) + + +@dataclass +class DataArguments: + """ + Arguments pertaining to what data we are going to input our model for training and eval. + """ + + train_file_path: Optional[str] = field( + default="data/train_data_multitask_mt5.pt", + metadata={"help": "name for cached train dataset"}, + ) + valid_file_path: Optional[str] = field( + default="data/valid_data_multitask_mt5.pt", + metadata={"help": "name for cached valid dataset"}, + ) + + prepare_data: bool = field( + default=True, + metadata={ + "help": "Runs prepare_data.py before starting to train. Set if false if you alreade prepared data before." + }, + ) + + +@dataclass +class ExtendedTrainingArguments(TrainingArguments): + train_dataset_list: Optional[List[str]] = field( + default_factory=lambda: ["tquad-train"], + metadata={"help": "dataset name list of the training"}, + ) + valid_dataset_list: Optional[List[str]] = field( + default_factory=lambda: ["tquad-val"], + metadata={"help": "dataset name list of the validation"}, + ) + eval_dataset_list: Optional[List[str]] = field( + default_factory=lambda: ["tquad-val", "xquad.tr"], + metadata={"help": "dataset name list of the evaluation"}, + ) + freeze_embeddings: bool = field( + default=False, + metadata={"help": "Freeze token embeddings."}, + ) + max_source_length: Optional[int] = field( + default=512, + metadata={"help": "Max input length for the source text"}, + ) + max_target_length: Optional[int] = field( + default=80, + metadata={"help": "Max input length for the target text"}, + ) + mt5_task_list: Optional[List[str]] = field( + default_factory=lambda: ["qa", "qg", "ans_ext"], + metadata={"help": "task list for mt5"}, + ) + mt5_qg_format: Optional[str] = field( + default="highlight", + metadata={"help": 'mt5 qg format as "highlight", "prepend" or "both"'}, + ) + + +def parser(args_file_path: Optional[str] = None): + # See all possible arguments in src/transformers/training_args.py + # or by passing the --help flag to this script. + # We now keep distinct sets of args, for a cleaner separation of concerns. + + parser = HfArgumentParser((ModelArguments, DataArguments, ExtendedTrainingArguments)) + + accepted_args_file_suffixes = (".json", ".yml", ".yaml") + + # validate args file suffix + args_file_suffix: Optional[str] = None + if args_file_path: + # If we pass args_file_path to the script and it's the path to a json/yml/yaml file, + # let's parse it to get our arguments. + assert type(args_file_path) == str, TypeError(f"invalid 'args_file_path': {args_file_path}") + args_file_suffix = Path(args_file_path).suffix + assert args_file_suffix in accepted_args_file_suffixes, TypeError( + f"""args file should be one of: / + {accepted_args_file_suffixes}, invalid args file format: {args_file_suffix}""" + ) + elif len(sys.argv) == 2: + # If we pass only one argument to the script and it's the path to a json/yml/yaml file, + # let's parse it to get our arguments. + args_file_path = sys.argv[1] + args_file_suffix = Path(args_file_path).suffix + assert args_file_suffix in accepted_args_file_suffixes, TypeError( + f"""args file should be one of: / + {accepted_args_file_suffixes}, invalid args file format: {args_file_suffix}""" + ) + + if args_file_suffix == ".json": + model_args, data_args, training_args = parser.parse_json_file(json_file=args_file_path) + elif args_file_suffix in (".yml", ".yaml"): + args_dict = read_yaml(args_file_path) + model_args, data_args, training_args = parser.parse_dict(args=args_dict) + else: + model_args, data_args, training_args = parser.parse_args_into_dataclasses() + + if model_args.tokenizer_path is None: + model_name = (model_args.model_name_or_path).split("/")[-1] + model_args.tokenizer_path = model_name + "_tokenizer" + + # overwrite model type + if "mt5" in model_args.model_name_or_path: + model_type = "mt5" + elif "bert" in model_args.model_name_or_path: + model_type = "bert" + else: + logger.info("couldnt infer model type from 'model_name_or_path', assuming its 'mt5'.") + model_type = "mt5" + model_args.model_type = model_type + + return model_args, data_args, training_args diff --git a/core/bert_api.py b/core/bert_api.py new file mode 100644 index 0000000..8677388 --- /dev/null +++ b/core/bert_api.py @@ -0,0 +1,129 @@ +import logging +import os +from typing import Tuple, Union + +import numpy as np +import torch +from tqdm import tqdm + +from hf.model import BertModel +from utils.file import load_json + +logger = logging.getLogger(__name__) +logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=os.environ.get("LOGLEVEL", "INFO").upper(), +) + + +def _postprocess_output( + model_output, top_n_best_answers: int = 20, max_answer_length: int = 64 +) -> Tuple[Union[int, int]]: + """returns valid_answer_list (List[Dict]) with each elements having + score (float), answer_start (int), answer_end (int) keys""" + + start_logits = model_output.start_logits[0].cpu().detach().numpy() + end_logits = model_output.end_logits[0].cpu().detach().numpy() + # Gather the indices the best start/end logits: + start_indexes = np.argsort(start_logits)[-1 : -top_n_best_answers - 1 : -1].tolist() + end_indexes = np.argsort(end_logits)[-1 : -top_n_best_answers - 1 : -1].tolist() + valid_answers = [] + for start_index in start_indexes: + for end_index in end_indexes: + # Don't consider answers with a length that is either < 0 or > max_answer_length. + if end_index < start_index or end_index - start_index + 1 > max_answer_length: + continue + valid_answers.append( + { + "score": start_logits[start_index] + end_logits[end_index], + "answer_start": start_index, + "answer_end": end_index, + } + ) + valid_answer_list = sorted(valid_answers, key=lambda x: x["score"], reverse=True) + return valid_answer_list + + +def qa_from_file( + path_or_dict: str, + model_url_or_path: str, + use_cuda: bool = True, + top_n_best_answers: int = 20, + max_source_lenght: int = 512, + max_target_length: int = 80, +): + model = BertModel(model_url_or_path, use_cuda=use_cuda) + + # read data from path or dict + if isinstance(path_or_dict, str): + data = load_json(path_or_dict)["data"] + else: + data = path_or_dict["data"] + + out = {"data": []} + for article in tqdm(data, desc="Answer Generation using Bert from articles"): + out_article = {"paragraphs": [], "title": article.get("title")} + # iterate over each paragraph + for paragraph in article["paragraphs"]: + context = paragraph["context"] + out_para = {"context": context, "qas": []} + + # extract questions from dataset paragraph and get answer from model + for qa in paragraph["qas"]: + question = qa.get("question") + if question is not None: + + inputs = model.tokenizer.encode_plus( + context, + question, + max_length=max_source_lenght, + add_special_tokens=True, + padding=False, + pad_to_max_length=False, + return_tensors="pt", + truncation=True, + ) + input_ids = inputs["input_ids"].tolist()[0] + + model_output = model.model(torch.tensor([input_ids]).to(model.device)) + # answer_dict_list = _postprocess_output( + # model_output, top_n_best_answers=top_n_best_answers, max_answer_length=max_target_length + # ) + answer_start_scores, answer_end_scores = model_output["start_logits"], model_output["end_logits"] + + answer_start = torch.argmax( + answer_start_scores + ) # Get the most likely beginning of answer with the argmax of the score + answer_end = ( + torch.argmax(answer_end_scores) + 1 + ) # Get the most likely end of answer with the argmax of the score + + # answer_list = [] + # for answer_dict in answer_dict_list: + # answer = model.tokenizer.convert_tokens_to_string( + # model.tokenizer.convert_ids_to_tokens( + # input_ids[answer_dict["answer_start"] : answer_dict["answer_end"]] + # ) + # ) + # answer_list.append(answer) + # append q&a pair into out_para + if answer_end > answer_start: + answer = model.tokenizer.convert_tokens_to_string( + model.tokenizer.convert_ids_to_tokens(input_ids[answer_start:answer_end]) + ) + else: + answer = "" + qa_out = { + "answer": answer, + # "alternative_answers": answer_list[1:], + "gold_answer": qa["answers"][0]["text"], + "question": question, + } + out_para["qas"].append(qa_out) + else: + logger.warning("skipping a paragraph without questions.") + + out_article["paragraphs"].append(out_para) + out["data"].append(out_article) + return out diff --git a/core/collator.py b/core/collator.py new file mode 100644 index 0000000..cd7aee9 --- /dev/null +++ b/core/collator.py @@ -0,0 +1,81 @@ +from typing import Dict, List + +import torch + +# taken from https://github.com/patil-suraj/question_generation/blob/master/data_collator.py + + +def trim_batch( + input_ids, + pad_token_id, + attention_mask=None, +): + """Remove columns that are populated exclusively by pad_token_id""" + keep_column_mask = input_ids.ne(pad_token_id).any(dim=0) + if attention_mask is None: + return input_ids[:, keep_column_mask] + else: + return (input_ids[:, keep_column_mask], attention_mask[:, keep_column_mask]) + + +# prepares lm_labels from target_ids, returns examples with keys as expected by the forward method +# this is necessary because the trainer directly passes this dict as arguments to the model +# so make sure the keys match the parameter names of the forward method +class T2TDataCollator: + def __init__(self, tokenizer, mode="training", using_tpu=False): + self.tokenizer = tokenizer + self.using_tpu = using_tpu + self.mode = mode + + def __call__(self, batch: List) -> Dict[str, torch.Tensor]: + """ + Take a list of samples from a Dataset and collate them into a batch. + Returns: + A dictionary of tensors + """ + input_ids = torch.stack([example["source_ids"] for example in batch]) + target_ids = torch.stack([example["target_ids"] for example in batch]) + attention_mask = torch.stack([example["attention_mask"] for example in batch]) + + pad_token_id = self.tokenizer.pad_token_id + + # don't trim on tpu, for some reason trimming leads to slower training on TPU + if not self.using_tpu: + input_ids, attention_mask = trim_batch(input_ids, pad_token_id, attention_mask=attention_mask) + target_ids = trim_batch(target_ids, pad_token_id) + + lm_labels = target_ids.clone() + decoder_input_ids = self._shift_right_t5(lm_labels) + if self.mode == "training": + lm_labels[lm_labels[:, :] == pad_token_id] = -100 + + params = { + "input_ids": input_ids, + "attention_mask": attention_mask, + "labels": lm_labels, + "decoder_input_ids": decoder_input_ids, + } + + return params + + def _shift_right_t5(self, input_ids): + decoder_start_token_id = self.tokenizer.pad_token_id + pad_token_id = self.tokenizer.pad_token_id + + assert ( + decoder_start_token_id is not None + ), """self.model.config.decoder_start_token_id has to be defined. In T5 it is usually set to the pad_token_id. + See T5 docs for more information""" + + # shift inputs to the right + shifted_input_ids = input_ids.new_zeros(input_ids.shape) + shifted_input_ids[..., 1:] = input_ids[..., :-1].clone() + shifted_input_ids[..., 0] = decoder_start_token_id + + assert pad_token_id is not None, "self.model.config.pad_token_id has to be defined." + # replace possible -100 values in labels by `pad_token_id` + shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) + + assert torch.all(shifted_input_ids >= 0).item(), "Verify that `labels` has only positive values and -100" + + return shifted_input_ids diff --git a/core/dataset_parsers.py b/core/dataset_parsers.py new file mode 100644 index 0000000..9ee57df --- /dev/null +++ b/core/dataset_parsers.py @@ -0,0 +1,195 @@ +import os +from typing import Dict, List + +from datasets import Dataset + +from utils.file import load_json +from utils.nlp import ( + download_and_load_dataset, + normalize_text, + prepare_answer_extraction_samples, + prepare_qa_sample, + prepare_qg_samples, +) + +# former url train data was unfit for (pyarrow) datasets.Dataset (file maybe corrupt). + +_TQUAD_URL = "https://github.com/fcakyon/turkish-qa-datasets/releases/download/0.0.1/" +_TQUAD1_DEV_FILE = "tquad_dev_data_v1.json" +_TQUAD1_TRAINING_FILE = "tquad_train_data_v1.json" +_TQUAD2_DEV_FILE = "tquad_dev_data_v2.json" +_TQUAD2_TRAINING_FILE = "tquad_train_data_v2.json" +_TQUAD_DEV_SMALL_FILE = "tquad_dev_small_data.json" +_TQUAD_LOCAL_DIR = "data/tquad/" +_XQUAD_TR_URL = "https://raw.githubusercontent.com/deepmind/xquad/master/" +_XQUAD_LOCAL_DIR = "data/xquad" + + +def load_tquad_data(data_name="tquad2-train") -> Dataset: + if "tquad1-train" in data_name: + tquad_url = os.path.join(_TQUAD_URL, _TQUAD1_TRAINING_FILE) + tquad_local_path = os.path.join(_TQUAD_LOCAL_DIR, _TQUAD1_TRAINING_FILE) + elif "tquad2-train" in data_name: + tquad_url = os.path.join(_TQUAD_URL, _TQUAD2_TRAINING_FILE) + tquad_local_path = os.path.join(_TQUAD_LOCAL_DIR, _TQUAD2_TRAINING_FILE) + elif "tquad1-valid" in data_name: + tquad_url = os.path.join(_TQUAD_URL, _TQUAD1_DEV_FILE) + tquad_local_path = os.path.join(_TQUAD_LOCAL_DIR, _TQUAD1_DEV_FILE) + elif "tquad2-valid" in data_name: + tquad_url = os.path.join(_TQUAD_URL, _TQUAD2_DEV_FILE) + tquad_local_path = os.path.join(_TQUAD_LOCAL_DIR, _TQUAD2_DEV_FILE) + elif "small" in data_name: + tquad_url = os.path.join(_TQUAD_URL, _TQUAD_DEV_SMALL_FILE) + tquad_local_path = os.path.join(_TQUAD_LOCAL_DIR, _TQUAD_DEV_SMALL_FILE) + else: + raise ValueError( + f"Unknown data_name {data_name}, must be one of ['tquad1-train', 'tquad2-train', 'tquad1-valid', 'tquad2-valid', 'tquad.small']" + ) + + return download_and_load_dataset(tquad_url, tquad_local_path) + + +def load_xquad_data(data_name="xquad.tr"): + """XQuad dataset has only validation split.""" + xquad_url = os.path.join(_XQUAD_TR_URL, data_name + ".json") + xquad_local_path = os.path.join(_XQUAD_LOCAL_DIR, data_name + ".json") + return download_and_load_dataset(xquad_url, xquad_local_path) + + +def prepare_data_for_bert(data: Dataset) -> List[Dict]: + """ + Args: + data: squad data + + Returns: Processed samples as list in bert input format. + """ + samples = [] + data = data["data"] + + for group in data: + for passage in group["paragraphs"]: + context = passage["context"] + if passage["qas"]: + for qa in passage["qas"]: + question = qa["question"] + # for answer in qa["answers"]: + answer = qa["answers"][0] + + gold_text = answer["text"] + start_idx = answer["answer_start"] + end_idx = start_idx + len(gold_text) + + # sometimes squad answers are off by a character or two – fix this + if context[start_idx:end_idx] == gold_text: + answer["answer_end"] = end_idx + elif context[start_idx - 1 : end_idx - 1] == gold_text: + answer["answer_start"] = start_idx - 1 + answer["answer_end"] = end_idx - 1 # When the gold label is off by one character + elif context[start_idx - 2 : end_idx - 2] == gold_text: + answer["answer_start"] = start_idx - 2 + answer["answer_end"] = end_idx - 2 # When the gold label is off by two characters + elif context[start_idx - 3 : end_idx - 3] == gold_text: + answer["answer_start"] = start_idx - 3 + answer["answer_end"] = end_idx - 3 # When the gold label is off by three characters + else: + print( + f"skipping the answer|answer_start|context {answer['text']}|{answer['answer_start']}|{context} | for reason: 'answer indexes are off by a lot'" + ) + continue + + sample = {"context": context, "question": question, "answer": answer} + samples.append(sample) + return samples + + +def prepare_data_for_mt5( + data: Dataset, task_list: List[str] = ["ans_ext", "qa", "qg"], qg_format="highlight" +) -> List[Dict]: + """ + Args: + data: squad data + task_list: list of tasks to data be prepared + qg_format: "highlight", "prepend" or "both" + + Returns: Processed samples as list in mt5 input format. + """ + samples = [] + data = data["data"] + + for article in data: + for paragraph in article["paragraphs"]: + context = paragraph["context"] + question_list = [] + answer_list = [] + + if paragraph["qas"]: # pass if paragraph["qas"] is empty + for qa in paragraph["qas"]: + question = normalize_text(qa["question"]) + answer = qa["answers"][0] + answer["text"] = answer["text"] + qa_sample = prepare_qa_sample(context=context, question=question, answer=answer["text"]) + if qa_sample["target_text"] is not None: + if "qa" in task_list: + samples.append(qa_sample) + question_list.append(question) + answer_list.append(answer) + + if answer_list and question_list: + qg_samples = prepare_qg_samples( + context=context, answer_list=answer_list, question_list=question_list, qg_format=qg_format + ) + if qg_samples[0]["answer"] is not None: + if "qg" in task_list: + samples.extend(qg_samples) + + answer_extraction_samples = prepare_answer_extraction_samples(context=context, answer_list=answer_list) + for answer_extraction_sample in answer_extraction_samples: + if answer_extraction_sample["target_text"] is not None: + if "ans_ext" in task_list: + samples.extend(answer_extraction_samples) + return samples + + +def prepare_data( + data: Dict, target_format="mt5", mt5_task_list: List[str] = ["ans_ext", "qa", "qg"], mt5_qg_format="highlight" +): + """ + Args: + target_format (str): output format ('mt5' or 'bert') + mt5_task_list: list of tasks for mt5 data to be prepared + mt5_qg_format: "highlight", "prepend" or "both" + """ + if target_format == "mt5": + samples = prepare_data_for_mt5(data, mt5_task_list, mt5_qg_format) + elif target_format == "bert": + samples = prepare_data_for_bert(data) + return samples + + +def load_dataset(name_or_path: str): + if os.path.isfile(name_or_path): + data = load_json(name_or_path) + elif "tquad" in name_or_path: + data = load_tquad_data(name_or_path) + elif "xquad" in name_or_path: + data = load_xquad_data(name_or_path) + else: + raise ValueError(f"Unknown dataset {name_or_path}.") + + return data + + +def load_and_prepare_dataset( + name_or_path: str, + target_format="mt5", + mt5_task_list: List[str] = ["ans_ext", "qa", "qg"], + mt5_qg_format="highlight", +): + """ + Args: + target_format (str): output format ('mt5' or 'bert') + mt5_task_list: list of tasks for mt5 data to be prepared + mt5_qg_format: "highlight", "prepend" or "both" + """ + data = load_dataset(name_or_path) + return prepare_data(data, target_format=target_format, mt5_task_list=mt5_task_list, mt5_qg_format=mt5_qg_format) diff --git a/core/evaluate.py b/core/evaluate.py new file mode 100644 index 0000000..8499936 --- /dev/null +++ b/core/evaluate.py @@ -0,0 +1,289 @@ +import logging +import os +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple, Union + +from jury import Jury +from jury.metrics import load_metric + +from core.dataset_parsers import load_dataset +from core.generate import generate +from utils.file import load_json, save_json + +logger = logging.getLogger(__name__) +logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=os.environ.get("LOGLEVEL", "INFO").upper(), +) + + +TASK_TO_METRIC = { + "ans_ext": [ + load_metric("f1", task="language-generation"), + load_metric("precision", task="language-generation"), + load_metric("bertscore", compute_kwargs={"lang": "tr"}), + ], + "qa": [load_metric("squad"), load_metric("bertscore", compute_kwargs={"lang": "tr"})], + "qg": [ + load_metric("bleu", compute_kwargs={"max_order": 1}), + load_metric("bleu", compute_kwargs={"max_order": 2}), + load_metric("bleu", compute_kwargs={"max_order": 3}), + load_metric("bleu", compute_kwargs={"max_order": 4}), + load_metric("rouge"), + load_metric("bertscore", compute_kwargs={"lang": "tr"}), + ], +} + + +class Evaluation: + r""" + Simple evaluation pipeline for text based metrics. By default it computes BLEU(n), + METEOR, ROUGE-L and SacreBLEU metrics. It supports both QA and QG evaluation, when BMR metrics + are given, it runs a QG Evaluation, for QA Evaluation construct the object with + "squad". + + Note: + + If ``predictions`` and ``references`` are given as list of strings, the order is recieved + as prediction & reference pairs and evaluation is done by prioratizing the order. + """ + + def __init__(self, metrics: Optional[List[str]] = None): + self.metrics = metrics + + @staticmethod + def task_to_metric(task: str) -> List[str]: + return TASK_TO_METRIC.get(task) + + @staticmethod + def metric_to_task(metric: str) -> str: + for task, metrics in TASK_TO_METRIC.items(): + if metric in metrics: + return task + + @staticmethod + def get_tasks(): + return list(TASK_TO_METRIC.keys()) + + @staticmethod + def _get_task_related_samples( + desired_task: str, + predictions: Union[List[str], List[Dict]], + references: Union[List[str], List[Dict]], + tasks: Optional[List[str]] = None, + ): + if tasks is None: + return predictions, references + + selected_predictions = [] + selected_references = [] + for prediction, reference, task in zip(predictions, references, tasks): + if task == desired_task: + selected_predictions.append(prediction) + selected_references.append(reference) + return selected_predictions, selected_references + + def run( + self, + predictions: Union[List[str], List[Dict]], + references: Union[List[str], List[Dict]], + tasks: List[str], + ) -> Dict[str, Any]: + scores = {} + + for task in TASK_TO_METRIC: + metrics = self.task_to_metric(task) + scorer = Jury(metrics=metrics, run_concurrent=True) + selected_predictions, selected_references = self._get_task_related_samples( + desired_task=task, predictions=predictions, references=references, tasks=tasks + ) + save_json( + {"predictions": selected_predictions, "references": selected_references}, + task + "_outputs_during_training.json", + ) + task_scores = scorer(predictions=selected_predictions, references=selected_references) + for key, value in task_scores.items(): + scores[task + "_" + key] = value + + return scores + + +def _get_qa_predictions_references(data: Dict) -> Tuple[List, List]: + predictions = [] + references = [] + for article in data: + for paragraph in article["paragraphs"]: + for qa in paragraph["qas"]: + predictions.append(qa["answer"]) + references.append(qa["gold_answer"]) + return predictions, references + + +def _get_qg_predictions_references(data: Dict) -> Tuple[List, List]: + predictions = [] + references = [] + for article in data: + for paragraph in article["paragraphs"]: + for i, qa in enumerate(paragraph["qas"]): + predictions.append(qa["question"]) + references.append(qa["gold_question"]) + return predictions, references + + +def _get_ans_ext_predictions_references(data: Dict, dont_combine_answer_list: bool = False) -> Tuple[List, List]: + predictions = [] + references = [] + for article in data: + for paragraph in article["paragraphs"]: + if dont_combine_answer_list: + predictions.append(paragraph["predicted_answer_list"]) + references.append(paragraph["gold_answer_list"]) + else: + predictions.append(" ".join(paragraph["predicted_answer_list"])) + references.append(" ".join(paragraph["gold_answer_list"])) + return predictions, references + + +def evaluate_from_file(path: str, task: str, output: str = None, dont_combine_answer_list: bool = False): + # prepare predictions & references + data = load_json(path)["data"] + if task == "qa": + predictions, references = _get_qa_predictions_references(data) + elif task == "qg": + predictions, references = _get_qg_predictions_references(data) + elif task == "ans_ext": + predictions, references = _get_ans_ext_predictions_references(data, dont_combine_answer_list) + else: + raise ValueError("Unknown task. Must be one of [qa, qg, ans_ext]") + # prepare metrics + metrics = TASK_TO_METRIC.get(task) + + try: + # calculate scores + scorer = Jury(metrics=metrics, run_concurrent=False) + scores = scorer(predictions=predictions, references=references) + # log results + logger.info(scores) + # export result + if output is None: + export_path = str(Path(path).parent / (task + "_eval_scores.json")) + else: + export_path = output + save_json( + scores, + export_path, + ) + return scores + except Exception as e: + logger.warning(e) + return None + + +def evaluate_on_train_end(model_args, training_args): + logger.info("*** Evaluate on train end ***") + + if model_args.model_type == "mt5": + eval_tasks = training_args.mt5_task_list + elif model_args.model_type == "bert": + eval_tasks = ["qa_for_bert"] + + overall_results = {} + for eval_dataset in training_args.eval_dataset_list: + for eval_task in eval_tasks: + dataset_name = Path(eval_dataset).name + data = load_dataset(eval_dataset) + output_generation_file = os.path.join( + training_args.output_dir, dataset_name + "_" + eval_task + "_generation.json" + ) + logger.info(f"Evaluating on {dataset_name} for {eval_task} task.") + + if eval_task == "ans_ext": + generate( + path_or_dict=data, + output=output_generation_file, + model_url_or_path=training_args.output_dir, + use_cuda=not training_args.no_cuda, + task=eval_task, + max_source_length=training_args.max_source_length, + max_target_length=training_args.max_target_length, + seed=training_args.seed, + ) + output_eval_file = os.path.join( + training_args.output_dir, dataset_name + "_" + eval_task + "_eval_result.json" + ) + results = evaluate_from_file( + path=output_generation_file, task=eval_task, output=output_eval_file, dont_combine_answer_list=False + ) + if results is not None: + for key, value in results.items(): + overall_results["eval_" + dataset_name + "_" + eval_task + "_" + key] = value + elif eval_task == "qa": + generate( + path_or_dict=data, + output=output_generation_file, + model_url_or_path=training_args.output_dir, + use_cuda=not training_args.no_cuda, + task=eval_task, + max_source_length=training_args.max_source_length, + max_target_length=training_args.max_target_length, + seed=training_args.seed, + ) + output_eval_file = os.path.join( + training_args.output_dir, dataset_name + "_" + eval_task + "_eval_result.json" + ) + results = evaluate_from_file( + path=output_generation_file, + task=eval_task, + output=output_eval_file, + dont_combine_answer_list=False, + ) + for key, value in results.items(): + overall_results["eval_" + dataset_name + "_" + eval_task + "_" + key] = value + elif eval_task == "qg": + generate( + path_or_dict=data, + output=output_generation_file, + model_url_or_path=training_args.output_dir, + use_cuda=not training_args.no_cuda, + task=eval_task, + use_answers=True, + max_source_length=training_args.max_source_length, + max_target_length=training_args.max_target_length, + qg_format=training_args.mt5_qg_format, + seed=training_args.seed, + ) + output_eval_file = os.path.join( + training_args.output_dir, dataset_name + "_" + eval_task + "_eval_result.json" + ) + results = evaluate_from_file( + path=output_generation_file, + task=eval_task, + output=output_eval_file, + dont_combine_answer_list=False, + ) + for key, value in results.items(): + overall_results["eval_" + dataset_name + "_" + eval_task + "_" + key] = value + elif eval_task == "qa_for_bert": + generate( + path_or_dict=data, + output=output_generation_file, + model_url_or_path=training_args.output_dir, + use_cuda=not training_args.no_cuda, + task=eval_task, + max_source_length=training_args.max_source_length, + max_target_length=training_args.max_target_length, + seed=training_args.seed, + ) + output_eval_file = os.path.join( + training_args.output_dir, dataset_name + "_" + eval_task + "_eval_result.json" + ) + results = evaluate_from_file( + path=output_generation_file, + task="qa", + output=output_eval_file, + dont_combine_answer_list=False, + ) + for key, value in results.items(): + overall_results["eval_" + dataset_name + "_" + "qa" + "_" + key] = value + return overall_results diff --git a/core/generate.py b/core/generate.py new file mode 100644 index 0000000..fab77a3 --- /dev/null +++ b/core/generate.py @@ -0,0 +1,163 @@ +import json +from typing import Optional, Union + +from transformers import set_seed + +from core.api import TurQue +from core.bert_api import qa_from_file +from utils.file import save_json + + +def read_config(path: str): + with open(path, "r", encoding="utf-8") as jf: + config = json.load(jf) + return config + + +def generate_qg( + path_or_dict: str = None, + output: str = None, + model_url_or_path: str = None, + use_cuda: bool = True, + use_answers: str = None, + max_source_length: int = 512, + max_target_length: int = 80, + qg_format: str = "highlight", +): + use_answers = False if use_answers is None else use_answers + turque = TurQue( + model_url_or_path=model_url_or_path, + use_cuda=use_cuda, + max_source_length=max_source_length, + max_target_length=max_target_length, + qg_format=qg_format, + ) + result = turque.qg_from_file(path_or_dict=path_or_dict, use_answers=use_answers) + save_json(result, output) + + +def generate_qa( + path_or_dict: str = None, + output: str = None, + model_url_or_path: str = None, + use_cuda: bool = True, + max_source_length: int = 512, + max_target_length: int = 80, +): + turque = TurQue( + model_url_or_path=model_url_or_path, + use_cuda=use_cuda, + max_source_length=max_source_length, + max_target_length=max_target_length, + ) + result = turque.qa_from_file(path_or_dict=path_or_dict) + save_json(result, output) + + +def generate_ans_ext( + path_or_dict: str = None, + output: str = None, + model_url_or_path: str = None, + use_cuda: bool = True, + max_source_length: int = 512, + max_target_length: int = 80, +): + turque = TurQue( + model_url_or_path=model_url_or_path, + use_cuda=use_cuda, + max_source_length=max_source_length, + max_target_length=max_target_length, + ) + result = turque.ans_ext_from_file(path_or_dict=path_or_dict) + save_json(result, output) + + +def generate_qa_for_bert( + path_or_dict: str, + output: str, + model_url_or_path: str = None, + use_cuda: bool = True, + max_source_length: int = 512, + max_target_length: int = 80, +): + result = qa_from_file( + path_or_dict=path_or_dict, + model_url_or_path=model_url_or_path, + use_cuda=use_cuda, + max_source_lenght=max_source_length, + max_target_length=max_target_length, + ) + save_json(result, output) + + +def generate( + path_or_dict: str = None, + output: str = None, + model_url_or_path: str = None, + use_cuda: bool = True, + use_answers: Union[str, bool] = False, + task: str = "qa", + max_source_length: int = 512, + max_target_length: int = 80, + config: str = None, + qg_format: str = "highlight", + seed: Optional[int] = None, +): + """ + path_or_dict (str): path or dict for a squad formatted dataset + output (str): output path for generation json + use_cuda (bool): perform generation on cuda + use_answers (bool): use gold answer for qg + task (str): one of 'qa', 'qg', 'ans_ext', 'qa_for_bert' + config (str): path to a json file + qg_format (str): 'highlight', 'prepend' or 'both' + seed (int): seed for randomized operations + """ + args = read_config(config) if config is not None else {} + path_or_dict = args.get("path_or_dict") if path_or_dict is None else path_or_dict + output = args.get("output") if output is None else output + model_url_or_path = args.get("model_url_or_path") if model_url_or_path is None else model_url_or_path + task = args.get("task") if task is None else task + use_answers = False if use_answers is None else use_answers + if seed is not None: + set_seed(seed) + if task == "qa": + generate_qa( + path_or_dict=path_or_dict, + output=output, + model_url_or_path=model_url_or_path, + use_cuda=use_cuda, + max_source_length=max_source_length, + max_target_length=max_target_length, + ) + elif task == "qg": + generate_qg( + path_or_dict=path_or_dict, + output=output, + model_url_or_path=model_url_or_path, + use_cuda=use_cuda, + use_answers=use_answers, + max_source_length=max_source_length, + max_target_length=max_target_length, + qg_format=qg_format, + ) + elif task == "ans_ext": + generate_ans_ext( + path_or_dict=path_or_dict, + output=output, + model_url_or_path=model_url_or_path, + use_cuda=use_cuda, + max_source_length=max_source_length, + max_target_length=max_target_length, + ) + elif task == "qa_for_bert": + generate_qa_for_bert( + path_or_dict=path_or_dict, + output=output, + model_url_or_path=model_url_or_path, + use_cuda=use_cuda, + max_source_length=max_source_length, + max_target_length=max_target_length, + ) + else: + raise ValueError(f"'task' should be one of ['qa', 'qg', 'ans_ext', 'qa_for_bert'] but given as {task}") diff --git a/core/pipelines.py b/core/pipelines.py new file mode 100644 index 0000000..62a1461 --- /dev/null +++ b/core/pipelines.py @@ -0,0 +1,242 @@ +import itertools +import logging +from typing import Dict, List, Optional, Union + +import torch +from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, PreTrainedModel, PreTrainedTokenizer + +from utils.nlp import sentence_tokenize + + +class QGPipeline: + def __init__( + self, + model: PreTrainedModel, + tokenizer: PreTrainedTokenizer, + ans_model: PreTrainedModel, + ans_tokenizer: PreTrainedTokenizer, + qg_format: str, + use_cuda: bool, + generate_max_length: int = 80, + generate_num_beams: int = 4, + ): + self.model = model + self.tokenizer = tokenizer + + self.ans_model = ans_model + self.ans_tokenizer = ans_tokenizer + + self.qg_format = qg_format + + self.device = "cuda" if torch.cuda.is_available() and use_cuda else "cpu" + self.model.to(self.device) + + if self.ans_model is not self.model: + self.ans_model.to(self.device) + + assert self.model.__class__.__name__ in ["MT5ForConditionalGeneration"] + + self.model_type = "mt5" + + self.generate_max_length = generate_max_length + self.generate_num_beams = generate_num_beams + + def __call__(self, inputs: str): + inputs = " ".join(inputs.split()) + sents, answers = self._extract_answers(inputs) + flat_answers = list(itertools.chain(*answers)) + + if len(flat_answers) == 0: + return [] + + qg_examples = self._prepare_inputs_for_qg_from_answers_hl(sents, answers) + + qg_inputs = [example["source_text"] for example in qg_examples] + questions = self._generate_questions(qg_inputs) + output = [{"answer": example["answer"], "question": que} for example, que in zip(qg_examples, questions)] + return output + + def _generate_questions(self, inputs): + inputs = self._tokenize(inputs, padding=True, truncation=True) + + outs = self.model.generate( + input_ids=inputs["input_ids"].to(self.device), + attention_mask=inputs["attention_mask"].to(self.device), + max_length=self.generate_max_length, + num_beams=self.generate_num_beams, + ) + + questions = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in outs] + return questions + + def _extract_answers(self, context): + sents, inputs = self._prepare_inputs_for_ans_extraction(context) + + inputs = self._tokenize(inputs, padding=True, truncation=True) + + outs = self.ans_model.generate( + input_ids=inputs["input_ids"].to(self.device), + attention_mask=inputs["attention_mask"].to(self.device), + max_length=self.generate_max_length, + ) + + dec = [self.ans_tokenizer.decode(ids, skip_special_tokens=True) for ids in outs] + + answers = [item.split("") for item in dec] + + answers = [i[:-1] for i in answers] + + return sents, answers + + def _tokenize(self, inputs, padding=True, truncation=True, add_special_tokens=True, max_length=512): + inputs = self.tokenizer.batch_encode_plus( + inputs, + max_length=max_length, + add_special_tokens=add_special_tokens, + truncation=truncation, + padding="max_length" if padding else False, + pad_to_max_length=padding, + return_tensors="pt", + ) + return inputs + + def _prepare_inputs_for_ans_extraction(self, text): + sents = sentence_tokenize(text) + + inputs = [] + for i in range(len(sents)): + source_text = "extract answers:" + for j, sent in enumerate(sents): + if i == j: + sent = " %s " % sent + source_text = "%s %s" % (source_text, sent) + source_text = source_text.strip() + + # if self.model_type == "mt5": + # source_text = source_text + " " + inputs.append(source_text) + + return sents, inputs + + def _prepare_inputs_for_qg_from_answers_hl(self, sents, answers): + inputs = [] + for i, answer in enumerate(answers): + if len(answer) == 0: + continue + for answer_text in answer: + sent = sents[i] + sents_copy = sents[:] + + answer_text = answer_text.strip() + # sometimes extracted answers does not match with the original text :/ + try: + ans_start_idx = sent.index(answer_text) + + sent = f"{sent[:ans_start_idx]} {answer_text} {sent[ans_start_idx + len(answer_text): ]}" + sents_copy[i] = sent + + source_text = " ".join(sents_copy) + source_text = f"generate question: {source_text}" + # if self.model_type == "mt5": + # source_text = source_text + " " + except: + continue + + inputs.append({"answer": answer_text, "source_text": source_text}) + + return inputs + + +class MultiTaskQAQGPipeline(QGPipeline): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def __call__(self, inputs: Union[Dict, str]): + if type(inputs) is str: + # do qg + return super().__call__(inputs) + else: + # do qa + return self._extract_answer(inputs["question"], inputs["context"]) + + def _prepare_inputs_for_qa(self, question, context): + source_text = f"question: {question} context: {context}" + # if self.model_type == "mt5": + # source_text = source_text + " " + return source_text + + def _extract_answer(self, question, context): + source_text = self._prepare_inputs_for_qa(question, context) + inputs = self._tokenize([source_text], padding=False) + + outs = self.model.generate( + input_ids=inputs["input_ids"].to(self.device), + attention_mask=inputs["attention_mask"].to(self.device), + max_length=self.generate_max_length, + ) + + answer = self.tokenizer.decode(outs[0], skip_special_tokens=True) + + return answer + + +SUPPORTED_TASKS = { + "multitask-qa-qg": { + "class": MultiTaskQAQGPipeline, + "default": { + "model": "obss/mt5-qa-qg", + }, + }, +} + + +def pipeline( + task: str, + model: Optional[PreTrainedModel] = None, + tokenizer: Optional[Union[str, PreTrainedTokenizer]] = None, + qg_format: Optional[str] = "highlight", + use_cuda: Optional[bool] = True, + **kwargs, +): + # Retrieve the task + if task not in SUPPORTED_TASKS: + raise KeyError("Unknown task {}, available tasks are {}".format(task, list(SUPPORTED_TASKS.keys()))) + + targeted_task = SUPPORTED_TASKS[task] + task_class = targeted_task["class"] + + # Use default model/config/tokenizer for the task if no model is provided + if model is None: + model = targeted_task["default"]["model"] + + # Try to infer tokenizer from model or config name (if provided as str) + if tokenizer is None: + if isinstance(model, str): + tokenizer = model + else: + # Impossible to guest what is the right tokenizer here + raise Exception( + "Impossible to guess which tokenizer to use. " + "Please provided a PretrainedTokenizer class or a path/identifier to a pretrained tokenizer." + ) + + # Instantiate tokenizer if needed + if isinstance(tokenizer, (str, tuple)): + if isinstance(tokenizer, tuple): + # For tuple we have (tokenizer name, {kwargs}) + tokenizer = AutoTokenizer.from_pretrained(tokenizer[0], **tokenizer[1]) + else: + tokenizer = AutoTokenizer.from_pretrained(tokenizer) + + # Instantiate model if needed + if isinstance(model, str): + model = AutoModelForSeq2SeqLM.from_pretrained(model) + + return task_class( + model=model, + tokenizer=tokenizer, + ans_model=model, + ans_tokenizer=tokenizer, + qg_format=qg_format, + use_cuda=use_cuda, + ) diff --git a/environment.yml b/environment.yml new file mode 100644 index 0000000..ab3e6c0 --- /dev/null +++ b/environment.yml @@ -0,0 +1,26 @@ +name: turque +#prefix: /your/custom/path/envs/turque +channels: + - anaconda +dependencies: + - anaconda::cudatoolkit=11 + - pip + - python=3.8 + - pytorch::pytorch + - pip: + - bert-score==0.3.10 + - black==21.7b0 + - datasets>=1.12.0,<2.0.0 + - flake8==3.9.2 + - gdown + - isort==5.9.2 + - jupyterlab==3.0.14 + - jury>=2.1.0,<3.0.0 + - protobuf>=3.17.3 + - pyyaml + - pysocks==1.5.6 + - rouge-score==0.0.4 + - sacrebleu==1.5.1 + - sentencepiece==0.1.96 + - transformers>=4.10.0,<5.0.0 + - trtokenizer==0.0.3 diff --git a/hf/__init__.py b/hf/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/hf/model.py b/hf/model.py new file mode 100644 index 0000000..0c8f39d --- /dev/null +++ b/hf/model.py @@ -0,0 +1,106 @@ +import logging +import os +from typing import Optional, Tuple + +import torch +from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, BertForQuestionAnswering, BertTokenizerFast +from transformers.hf_argparser import DataClass + +from utils.file import download_from_gdrive_and_unzip +from utils.torch import assert_not_all_frozen, freeze_embeds + +TURQUE_S1_GDRIVE_URL = "https://drive.google.com/uc?id=10hHFuavHCofDczGSzsH1xPHgTgAocOl1" + +logger = logging.getLogger(__name__) +logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=os.environ.get("LOGLEVEL", "INFO").upper(), +) + + +class MT5Model: + def __init__( + self, + model_name_or_path: str = "turque-s1", + tokenizer_name_or_path: str = None, + freeze_embeddings: bool = False, + cache_dir: Optional[str] = None, + use_cuda: bool = None, + ): + # try downloading pretrained files + if model_name_or_path == "turque-s1": + model_name_or_path = "data/pretrained/turque-s1/" + if not os.path.isfile("data/pretrained/turque-s1/pytorch_model.bin"): + download_from_gdrive_and_unzip(TURQUE_S1_GDRIVE_URL, model_name_or_path) + logger.info(f"pretrained model is downloaded to {model_name_or_path}") + else: + logger.info(f"using pretrained model at {model_name_or_path}") + model_name_or_path = "data/pretrained/turque-s1/" + + model = AutoModelForSeq2SeqLM.from_pretrained( + model_name_or_path, + cache_dir=cache_dir, + ) + if tokenizer_name_or_path is not None: + tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path, cache_dir=cache_dir) + else: + try: + tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, cache_dir=cache_dir) + except: + tokenizer = AutoTokenizer.from_pretrained( + os.path.join(model_name_or_path, "tokenizer_config.json"), cache_dir=cache_dir + ) + assert model.__class__.__name__ in ["MT5ForConditionalGeneration"] + self.model = model + self.tokenizer = tokenizer + self.type = "mt5" + + self.device = "cuda" if torch.cuda.is_available() and use_cuda else "cpu" + self.model.to(self.device) + + if freeze_embeddings: + logger.info("freezing embeddings of the model") + freeze_embeds(self.model) + assert_not_all_frozen(self.model) + + self.model.resize_token_embeddings(len(self.tokenizer)) + + +class BertModel: + def __init__( + self, + model_name_or_path: str = "dbmdz/bert-base-turkish-cased", + tokenizer_name_or_path: str = None, + freeze_embeddings: bool = False, + cache_dir: Optional[str] = None, + use_cuda: bool = None, + ): + + model = BertForQuestionAnswering.from_pretrained( + model_name_or_path, + cache_dir=cache_dir, + ) + if tokenizer_name_or_path is not None: + tokenizer = BertTokenizerFast.from_pretrained(tokenizer_name_or_path, cache_dir=cache_dir) + else: + try: + tokenizer = BertTokenizerFast.from_pretrained(model_name_or_path, cache_dir=cache_dir) + except: + tokenizer = BertTokenizerFast.from_pretrained( + os.path.join(model_name_or_path, "tokenizer_config.json"), cache_dir=cache_dir + ) + assert model.__class__.__name__ in ["BertForQuestionAnswering"] + self.model = model + self.tokenizer: BertTokenizerFast = tokenizer + self.type = "bert" + + self.device = "cuda" if torch.cuda.is_available() and use_cuda else "cpu" + self.model.to(self.device) + + if freeze_embeddings: + logger.info("freezing embeddings of the model") + freeze_embeds(self.model) + assert_not_all_frozen(self.model) + + self.model.resize_token_embeddings(len(self.tokenizer)) diff --git a/prepare_data.py b/prepare_data.py new file mode 100644 index 0000000..279eaa7 --- /dev/null +++ b/prepare_data.py @@ -0,0 +1,193 @@ +import logging +import os +from pathlib import Path +from typing import List + +import pandas as pd +import torch +from datasets import Dataset + +from core.argument_parsers import parser +from core.dataset_parsers import load_and_prepare_dataset +from hf.model import BertModel, BertTokenizerFast, MT5Model + +logger = logging.getLogger(__name__) + + +class MT5DataProcessor: + def __init__(self, tokenizer, max_source_length=512, max_target_length=80): + self.tokenizer = tokenizer + self.max_source_length = max_source_length + self.max_target_length = max_target_length + + def process(self, dataset): + dataset = dataset.map(self._convert_to_features, batched=True) + + return dataset + + # tokenize the examples + def _convert_to_features(self, example_batch): + source_encoding = self.tokenizer.batch_encode_plus( + example_batch["source_text"], + max_length=self.max_source_length, + padding="max_length", + pad_to_max_length=True, + truncation=True, + ) + target_encoding = self.tokenizer.batch_encode_plus( + example_batch["target_text"], + max_length=self.max_target_length, + padding="max_length", + pad_to_max_length=True, + truncation=True, + ) + + encodings = { + "source_ids": source_encoding["input_ids"], + "target_ids": target_encoding["input_ids"], + "attention_mask": source_encoding["attention_mask"], + } + + return encodings + + +class BertDataProcessor: + def __init__(self, tokenizer: BertTokenizerFast, max_source_length: int = 512): + self.tokenizer = tokenizer + self.max_source_length = max_source_length + + def process(self, dataset): + dataset = dataset.map(self._convert_to_features, batched=True) + + return dataset + + def _add_token_positions(self, encodings, answers): + start_positions = [] + end_positions = [] + for i in range(len(answers)): + start_positions.append(encodings.char_to_token(i, answers[i]["answer_start"])) + end_positions.append(encodings.char_to_token(i, answers[i]["answer_end"] - 1)) + + # if start position is None, the answer passage has been truncated + if start_positions[-1] is None: + start_positions[-1] = self.max_source_length + if end_positions[-1] is None: + end_positions[-1] = self.max_source_length + + encodings.update({"start_positions": start_positions, "end_positions": end_positions}) + + # tokenize the examples + def _convert_to_features(self, example_batch): + encodings = self.tokenizer( + example_batch["context"], + example_batch["question"], + max_length=self.max_source_length, + padding="max_length", + pad_to_max_length=True, + truncation=True, + ) + + self._add_token_positions(encodings, example_batch["answer"]) + + return encodings + + +def _read_datasets( + names: List[str], target_format="mt5", mt5_task_list: List[str] = ["ans_ext", "qa", "qg"], mt5_qg_format="highlight" +) -> Dataset: + """ + Args: + names: lisf of dataset subset names or paths + target_format (str): output format ('mt5' or 'bert') + mt5_task_list: list of tasks for mt5 data to be prepared + mt5_qg_format: "highlight", "prepend" or "both" + """ + data = [] + for name in names: + data.extend( + load_and_prepare_dataset( + name, target_format=target_format, mt5_task_list=mt5_task_list, mt5_qg_format=mt5_qg_format + ) + ) + data = Dataset.from_pandas(pd.DataFrame(data)) + return data + + +def main(args_file_path: str = None): + model_args, data_args, train_args = parser(args_file_path) + + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=os.environ.get("LOGLEVEL", "INFO").upper(), + ) + + # set datasets + train_dataset = _read_datasets( + names=train_args.train_dataset_list, + target_format=model_args.model_type, + mt5_task_list=train_args.mt5_task_list, + mt5_qg_format=train_args.mt5_qg_format, + ) + valid_dataset = _read_datasets( + names=train_args.valid_dataset_list, + target_format=model_args.model_type, + mt5_task_list=train_args.mt5_task_list, + mt5_qg_format=train_args.mt5_qg_format, + ) + + # set tokenizer + if model_args.model_type == "mt5": + model = MT5Model(model_args.model_name_or_path) + tokenizer = model.tokenizer + tokenizer.add_tokens(["", ""]) + elif model_args.model_type == "bert": + model = BertModel(model_args.model_name_or_path) + tokenizer = model.tokenizer + + # set processor + if model_args.model_type == "mt5": + processor = MT5DataProcessor( + tokenizer, max_source_length=train_args.max_source_length, max_target_length=train_args.max_target_length + ) + elif model_args.model_type == "bert": + processor = BertDataProcessor(tokenizer, max_source_length=train_args.max_source_length) + + # process datasets + train_dataset = processor.process(train_dataset) + valid_dataset = processor.process(valid_dataset) + + if model_args.model_type == "mt5": + columns = ["source_ids", "target_ids", "attention_mask"] + train_dataset.set_format(type="torch", columns=columns) + valid_dataset.set_format(type="torch", columns=columns) + elif model_args.model_type == "bert": + columns = ["start_positions", "end_positions", "input_ids", "attention_mask"] + train_dataset.set_format(type="torch") + valid_dataset.set_format(type="torch") + + # create train/valid file dirs + train_file_path = Path(str(data_args.train_file_path).strip()) + if not train_file_path.parent.exists(): + train_file_path.parent.mkdir(parents=True, exist_ok=True) + valid_file_path = Path(str(data_args.valid_file_path).strip()) + if not valid_file_path.parent.exists(): + valid_file_path.parent.mkdir(parents=True, exist_ok=True) + + # save train/valid files + torch.save(train_dataset, train_file_path) + logger.info(f"saved train dataset at {train_file_path}") + + torch.save(valid_dataset, valid_file_path) + logger.info(f"saved validation dataset at {valid_file_path}") + + # save tokenizer + tokenizer_path = model_args.tokenizer_path + if not os.path.exists(tokenizer_path): + os.mkdir(tokenizer_path) + tokenizer.save_pretrained(tokenizer_path) + logger.info(f"saved tokenizer at {tokenizer_path}") + + +if __name__ == "__main__": + main() diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..3f398cf --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,17 @@ +[tool.black] +line-length = 120 +exclude = ''' +( + /( + .git + | .vscode + | .venv + | runs + | data + )/ +) +''' + +[tool.isort] +line_length = 120 +profile = "black" \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..90a5dd5 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,13 @@ +bert-score==0.3.10 +datasets>=1.12.0,<2.0.0 +gdown +jury>=2.1.0,<3.0.0 +trtokenizer==0.0.3 +protobuf>=3.17.3 +pysocks==1.5.6 +pyyaml +rouge-score==0.0.4 +sacrebleu==1.5.1 +sentencepiece==0.1.96 +torch==1.10.0 +transformers>=4.10.0,<5.0.0 diff --git a/run.py b/run.py new file mode 100644 index 0000000..29c8bfc --- /dev/null +++ b/run.py @@ -0,0 +1,230 @@ +import json +import logging +import os +from typing import Tuple + +import torch +import transformers +from transformers import Trainer as HFTrainer +from transformers import set_seed +from transformers.hf_argparser import DataClass +from transformers.optimization import Adafactor, AdamW +from transformers.trainer import Trainer + +from core.argument_parsers import parser +from core.collator import T2TDataCollator +from core.evaluate import evaluate_on_train_end +from hf.model import BertModel, MT5Model +from prepare_data import main as prepare_data +from utils.file import save_experiment_config +from utils.neptune import init_neptune, log_to_neptune +from utils.wandb import init_wandb, log_to_wandb + + +def setup_logger(args: DataClass) -> logging.Logger: + logger = logging.getLogger(__name__) + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=os.environ.get("LOGLEVEL", "INFO").upper() if args.local_rank in [-1, 0] else logging.WARN, + ) + logger.warning( + "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s", + args.local_rank, + args.device, + args.n_gpu, + bool(args.local_rank != -1), + args.fp16, + ) + logger.info("Training/evaluation parameters %s", args) + return logger + + +def check_output(args: DataClass, logger: logging.Logger = None) -> None: + if ( + os.path.exists(args.output_dir) + and os.listdir(args.output_dir) + and args.do_train + and not args.overwrite_output_dir + ): + raise ValueError( + f"Output directory ({args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to overcome." + ) + + +def load_datasets(args: DataClass, train: bool, eval: bool, logger: logging.Logger) -> Tuple: + logger.info("loading dataset") + train_dataset = torch.load(args.train_file_path) if train else None + valid_dataset = torch.load(args.valid_file_path) if eval else None + logger.info("finished loading dataset") + + return train_dataset, valid_dataset + + +def main(args_file_path: str = None): + + model_args, data_args, training_args = parser(args_file_path) + + # check for output_dir with given arguments. + check_output(training_args) + + logger = setup_logger(training_args) + + # set seed + set_seed(training_args.seed) + + # initialize experiment tracking + report_to = [] + + if training_args.do_train: + wandb_status, wandb = init_wandb(project=model_args.wandb_project, name=training_args.run_name) + else: + wandb_status, wandb = init_wandb( + project=model_args.wandb_project, name=training_args.run_name, id=model_args.wandb_id + ) + neptune_status, neptune = init_neptune( + project=model_args.neptune_project, api_token=model_args.neptune_api_token, name=training_args.run_name + ) + + if wandb_status: + report_to.append("wandb") + if neptune_status: + report_to.append("neptune") + + training_args.report_to = report_to + + # disable wandb console logs + logging.getLogger("wandb.run_manager").setLevel(logging.WARNING) + + # prepare data() + if data_args.prepare_data: + prepare_data(args_file_path) + + # load model + if model_args.model_type == "mt5": + model = MT5Model( + model_name_or_path=model_args.model_name_or_path, + tokenizer_name_or_path=model_args.tokenizer_path, + freeze_embeddings=training_args.freeze_embeddings, + cache_dir=model_args.cache_dir, + use_cuda=True, + ) + elif model_args.model_type == "bert": + model = BertModel( + model_name_or_path=model_args.model_name_or_path, + tokenizer_name_or_path=model_args.tokenizer_path, + freeze_embeddings=training_args.freeze_embeddings, + cache_dir=model_args.cache_dir, + use_cuda=True, + ) + train_dataset, valid_dataset = load_datasets( + data_args, train=training_args.do_train, eval=training_args.do_eval, logger=logger + ) + + # set optimizer + if training_args.adafactor: + # as adviced in https://huggingface.co/transformers/main_classes/optimizer_schedules.html#adafactor-pytorch + optimizer = Adafactor( + model.model.parameters(), + scale_parameter=False, + relative_step=False, + warmup_init=False, + weight_decay=training_args.weight_decay, + lr=training_args.learning_rate, + ) + else: + optimizer = AdamW( + model.model.parameters(), weight_decay=training_args.weight_decay, lr=training_args.learning_rate + ) + + if model_args.model_type == "mt5": + # initialize data_collator + data_collator = T2TDataCollator( + tokenizer=model.tokenizer, mode="training", using_tpu=training_args.tpu_num_cores is not None + ) + + # fix https://discuss.huggingface.co/t/mt5-fine-tuning-keyerror-source-ids/5257/2 + training_args.remove_unused_columns = False if model_args.model_type == "mt5" else True + + # export experiment config + save_experiment_config(model_args, data_args, training_args) + + # start training + if training_args.do_train: + # init model + if model_args.model_type == "mt5": + trainer: Trainer = HFTrainer( + model=model.model, + args=training_args, + train_dataset=train_dataset, + eval_dataset=valid_dataset, + data_collator=data_collator, + optimizers=(optimizer, None), + ) + elif model_args.model_type == "bert": + trainer: Trainer = HFTrainer( + model=model.model, + args=training_args, + train_dataset=train_dataset, + eval_dataset=valid_dataset, + optimizers=(optimizer, None), + ) + + # perform training + trainer.train( + resume_from_checkpoint=model_args.model_name_or_path + if os.path.isdir(model_args.model_name_or_path) + else None + ) + trainer.save_model() + # For convenience, we also re-save the tokenizer to the same directory, + # so that you can share your model easily on huggingface.co/models =) + + model.tokenizer.save_pretrained(training_args.output_dir) + + # start evaluation + if training_args.do_eval and training_args.local_rank in [-1, 0]: + # arange neptune/wandb loggers + if training_args.do_train: + for callback in trainer.callback_handler.callbacks: + if isinstance(callback, transformers.integrations.WandbCallback): + wandb = callback._wandb + for callback in trainer.callback_handler.callbacks: + if isinstance(callback, transformers.integrations.NeptuneCallback): + neptune_run = callback._neptune_run + if not training_args.do_train: + if "neptune" in report_to: + neptune_run = neptune.init( + project=os.getenv("NEPTUNE_PROJECT"), + api_token=os.getenv("NEPTUNE_API_TOKEN"), + mode=os.getenv("NEPTUNE_CONNECTION_MODE", "async"), + name=os.getenv("NEPTUNE_RUN_NAME", None), + run=model_args.neptune_run, + ) + elif "wandb" in report_to: + wandb.init(project=model_args.wandb_project, name=model_args.run_name, id=model_args.wandb_id) + + # calculate evaluation results + overall_results = evaluate_on_train_end(model_args, training_args) + + # log to neptune/wandb + if "neptune" in report_to: + log_to_neptune(neptune_run, overall_results) + if "wandb" in report_to: + log_to_wandb(wandb, overall_results) + + +def _mp_fn(index): + # For xla_spawn (TPUs) + main() + + +def run_multi(args_dict): + with open("args.json", "w") as f: + json.dump(args_dict, f) + + main(args_file="args.json") + + +if __name__ == "__main__": + main() diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 0000000..a3efd0f --- /dev/null +++ b/setup.cfg @@ -0,0 +1,5 @@ +[flake8] +max-line-length = 120 +exclude =.git +ignore = I101,I201,F401,F403,S001,D100,D101,D102,D103,D104,D105,D106,D107,D200,D205,D400,W504,D202,E203,E722,W503,B006 +inline-quotes = " diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_config.yaml b/tests/test_config.yaml new file mode 100644 index 0000000..545a432 --- /dev/null +++ b/tests/test_config.yaml @@ -0,0 +1,40 @@ +model_name_or_path: "google/mt5-small" +tokenizer_path: "mt5_small_tokenizer" +label_smoothing_factor: 0 +freeze_embeddings: false +run_name: null +wandb_project: null +wandb_id: null +neptune_project: null +neptune_run: null +neptune_api_token: null +train_dataset_list: ["tquad.small"] +valid_dataset_list: ["tquad.small"] +eval_dataset_list: ["tquad.small"] +train_file_path: "data/train_data_multitask_mt5.pt" +valid_file_path: "data/valid_data_multitask_mt5.pt" +max_source_length: 512 +max_target_length: 64 +prepare_data: true +mt5_task_list: [ + "qa", + "qg", + "ans_ext" +] +mt5_qg_format: "highlight" +output_dir: "runs/exp1" +do_train: true +do_eval: true +evaluation_strategy: "steps" +eval_steps: 1 +eval_accumulation_steps: 1 +per_device_train_batch_size: 1 +per_device_eval_batch_size: 1 +gradient_accumulation_steps: 1 +learning_rate: 1.0e-4 +num_train_epochs: 1 +save_total_limit: 1 +no_cuda: true +seed: 42 +max_steps: 2 +overwrite_output_dir: true \ No newline at end of file diff --git a/tr_non_suffixes b/tr_non_suffixes new file mode 100644 index 0000000..a6bfd9e --- /dev/null +++ b/tr_non_suffixes @@ -0,0 +1,246 @@ +# This is a non-breaking prefix list for the Turkish language. +# The file is used for sentence tokenization (tr_tokenizer/tokenizer/SentenceTokenizer class). +# +#ASSUMPTION: +# +# Anything in this file, followed by a period (and an upper-case word), does NOT +# indicate an end-of-sentence marker. +# Special cases are included for prefixes that ONLY appear before 0-9 numbers. +# Any single upper case letter followed by a period is not a sentence ender +# Usually upper case letters are initials in a name. +A +B +C +D +E +F +G +H +I +J +K +L +M +N +O +P +Q +R +S +T +U +V +W +X +Y +Z +# Usually upper case letters are initials in a name (Turkish alphabet) +Ç +Ğ +İ +Ö +Ş +Ü +# Roman Numerals +I +II +III +IV +V +VI +VII +VIII +IX +X +XI +XII +XIII +XIV +XV +XVI +XVII +XVIII +XIX +XX + +# English -- but these work globally for all languages: +Mr +Mrs +No +pp +St +no +Sr +Jr +Bros +etc +vs +esp +Fig +fig +Jan +Feb +Mar +Apr +Jun +Jul +Aug +Sep +Sept +Oct +Okt +Nov +Dec +Ph.D +PhD +# in "et al." +al +cf +Inc +Ms +Gen +Sen +Prof +Dr +Corp +Co +# http://en.wiktionary.org/wiki/Category:Turkish_abbreviations +Av +no +Başk.yard +Bşk.yrd +Akad +Alb +Alm +anat +ant +Apt +Ar +Ar. Gör +ark +As +Asb +Asist +astr +astrol +Atğm +atm +Av +bağ +Bçvş +B.E +bitb +biy +bk +Bl +Bn +Bnb +bot +Böl +bs +Bşk +Bul +Bulg +C +Cad +coğ +çev +Çvş +D +dam +db +dbl +Doç +doğ +Dr +Dz. Kuv. K +dzş +e +Ecz +ed +ekon +Ens +Erm +F +f +Fak +Far +fel +fil +fiz +Fr +Gen +geom +gn +Gnkur +Gön +H.O +Hv. Kuv +Hv. Kuv. K +Hz +İbr +is +İsp +İt +Mah +Sok +LTD +ŞTİ +TİC +PAZ +SAN +CAD +SK +ORG +SAN +BEL +İST +JAN +EĞ +ÜNİV +TUG +SNR +M.E.B +KDZ +A.Ş +MD +GN +ŞB +BŞ +K.K +Yzb + +# Number indicators +# add #NUMERIC_ONLY# after the word if it should ONLY be non-breaking when a 0-9 digit follows it +hayır + +# Ordinals are (apparently) done with . in Turkish - "1." = "1st" in English +# Ordinals 1 to 10000 added to self.__non_breaking_suffixes, thus they were deleted from here +# Numbers that may be used in date were added +01 +02 +03 +04 +05 +06 +07 +08 +09 + +# (d. 998 - ö. 1068) +d +ö +(d +- ö +# Ömer b. Abdülazīz +b +# Ord. Prof. Dr. +Ord +# 14-17. Yüzyıllarda ?? +#(Çerkez, Abaza, vs.) Musluaman ?? +# Prof.Dr. Ernst +Prof.Dr +# Ss. Cyril and Methodius University +Ss +# Bulgar Bilimler Akademisi Başkanı Acad. Stefan Vodenicharov'dan aldı. +Acad \ No newline at end of file diff --git a/utils/__init__.py b/utils/__init__.py new file mode 100644 index 0000000..b2de5a9 --- /dev/null +++ b/utils/__init__.py @@ -0,0 +1,5 @@ +import os.path +from pathlib import Path + +PROJECT_ROOT = Path(os.path.dirname(__file__)).parent.resolve() +SOURCE_ROOT = PROJECT_ROOT diff --git a/utils/file.py b/utils/file.py new file mode 100644 index 0000000..f75e008 --- /dev/null +++ b/utils/file.py @@ -0,0 +1,141 @@ +import json +import os +import urllib.request +import zipfile +from dataclasses import asdict +from pathlib import Path +from typing import Dict + +import yaml + + +def safe_download(target_path: str, source_url: str, source_url2=None, min_bytes=1e0, error_msg="") -> None: + """Attempts to download file from source_url or source_url2, checks and removes incomplete downloads < min_bytes""" + file = Path(target_path) + assert_msg = f"Downloaded file '{file}' does not exist or size is < min_bytes={min_bytes}" + try: # url1 + print(f"Downloading {source_url} to {file}...") + urllib.request.urlretrieve( + source_url, + target_path, + ) + assert file.exists() and file.stat().st_size > min_bytes, assert_msg # check + except Exception as e: # url2 + file.unlink(missing_ok=True) # remove partial downloads + print(f"ERROR: {e}\nRe-attempting {source_url2 or source_url} to {file}...") + urllib.request.urlretrieve( + source_url2 or source_url, + target_path, + ) + finally: + if not file.exists() or file.stat().st_size < min_bytes: # check + file.unlink(missing_ok=True) # remove partial downloads + print(f"ERROR: {assert_msg}\n{error_msg}") + print("") + + +def attempt_download(source_url: str, target_path: str) -> None: + target_path = Path(str(target_path).strip().replace("'", "")) + if not target_path.exists(): + target_path.parent.mkdir(parents=True, exist_ok=True) + safe_download(target_path=str(target_path), source_url=source_url) + + +def load_json(load_path, object_hook=None): + """ + Loads json formatted data (given as "data") from load_path + Example inputs: + load_path: "dirname/squad.json" + """ + # read from path + with open(load_path, encoding="utf-8") as json_file: + data = json.load(json_file, object_hook=object_hook) + return data + + +def save_json(obj: Dict, path: str, encoding: str = "utf-8", indent: int = 4): + """ + Save dict as json file. + """ + with open(path, "w", encoding=encoding) as jf: + json.dump(obj, jf, indent=indent, default=str, ensure_ascii=False) + + +def read_yaml(yaml_path): + """ + Reads yaml file as dict. + """ + with open(yaml_path) as f: + yaml_data = yaml.load(f, Loader=yaml.FullLoader) + + return yaml_data + + +def save_yaml(dict_file, yaml_path): + """ + Saves dict as yaml file. + """ + + Path(yaml_path).parent.mkdir(parents=True, exist_ok=True) + + with open(yaml_path, "w") as file: + yaml.dump(dict_file, file) + + +def unzip(file_path: str, dest_dir: str): + """ + Unzips compressed .zip file. + Example inputs: + file_path: 'data/01_alb_id.zip' + dest_dir: 'data/' + """ + + # unzip file + with zipfile.ZipFile(file_path) as zf: + zf.extractall(dest_dir) + + +def download_from_url(from_url: str, to_path: str): + + Path(to_path).parent.mkdir(parents=True, exist_ok=True) + + if not os.path.exists(to_path): + urllib.request.urlretrieve( + from_url, + to_path, + ) + + +def save_experiment_config(model_args, data_args, training_args): + experiment_config = {} + experiment_config.update(asdict(model_args)) + experiment_config.update(asdict(data_args)) + experiment_config.update(asdict(training_args)) + yaml_path = Path(training_args.output_dir) / "experiment_config.yaml" + save_yaml(experiment_config, yaml_path) + + +def download_from_gdrive(url: str, save_dir: str) -> str: + """ + Downloads file from gdrive, shows progress. + Example inputs: + url: 'https://drive.google.com/uc?id=10hHFuavHCofDczGSzsH1xPHgTgAocOl1' + save_dir: 'data/' + """ + import gdown + + # create save_dir if not present + Path(save_dir).mkdir(parents=True, exist_ok=True) + # download file + filepath = gdown.download(url, save_dir, quiet=False) + return filepath + + +def download_from_gdrive_and_unzip(url: str, save_dir: str) -> str: + save_dir = save_dir + os.path.sep + # download zip file + filepath = download_from_gdrive(url, save_dir) + # extract zip file + unzip(filepath, str(Path(filepath).parent)) + # remove zip file + os.remove(filepath) diff --git a/utils/neptune.py b/utils/neptune.py new file mode 100644 index 0000000..0f092ea --- /dev/null +++ b/utils/neptune.py @@ -0,0 +1,27 @@ +import os + + +def init_neptune(project: str = None, api_token: str = None, name: str = None): + status = False + neptune = None + if project is not None and api_token is not None: + try: + import neptune.new as neptune + + os.environ["NEPTUNE_PROJECT"] = project + os.environ["NEPTUNE_API_TOKEN"] = api_token + if name is not None: + os.environ["NEPTUNE_RUN_NAME"] = name + + status = True + except ImportError: + print("neptune not installed, skipping neptune logging. 'pip install neptune-client' for neptune logging.") + except Exception as e: + print(e) + + return status, neptune + + +def log_to_neptune(neptune_run, dict): + for k, v in dict.items(): + neptune_run[k].log(v) diff --git a/utils/nlp.py b/utils/nlp.py new file mode 100644 index 0000000..7c0079d --- /dev/null +++ b/utils/nlp.py @@ -0,0 +1,481 @@ +import itertools +import logging +import os +import re +import string +from typing import Dict, List, Union + +from datasets import Dataset +from trtokenizer.tr_tokenizer import SentenceTokenizer + +from utils import SOURCE_ROOT +from utils.file import attempt_download, load_json + +logger = logging.getLogger(__name__) +logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=os.environ.get("LOGLEVEL", "INFO").upper(), +) + + +def sentence_tokenize(context: str) -> List[str]: + non_prefix_file_path = str(SOURCE_ROOT / "tr_non_suffixes") + sentence_tokenizer = SentenceTokenizer(non_breaking_prefix_file=non_prefix_file_path) + context = context.replace("\xa0", " ").replace("\ufeff", " ").replace("\t", " ") + + sentence_list = [] + for trtok_sentence in sentence_tokenizer.tokenize(context): + + pattern = re.escape(trtok_sentence) + pattern = " +".join(pattern.split()) + pattern += " {0,1}" # handle space between sentences + pattern = r"%s" % pattern + pattern += "\r{0,1}" # handle \r between sentences + pattern = r"%s" % pattern + pattern += "\n{0,1}" # handle \n between sentences + pattern = r"%s" % pattern + match_str = re.search(pattern, context) + start_idx, end_idx = match_str.span() + sentence = context[start_idx:end_idx] + sentence_list.append(sentence) + return sentence_list + + +def _get_correct_alignement(context, answer): + """Some original examples in SQuAD have indices wrong by 1 or 2 character. We test and fix this here.""" + gold_text = answer["text"] + start_idx = answer["answer_start"] + end_idx = start_idx + len(gold_text) + if context[start_idx:end_idx] == gold_text: + return start_idx, end_idx # When the gold label position is good + elif context[start_idx - 1 : end_idx - 1] == gold_text: + return start_idx - 1, end_idx - 1 # When the gold label is off by one character + elif context[start_idx - 2 : end_idx - 2] == gold_text: + return start_idx - 2, end_idx - 2 # When the gold label is off by two character + elif context[start_idx + 1 : end_idx + 1] == gold_text: + return start_idx + 1, end_idx + 1 # When the gold label is off by one character + elif context[start_idx + 2 : end_idx + 2] == gold_text: + return start_idx + 2, end_idx + 2 # When the gold label is off by two character + else: + raise ValueError() + + +def get_answer(qa: Dict): + """qa: each element of 'qas' field of squad formatted dataset paragraph""" + if qa.get("answer") is not None: + return qa["answer"] + try: + answer = qa["answers"][0] + except IndexError: + answer = None + return answer + + +def get_answer_indices_from_sentence(answer_text: str, sentence: str, loose_match: bool = False): + """ + first match is returned in match case + + Args: + answer_text (str) + sentence (str) + loose_match (bool) : if True, regex also matches substrings + + Returns: + None, if no match + {"text": answer_text, "answer_start": ans_start_idx, "answer_end": ans_end_idx}, if match + answer_start and answer_end are sentence-wise indexes, not context-wise + """ + # sometimes extracted answers does not match with the original text :/ + try: + pattern = r"(? ".join(answer_list) + " " + else: + target_text = None + + for sentence_ind2, sentence in enumerate(sentence_list): + if sentence_ind == sentence_ind2: + sentence = f" {sentence} " + source_text = f"{source_text} {sentence}" + source_text = source_text.strip() + + sample = {"source_text": source_text, "target_text": target_text, "answer_list": answer_list} + if sample["target_text"] is None: + sample + samples.append(sample) + + return samples + + +def prepare_qa_sample(context: str, question: str, answer: str = None): + """ + Args: + context (str) (assumed to be normalized via normalize_text) + question (str) + answer (str) + """ + prepare_target = True if answer else False + + source_text = f"question: {question} context: {context}" + if prepare_target: + target_text = f"{answer}" + else: + target_text = None + return {"source_text": source_text, "target_text": target_text} + + +def prepare_qg_samples(context, answer_list: List[Dict], question_list: List[str] = None, qg_format: str = "highlight"): + """ + Args: + context (str) + question_list (List[str]) + answer_list: [ + {'text': str, 'answer_start': int}, + {'text': str, 'answer_start': int}, + ... + ] + qg_format: 'highlight', 'prepend' or 'both' + """ + # split into sentences + + try: + samples = [] + for ind, answer in enumerate(answer_list): + start_pos, end_pos = _get_correct_alignement(context, answer) + answer_text = answer["text"] + + if qg_format == "prepend": + source_text = f"answer: {answer_text} context: {context}" + elif qg_format == "highlight": + source_text = f"generate question: {context[:start_pos]} {answer_text} {context[end_pos:]}" + elif qg_format == "both": + source_text = ( + f"answer: {answer_text} context: {context[:start_pos]} {answer_text} {context[end_pos:]}" + ) + else: + raise ValueError(f"unsupported qg format: {qg_format}") + + if question_list: + question = question_list[ind] + else: + question = None + + samples.append({"answer": answer_text, "source_text": source_text, "target_text": question}) + + except ValueError: + sentence_list = sentence_tokenize(normalize_text(context)) + + if question_list: + answer_list_per_sentence, question_list_per_sentence = get_answer_list_per_sentence( + sentence_list, answer_list, question_list + ) + else: + answer_list_per_sentence = get_answer_list_per_sentence(sentence_list, answer_list) + + samples = [] + for sentence_ind, answer_list in enumerate(answer_list_per_sentence): + if not answer_list: + continue + for answer_ind, answer in enumerate(answer_list): + sentence = sentence_list[sentence_ind] + sentence_list_copy = sentence_list[:] + + answer_start = answer["answer_start"] + answer_text = answer["text"] + answer_end = answer["answer_end"] + + sentence = f"{sentence[:answer_start]} {answer_text} {sentence[answer_end:]}" + sentence_list_copy[sentence_ind] = sentence + highlighted_context = " ".join(sentence_list_copy) + + if qg_format == "prepend": + source_text = f"answer: {answer_text} context: {context}" + elif qg_format == "highlight": + source_text = f"generate question: {highlighted_context}" + elif qg_format == "both": + source_text = f"answer: {answer_text} context: {highlighted_context}" + else: + raise ValueError(f"unsupported qg format: {qg_format}") + + if question_list: + question_list = question_list_per_sentence[sentence_ind] + question = question_list[answer_ind] + else: + question = None + + samples.append({"answer": answer_text, "source_text": source_text, "target_text": question}) + + if not samples: + samples.append({"answer": None, "source_text": None, "target_text": None}) + + return samples + + +def postprocess_answer_extraction_output(answer_extraction_output: str): + """ + Args: + answer_extraction_output (str): decoded answer extraction output + + Returns: + answer_text_list (List[str]) + """ + # parse answers + answers = answer_extraction_output.split("")[:-1] + # normalize and append answers + answer_text_list = [] + for answer_text in answers: + # append if not present + if answer_text and (answer_text not in answer_text_list): + answer_text_list.append(answer_text) + return answer_text_list + + +def download_and_load_dataset(source_url: str, target_path: str) -> Dataset: + attempt_download(source_url=source_url, target_path=target_path) + data = load_json(target_path) + return data + + +def remove_punctuations(text: str) -> str: + regex = re.compile("[%s]" % re.escape(string.punctuation)) + text = regex.sub(" ", text) + return " ".join(text.split()) + + +def replace_multiple_spaces_with_single_whitespace(text: str) -> str: + return re.sub("\\s+", " ", text) + + +def remove_citations(text: str) -> str: + """ + Removes the citations that consist of a pair of brackets having a substring + containing at least one digit inside them. + Args: + text (str): + + Returns: + + """ + text = re.sub("\[[a-zA-Z]\]", "", text) + return re.sub(r"\[(\s|\w)*\d+(\s|\w)*\]", "", text) + + +def handle_ugly_case(text: str) -> str: + pattern = r"(?<=\d.)(/)(?=\d.)" + return re.sub(pattern, "-", text) + + +def normalize_text(text: str) -> str: + text = text.strip() + text = replace_multiple_spaces_with_single_whitespace(text) + text = remove_citations(text) + text = handle_ugly_case(text) + return text diff --git a/utils/torch.py b/utils/torch.py new file mode 100644 index 0000000..dfddc65 --- /dev/null +++ b/utils/torch.py @@ -0,0 +1,56 @@ +from typing import Iterable, List + +from torch import nn + +# taken from https://github.com/patil-suraj/question_generation/blob/master/utils.py + + +def grad_status(model: nn.Module) -> Iterable: + return (par.requires_grad for par in model.parameters()) + + +def freeze_params(model: nn.Module): + for par in model.parameters(): + par.requires_grad = False + + +def freeze_embeds(model: nn.Module): + """Freeze token embeddings and positional embeddings for bart, just token embeddings for t5.""" + try: + freeze_params(model.model.shared) + for d in [model.model.encoder, model.model.decoder]: + freeze_params(d.embed_positions) + freeze_params(d.embed_tokens) + except AttributeError: + freeze_params(model.shared) + for d in [model.encoder, model.decoder]: + freeze_params(d.embed_tokens) + + +def assert_not_all_frozen(model): + model_grads: List[bool] = list(grad_status(model)) + npars = len(model_grads) + assert any(model_grads), f"none of {npars} weights require grad" + + +def label_smoothed_nll_loss(lprobs, target, epsilon, ignore_index=-100): + """From fairseq""" + if target.dim() == lprobs.dim() - 1: + target = target.unsqueeze(-1) + nll_loss = -lprobs.gather(dim=-1, index=target) + smooth_loss = -lprobs.sum(dim=-1, keepdim=True) + if ignore_index is not None: + pad_mask = target.eq(ignore_index) + nll_loss.masked_fill_(pad_mask, 0.0) + smooth_loss.masked_fill_(pad_mask, 0.0) + bs = pad_mask.long().sum() + else: + nll_loss = nll_loss.squeeze(-1) + smooth_loss = smooth_loss.squeeze(-1) + bs = lprobs.shape[0] + + nll_loss = nll_loss.sum() # mean()? Scared to break other math. + smooth_loss = smooth_loss.sum() + eps_i = epsilon / lprobs.size(-1) + loss = (1.0 - epsilon) * nll_loss + eps_i * smooth_loss + return loss / bs, nll_loss / bs diff --git a/utils/wandb.py b/utils/wandb.py new file mode 100644 index 0000000..9a211bd --- /dev/null +++ b/utils/wandb.py @@ -0,0 +1,22 @@ +def init_wandb(project: str = None, name: str = None, id: str = None): + status = False + wandb = None + if project is not None and name is not None: + try: + import wandb + + if not id: + wandb.init(project=project, name=name) + else: + wandb.init(project=project, name=name, id=id) + status = True + except ImportError: + print("wandb not installed, skipping wandb logging. 'pip install wandb' for wandb logging.") + except Exception as e: + print(e) + + return status, wandb + + +def log_to_wandb(wandb, dict): + wandb.log({**dict})