From 8cadd5df0386a298687f7b861f174c11bf4aa4bd Mon Sep 17 00:00:00 2001
From: fatih <34196005+fcakyon@users.noreply.github.com>
Date: Wed, 10 Nov 2021 21:07:37 +0300
Subject: [PATCH] initial upload (#1)
* initial commit
* read logging level from env
* fix comment
* update setup.cfg
* update environment yml
* rename train to run
* relocate modules
* add init file
* update readme
---
.github/workflows/ci.yml | 85 ++++
.gitignore | 140 +++++
LICENSE | 21 +
README.md | 242 ++++++++-
configs/default/config.json | 48 ++
configs/default/config.yaml | 35 ++
configs/paper/berturk-qatask-tquad2.yaml | 36 ++
configs/paper/mt5base-3task-both-tquad2.yaml | 35 ++
.../paper/mt5base-3task-highlight-tquad2.yaml | 35 ++
.../paper/mt5base-3task-prepend-tquad2.yaml | 35 ++
configs/paper/mt5base-qatask-tquad2.yaml | 35 ++
configs/paper/mt5base-qgtask-both-tquad2.yaml | 35 ++
.../mt5base-qgtask-highlight-tquad2.yaml | 35 ++
.../paper/mt5base-qgtask-prepend-tquad2.yaml | 35 ++
configs/paper/mt5small-3task-both-tquad2.yaml | 35 ++
core/__init__.py | 0
core/api.py | 297 +++++++++++
core/argument_parsers.py | 162 ++++++
core/bert_api.py | 129 +++++
core/collator.py | 81 +++
core/dataset_parsers.py | 195 +++++++
core/evaluate.py | 289 +++++++++++
core/generate.py | 163 ++++++
core/pipelines.py | 242 +++++++++
environment.yml | 26 +
hf/__init__.py | 0
hf/model.py | 106 ++++
prepare_data.py | 193 +++++++
pyproject.toml | 17 +
requirements.txt | 13 +
run.py | 230 +++++++++
setup.cfg | 5 +
tests/__init__.py | 0
tests/test_config.yaml | 40 ++
tr_non_suffixes | 246 +++++++++
utils/__init__.py | 5 +
utils/file.py | 141 +++++
utils/neptune.py | 27 +
utils/nlp.py | 481 ++++++++++++++++++
utils/torch.py | 56 ++
utils/wandb.py | 22 +
41 files changed, 4051 insertions(+), 2 deletions(-)
create mode 100644 .github/workflows/ci.yml
create mode 100644 .gitignore
create mode 100644 LICENSE
create mode 100644 configs/default/config.json
create mode 100644 configs/default/config.yaml
create mode 100644 configs/paper/berturk-qatask-tquad2.yaml
create mode 100644 configs/paper/mt5base-3task-both-tquad2.yaml
create mode 100644 configs/paper/mt5base-3task-highlight-tquad2.yaml
create mode 100644 configs/paper/mt5base-3task-prepend-tquad2.yaml
create mode 100644 configs/paper/mt5base-qatask-tquad2.yaml
create mode 100644 configs/paper/mt5base-qgtask-both-tquad2.yaml
create mode 100644 configs/paper/mt5base-qgtask-highlight-tquad2.yaml
create mode 100644 configs/paper/mt5base-qgtask-prepend-tquad2.yaml
create mode 100644 configs/paper/mt5small-3task-both-tquad2.yaml
create mode 100644 core/__init__.py
create mode 100644 core/api.py
create mode 100644 core/argument_parsers.py
create mode 100644 core/bert_api.py
create mode 100644 core/collator.py
create mode 100644 core/dataset_parsers.py
create mode 100644 core/evaluate.py
create mode 100644 core/generate.py
create mode 100644 core/pipelines.py
create mode 100644 environment.yml
create mode 100644 hf/__init__.py
create mode 100644 hf/model.py
create mode 100644 prepare_data.py
create mode 100644 pyproject.toml
create mode 100644 requirements.txt
create mode 100644 run.py
create mode 100644 setup.cfg
create mode 100644 tests/__init__.py
create mode 100644 tests/test_config.yaml
create mode 100644 tr_non_suffixes
create mode 100644 utils/__init__.py
create mode 100644 utils/file.py
create mode 100644 utils/neptune.py
create mode 100644 utils/nlp.py
create mode 100644 utils/torch.py
create mode 100644 utils/wandb.py
diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml
new file mode 100644
index 0000000..e7ed3be
--- /dev/null
+++ b/.github/workflows/ci.yml
@@ -0,0 +1,85 @@
+name: CI
+
+on:
+ push:
+ branches: [main]
+ pull_request:
+ branches: [main]
+
+jobs:
+ build:
+ runs-on: ubuntu-latest
+ strategy:
+ matrix:
+ operating-system: [ubuntu-latest]
+ python-version: [3.8, 3.9]
+ fail-fast: false
+
+ steps:
+ - name: Checkout
+ uses: actions/checkout@v2
+
+ - name: Set up Python
+ uses: actions/setup-python@v2
+ with:
+ python-version: ${{ matrix.python-version }}
+
+ - name: Restore Ubuntu cache
+ uses: actions/cache@v1
+ if: matrix.operating-system == 'ubuntu-latest'
+ with:
+ path: ~/.cache/pip
+ key: ${{ matrix.os }}-${{ matrix.python-version }}-${{ hashFiles('**/setup.py')}}
+ restore-keys: ${{ matrix.os }}-${{ matrix.python-version }}-
+
+ - name: Restore MacOS cache
+ uses: actions/cache@v1
+ if: matrix.operating-system == 'macos-latest'
+ with:
+ path: ~/Library/Caches/pip
+ key: ${{ matrix.os }}-${{ matrix.python-version }}-${{ hashFiles('**/setup.py')}}
+ restore-keys: ${{ matrix.os }}-${{ matrix.python-version }}-
+
+ - name: Restore Windows cache
+ uses: actions/cache@v1
+ if: matrix.operating-system == 'windows-latest'
+ with:
+ path: ~\AppData\Local\pip\Cache
+ key: ${{ matrix.os }}-${{ matrix.python-version }}-${{ hashFiles('**/setup.py')}}
+ restore-keys: ${{ matrix.os }}-${{ matrix.python-version }}-
+
+ - name: Update pip
+ run: python -m pip install --upgrade pip
+
+ - name: Install PyTorch on Linux and Windows
+ if: >
+ matrix.operating-system == 'ubuntu-latest' ||
+ matrix.operating-system == 'windows-latest'
+ run: >
+ pip install torch==1.10.0+cpu
+ -f https://download.pytorch.org/whl/torch_stable.html
+
+ - name: Install PyTorch on MacOS
+ if: matrix.operating-system == 'macos-latest'
+ run: pip install torch==1.10.0
+
+ - name: Install requirements
+ run: >
+ pip install -r requirements.txt
+
+ - name: Lint with flake8, black and isort
+ run: |
+ pip install "black==21.7b0" "flake8==3.9.2" "isort==5.9.2"
+ # stop the build if there are Python syntax errors or undefined names
+ flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
+ black . --check --config pyproject.toml
+ isort -c .
+ # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
+ flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
+
+ - name: Test prepare_data and train
+ run: |
+ # prepare_data
+ python prepare_data.py tests/test_config.yaml
+ # train
+ python run.py tests/test_config.yaml
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..0631120
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,140 @@
+# Byte-compiled / optimized / DLL files
+__pycache__/
+*.py[cod]
+*$py.class
+
+# C extensions
+*.so
+
+# Distribution / packaging
+.Python
+build/
+develop-eggs/
+dist/
+downloads/
+eggs/
+.eggs/
+lib/
+lib64/
+parts/
+sdist/
+var/
+wheels/
+pip-wheel-metadata/
+share/python-wheels/
+*.egg-info/
+.installed.cfg
+*.egg
+MANIFEST
+
+# PyInstaller
+# Usually these files are written by a python script from a template
+# before PyInstaller builds the exe, so as to inject date/other infos into it.
+*.manifest
+*.spec
+
+# Installer logs
+pip-log.txt
+pip-delete-this-directory.txt
+
+# Unit test / coverage reports
+htmlcov/
+.tox/
+.nox/
+.coverage
+.coverage.*
+.cache
+nosetests.xml
+coverage.xml
+*.cover
+*.py,cover
+.hypothesis/
+.pytest_cache/
+
+# Translations
+*.mo
+*.pot
+
+# Django stuff:
+*.log
+local_settings.py
+db.sqlite3
+db.sqlite3-journal
+
+# Flask stuff:
+instance/
+.webassets-cache
+
+# Scrapy stuff:
+.scrapy
+
+# Sphinx documentation
+docs/_build/
+
+# PyBuilder
+target/
+
+# Jupyter Notebook
+.ipynb_checkpoints
+
+# IPython
+profile_default/
+ipython_config.py
+
+# pyenv
+.python-version
+
+# pipenv
+# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
+# However, in case of collaboration, if having platform-specific dependencies or dependencies
+# having no cross-platform support, pipenv may install dependencies that don't work, or not
+# install all needed dependencies.
+#Pipfile.lock
+
+# PEP 582; used by e.g. github.com/David-OConnor/pyflow
+__pypackages__/
+
+# Celery stuff
+celerybeat-schedule
+celerybeat.pid
+
+# SageMath parsed files
+*.sage.py
+
+# Environments
+.env
+.venv
+env/
+venv/
+ENV/
+env.bak/
+venv.bak/
+
+# Spyder project settings
+.spyderproject
+.spyproject
+
+# Rope project settings
+.ropeproject
+
+# mkdocs documentation
+/site
+
+# mypy
+.mypy_cache/
+.dmypy.json
+dmypy.json
+
+# Pyre type checker
+.pyre/
+
+# other
+.idea/
+.vscode
+*.pt
+data/
+mt5_qg_tokenizer/
+runs/
+*tokenizer/
+wandb
+.neptune
\ No newline at end of file
diff --git a/LICENSE b/LICENSE
new file mode 100644
index 0000000..6fb8020
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,21 @@
+MIT License
+
+Copyright (c) 2021 obss
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
diff --git a/README.md b/README.md
index 49ed13b..9eb07a4 100644
--- a/README.md
+++ b/README.md
@@ -1,2 +1,240 @@
-# turkish-question-generation
-Automated question generation and question answering from Turkish texts using text-to-text transformers
+
+
+ Turkish Question Generation
+
+
+
+ Offical implementation for "Automated question generation & question answering from Turkish texts using text-to-text transformers".
+
+
+
+
+
+
+install
+
+
+```bash
+git clone https://github.com/obss/turkish-question-generation.git
+cd turkish-question-generation
+pip install -r requirements.txt
+```
+
+
+
+
+train
+
+
+- start a training using args:
+
+```bash
+python run.py --model_name_or_path google/mt5-small --output_dir runs/exp1 --do_train --do_eval --tokenizer_name_or_path mt5_qg_tokenizer --per_device_train_batch_size 4 --gradient_accumulation_steps 2 --learning_rate 1e-4 --seed 42 --save_total_limit 1
+```
+
+- download [json config](configs/default/config.json) file and start a training:
+
+```bash
+python run.py config.json
+```
+
+- downlaod [yaml config](configs/default/config.yaml) file and start a training:
+
+```bash
+python run.py config.yaml
+```
+
+
+
+
+
+evaluate
+
+
+- arrange related params in config:
+
+```yaml
+do_train: false
+do_eval: true
+eval_dataset_list: ["tquad2-valid", "xquad.tr"]
+prepare_data: true
+mt5_task_list: ["qa", "qg", "ans_ext"]
+mt5_qg_format: "both"
+no_cuda: false
+```
+
+- start an evaluation:
+
+```bash
+python run.py config.yaml
+```
+
+
+
+
+
+neptune
+
+
+- install neptune:
+
+```bash
+pip install neptune-client
+```
+
+- download [config](configs/default/config.yaml) file and arrange neptune params:
+
+```yaml
+run_name: 'exp1'
+neptune_project: 'name/project'
+neptune_api_token: 'YOUR_API_TOKEN'
+```
+
+- start a training:
+
+```bash
+python train.py config.yaml
+```
+
+
+
+
+
+wandb
+
+
+- install wandb:
+
+```bash
+pip install wandb
+```
+
+- download [config](configs/default/config.yaml) file and arrange wandb params:
+
+```yaml
+run_name: 'exp1'
+wandb_project: 'turque'
+```
+
+- start a training:
+
+```bash
+python train.py config.yaml
+```
+
+
+
+
+
+finetuned checkpoints
+
+
+[model_url1]: https://drive.google.com/file/d/10hHFuavHCofDczGSzsH1xPHgTgAocOl1/view?usp=sharing
+[model_url2]: https://huggingface.co/google/mt5-small
+[data_url1]: https://github.com/okanvk/Turkish-Reading-Comprehension-Question-Answering-Dataset/blob/master/data/2018%20%2B%202020%20veri%20k%C3%BCmesi/final_train_data_v2.json
+[data_url2]: https://github.com/okanvk/Turkish-Reading-Comprehension-Question-Answering-Dataset/blob/master/data/2018%20%2B%202020%20veri%20k%C3%BCmesi/final_dev_data_v2.json
+[data_url3]: https://github.com/deepmind/xquad/blob/master/xquad.tr.json
+
+
+|Name |Model |data
train |params
(M) |model size
(GB) |
+|--- |--- |--- |--- |--- |
+|[turque-s1][model_url1] |[mt5-small][model_url2] |[tquad-train][data_url1]+[tquad-val][data_url2]+[xquad.tr][data_url3] |300M |1.2GB |
+
+
+
+
+
+format
+
+
+- answer extraction:
+
+input:
+```
+" Osman Bey 1258 yılında Söğüt’te doğdu. Osman Bey 1 Ağustos 1326’da Bursa’da hayatını kaybetmiştir.1281 yılında Osman Bey 23 yaşında iken Ahi teşkilatından olan Şeyh Edebali’nin kızı Malhun Hatun ile evlendi."
+```
+
+target:
+```
+ 1258 Söğüt’te
+```
+
+- question answering:
+
+input:
+```
+"question: Osman Bey nerede doğmuştur? context: Osman Bey 1258 yılında Söğüt’te doğdu. Osman Bey 1 Ağustos 1326’da Bursa’da hayatını kaybetmiştir.1281 yılında Osman Bey 23 yaşında iken Ahi teşkilatından olan Şeyh Edebali’nin kızı Malhun Hatun ile evlendi."
+```
+
+target:
+```
+"Söğüt’te"
+```
+
+- question generation (prepend):
+
+input:
+```
+"answer: Söğüt’te context: Osman Bey 1258 yılında Söğüt’te doğdu. Osman Bey 1 Ağustos 1326’da Bursa’da hayatını kaybetmiştir.1281 yılında Osman Bey 23 yaşında iken Ahi teşkilatından olan Şeyh Edebali’nin kızı Malhun Hatun ile evlendi."
+```
+
+target:
+```
+"Osman Bey nerede doğmuştur?"
+```
+
+- question generation (highlight):
+
+input:
+```
+"generate question: Osman Bey 1258 yılında Söğüt’te doğdu. Osman Bey 1 Ağustos 1326’da Bursa’da hayatını kaybetmiştir.1281 yılında Osman Bey 23 yaşında iken Ahi teşkilatından olan Şeyh Edebali’nin kızı Malhun Hatun ile evlendi."
+```
+
+target:
+```
+"Osman Bey nerede doğmuştur?"
+```
+
+- question generation (both):
+
+input:
+```
+"answer: Söğüt’te context: Osman Bey 1258 yılında Söğüt’te doğdu. Osman Bey 1 Ağustos 1326’da Bursa’da hayatını kaybetmiştir.1281 yılında Osman Bey 23 yaşında iken Ahi teşkilatından olan Şeyh Edebali’nin kızı Malhun Hatun ile evlendi."
+```
+
+target:
+```
+"Osman Bey nerede doğmuştur?"
+```
+
+
+
+
+paper configs
+
+
+You can find the config files used in the paper under [configs/paper](configs/paper).
+
+
+
+
+
+contributing
+
+
+Before opening a PR:
+
+- Install required development packages:
+
+```bash
+pip install "black==21.7b0" "flake8==3.9.2" "isort==5.9.2"
+```
+
+- Reformat with black and isort:
+
+```bash
+black . --config pyproject.toml
+isort .
+```
+
+
diff --git a/configs/default/config.json b/configs/default/config.json
new file mode 100644
index 0000000..b367344
--- /dev/null
+++ b/configs/default/config.json
@@ -0,0 +1,48 @@
+{
+ "model_name_or_path": "google/mt5-small",
+ "tokenizer_path": "mt5_small_tokenizer",
+ "label_smoothing_factor": 0,
+ "freeze_embeddings": false,
+ "run_name": "exp1",
+ "wandb_project": null,
+ "neptune_project": null,
+ "neptune_api_token": null,
+ "train_dataset_list": [
+ "tquad2-train"
+ ],
+ "valid_dataset_list": [
+ "tquad2-valid"
+ ],
+ "eval_dataset_list": [
+ "tquad2-valid",
+ "xquad.tr"
+ ],
+ "train_file_path": "data/train_data_multitask_mt5.pt",
+ "valid_file_path": "data/valid_data_multitask_mt5.pt",
+ "max_source_length": 512,
+ "max_target_length": 80,
+ "prepare_data": true,
+ "mt5_task_list": [
+ "qa",
+ "qg",
+ "ans_ext"
+ ],
+ "mt5_qg_format": "highlight",
+ "output_dir": "runs/exp1",
+ "do_train": true,
+ "do_eval": true,
+ "evaluation_strategy": "steps",
+ "eval_steps": 2000,
+ "per_device_train_batch_size": 4,
+ "per_device_eval_batch_size": 4,
+ "gradient_accumulation_steps": 1,
+ "eval_accumulation_steps": 1,
+ "learning_rate": 1e-4,
+ "num_train_epochs": 10,
+ "save_total_limit": 1,
+ "no_cuda": false,
+ "seed": 42,
+ "fp16": false,
+ "fp16_full_eval": false,
+ "adafactor": true
+}
\ No newline at end of file
diff --git a/configs/default/config.yaml b/configs/default/config.yaml
new file mode 100644
index 0000000..c125fa8
--- /dev/null
+++ b/configs/default/config.yaml
@@ -0,0 +1,35 @@
+model_name_or_path: "google/mt5-small"
+tokenizer_path: "mt5_small_tokenizer"
+label_smoothing_factor: 0
+freeze_embeddings: false
+run_name: "exp1"
+wandb_project: null
+neptune_project: null
+neptune_api_token: null
+train_dataset_list: ["tquad2-train"]
+valid_dataset_list: ["tquad2-valid"]
+eval_dataset_list: ["tquad2-valid", "xquad.tr"]
+train_file_path: "data/train_data_multitask_mt5.pt"
+valid_file_path: "data/valid_data_multitask_mt5.pt"
+max_source_length: 512
+max_target_length: 80
+prepare_data: true
+mt5_task_list: ["qa", "qg", "ans_ext"]
+mt5_qg_format: "highlight"
+output_dir: "runs/exp1"
+do_train: true
+do_eval: true
+evaluation_strategy: "steps"
+eval_steps: 2000
+per_device_train_batch_size: 4
+per_device_eval_batch_size: 4
+gradient_accumulation_steps: 1
+eval_accumulation_steps: 1
+learning_rate: 1.0e-4
+num_train_epochs: 10
+save_total_limit: 1
+no_cuda: false
+seed: 42
+fp16: false
+fp16_full_eval: false
+adafactor: true
\ No newline at end of file
diff --git a/configs/paper/berturk-qatask-tquad2.yaml b/configs/paper/berturk-qatask-tquad2.yaml
new file mode 100644
index 0000000..dc5a610
--- /dev/null
+++ b/configs/paper/berturk-qatask-tquad2.yaml
@@ -0,0 +1,36 @@
+model_name_or_path: "dbmdz/bert-base-turkish-cased"
+tokenizer_path: "tokenizers/bert-base-turkish-cased_tokenizer"
+label_smoothing: 0
+freeze_embeddings: false
+run_name: "turque-bertbase-adamw-1e4-3ep-tquad2train"
+wandb_project: null
+wandb_id: null
+neptune_project: null
+neptune_run: null
+neptune_api_token:
+train_dataset_list: ["tquad2-train"]
+valid_dataset_list: ["tquad2-valid"]
+eval_dataset_list: ["tquad2-valid", "xquad.tr"]
+train_file_path: "data/train_data.pt"
+valid_file_path: "data/valid_data.pt"
+max_source_length: 512
+max_target_length: 64
+prepare_data: true
+output_dir: "runs/bert-base/turque-bertbase-adamw-1e4-3ep-tquad2train"
+do_train: true
+do_eval: true
+evaluation_strategy: "steps"
+eval_steps: 300
+logging_steps: 300
+per_device_train_batch_size: 32
+per_device_eval_batch_size: 32
+gradient_accumulation_steps: 1
+eval_accumulation_steps: 8
+learning_rate: 1.0e-4
+num_train_epochs: 3
+save_total_limit: 1
+no_cuda: false
+seed: 42
+fp16: false
+fp16_full_eval: false
+adafactor: false
\ No newline at end of file
diff --git a/configs/paper/mt5base-3task-both-tquad2.yaml b/configs/paper/mt5base-3task-both-tquad2.yaml
new file mode 100644
index 0000000..4963108
--- /dev/null
+++ b/configs/paper/mt5base-3task-both-tquad2.yaml
@@ -0,0 +1,35 @@
+model_name_or_path: "google/mt5-base"
+tokenizer_path: "tokenizers/mt5-base"
+label_smoothing_factor: 0
+freeze_embeddings: false
+run_name: "turque-mt5base-3task-adamw-tquad2train"
+wandb_project: null
+neptune_project: null
+neptune_api_token: null
+train_dataset_list: ["tquad2-train"]
+valid_dataset_list: ["tquad2-valid"]
+eval_dataset_list: ["tquad2-valid", "xquad.tr"]
+train_file_path: "data/train_data.pt"
+valid_file_path: "data/valid_data.pt"
+max_source_length: 512
+max_target_length: 64
+prepare_data: true
+mt5_task_list: ["qa", "qg", "ans_ext"]
+mt5_qg_format: "both"
+output_dir: "runs/mt5-base/3task-adamw-tquad2train"
+do_train: true
+do_eval: true
+evaluation_strategy: "steps"
+eval_steps: 300
+per_device_train_batch_size: 32
+per_device_eval_batch_size: 32
+gradient_accumulation_steps: 8
+eval_accumulation_steps: 1
+learning_rate: 1.0e-3
+num_train_epochs: 15
+save_total_limit: 1
+no_cuda: false
+seed: 42
+fp16: false
+fp16_full_eval: false
+adafactor: false
\ No newline at end of file
diff --git a/configs/paper/mt5base-3task-highlight-tquad2.yaml b/configs/paper/mt5base-3task-highlight-tquad2.yaml
new file mode 100644
index 0000000..74def9d
--- /dev/null
+++ b/configs/paper/mt5base-3task-highlight-tquad2.yaml
@@ -0,0 +1,35 @@
+model_name_or_path: "google/mt5-base"
+tokenizer_path: "tokenizers/mt5-base"
+label_smoothing_factor: 0
+freeze_embeddings: false
+run_name: "turque-mt5base-3task-highlight-adamw-tquad2train"
+wandb_project: null
+neptune_project: null
+neptune_api_token: null
+train_dataset_list: ["tquad2-train"]
+valid_dataset_list: ["tquad2-valid"]
+eval_dataset_list: ["tquad2-valid", "xquad.tr"]
+train_file_path: "data/train_data.pt"
+valid_file_path: "data/valid_data.pt"
+max_source_length: 512
+max_target_length: 64
+prepare_data: true
+mt5_task_list: ["qa", "qg", "ans_ext"]
+mt5_qg_format: "highlight"
+output_dir: "runs/mt5-base/3task-highlight-adamw-tquad2train"
+do_train: true
+do_eval: true
+evaluation_strategy: "steps"
+eval_steps: 300
+per_device_train_batch_size: 32
+per_device_eval_batch_size: 32
+gradient_accumulation_steps: 8
+eval_accumulation_steps: 1
+learning_rate: 1.0e-3
+num_train_epochs: 15
+save_total_limit: 1
+no_cuda: false
+seed: 42
+fp16: false
+fp16_full_eval: false
+adafactor: false
\ No newline at end of file
diff --git a/configs/paper/mt5base-3task-prepend-tquad2.yaml b/configs/paper/mt5base-3task-prepend-tquad2.yaml
new file mode 100644
index 0000000..ec19cd3
--- /dev/null
+++ b/configs/paper/mt5base-3task-prepend-tquad2.yaml
@@ -0,0 +1,35 @@
+model_name_or_path: "google/mt5-base"
+tokenizer_path: "tokenizers/mt5-base"
+label_smoothing_factor: 0
+freeze_embeddings: false
+run_name: "turque-mt5base-3task-prepend-adamw-tquad2train"
+wandb_project: null
+neptune_project: null
+neptune_api_token: null
+train_dataset_list: ["tquad2-train"]
+valid_dataset_list: ["tquad2-valid"]
+eval_dataset_list: ["tquad2-valid", "xquad.tr"]
+train_file_path: "data/train_data.pt"
+valid_file_path: "data/valid_data.pt"
+max_source_length: 512
+max_target_length: 64
+prepare_data: true
+mt5_task_list: ["qa", "qg", "ans_ext"]
+mt5_qg_format: "prepend"
+output_dir: "runs/mt5-base/3task-prepend-adamw-tquad2train"
+do_train: true
+do_eval: true
+evaluation_strategy: "steps"
+eval_steps: 300
+per_device_train_batch_size: 32
+per_device_eval_batch_size: 32
+gradient_accumulation_steps: 8
+eval_accumulation_steps: 1
+learning_rate: 1.0e-3
+num_train_epochs: 15
+save_total_limit: 1
+no_cuda: false
+seed: 42
+fp16: false
+fp16_full_eval: false
+adafactor: false
\ No newline at end of file
diff --git a/configs/paper/mt5base-qatask-tquad2.yaml b/configs/paper/mt5base-qatask-tquad2.yaml
new file mode 100644
index 0000000..623b68e
--- /dev/null
+++ b/configs/paper/mt5base-qatask-tquad2.yaml
@@ -0,0 +1,35 @@
+model_name_or_path: "google/mt5-base"
+tokenizer_path: "tokenizers/mt5-base"
+label_smoothing_factor: 0
+freeze_embeddings: false
+run_name: "turque-mt5base-qa-adamw-tquad2train"
+wandb_project: null
+neptune_project: null
+neptune_api_token: null
+train_dataset_list: ["tquad2-train"]
+valid_dataset_list: ["tquad2-valid"]
+eval_dataset_list: ["tquad2-valid", "xquad.tr"]
+train_file_path: "data/train_data.pt"
+valid_file_path: "data/valid_data.pt"
+max_source_length: 512
+max_target_length: 64
+prepare_data: true
+mt5_task_list: ["qa"]
+mt5_qg_format: "both"
+output_dir: "runs/mt5-base/qa-adamw-tquad2train"
+do_train: true
+do_eval: true
+evaluation_strategy: "steps"
+eval_steps: 300
+per_device_train_batch_size: 32
+per_device_eval_batch_size: 32
+gradient_accumulation_steps: 8
+eval_accumulation_steps: 1
+learning_rate: 1.0e-3
+num_train_epochs: 15
+save_total_limit: 1
+no_cuda: false
+seed: 42
+fp16: false
+fp16_full_eval: false
+adafactor: false
\ No newline at end of file
diff --git a/configs/paper/mt5base-qgtask-both-tquad2.yaml b/configs/paper/mt5base-qgtask-both-tquad2.yaml
new file mode 100644
index 0000000..1c64949
--- /dev/null
+++ b/configs/paper/mt5base-qgtask-both-tquad2.yaml
@@ -0,0 +1,35 @@
+model_name_or_path: "google/mt5-base"
+tokenizer_path: "tokenizers/mt5-base"
+label_smoothing_factor: 0
+freeze_embeddings: false
+run_name: "turque-mt5base-qg-adamw-tquad2train"
+wandb_project: null
+neptune_project: null
+neptune_api_token: null
+train_dataset_list: ["tquad2-train"]
+valid_dataset_list: ["tquad2-valid"]
+eval_dataset_list: ["tquad2-valid", "xquad.tr"]
+train_file_path: "data/train_data.pt"
+valid_file_path: "data/valid_data.pt"
+max_source_length: 512
+max_target_length: 64
+prepare_data: true
+mt5_task_list: ["qg"]
+mt5_qg_format: "both"
+output_dir: "runs/mt5-base/qg-adamw-tquad2train"
+do_train: true
+do_eval: true
+evaluation_strategy: "steps"
+eval_steps: 300
+per_device_train_batch_size: 32
+per_device_eval_batch_size: 32
+gradient_accumulation_steps: 8
+eval_accumulation_steps: 1
+learning_rate: 1.0e-3
+num_train_epochs: 15
+save_total_limit: 1
+no_cuda: false
+seed: 42
+fp16: false
+fp16_full_eval: false
+adafactor: false
\ No newline at end of file
diff --git a/configs/paper/mt5base-qgtask-highlight-tquad2.yaml b/configs/paper/mt5base-qgtask-highlight-tquad2.yaml
new file mode 100644
index 0000000..06cfdbb
--- /dev/null
+++ b/configs/paper/mt5base-qgtask-highlight-tquad2.yaml
@@ -0,0 +1,35 @@
+model_name_or_path: "google/mt5-base"
+tokenizer_path: "tokenizers/mt5-base"
+label_smoothing_factor: 0
+freeze_embeddings: false
+run_name: "turque-mt5base-qg-highlight-adamw-tquad2train"
+wandb_project: null
+neptune_project: null
+neptune_api_token: null
+train_dataset_list: ["tquad2-train"]
+valid_dataset_list: ["tquad2-valid"]
+eval_dataset_list: ["tquad2-valid", "xquad.tr"]
+train_file_path: "data/train_data.pt"
+valid_file_path: "data/valid_data.pt"
+max_source_length: 512
+max_target_length: 64
+prepare_data: true
+mt5_task_list: ["qg"]
+mt5_qg_format: "highlight"
+output_dir: "runs/mt5-base/qg-highlight-adamw-tquad2train"
+do_train: true
+do_eval: true
+evaluation_strategy: "steps"
+eval_steps: 300
+per_device_train_batch_size: 32
+per_device_eval_batch_size: 32
+gradient_accumulation_steps: 8
+eval_accumulation_steps: 1
+learning_rate: 1.0e-3
+num_train_epochs: 15
+save_total_limit: 1
+no_cuda: false
+seed: 42
+fp16: false
+fp16_full_eval: false
+adafactor: false
\ No newline at end of file
diff --git a/configs/paper/mt5base-qgtask-prepend-tquad2.yaml b/configs/paper/mt5base-qgtask-prepend-tquad2.yaml
new file mode 100644
index 0000000..d40605a
--- /dev/null
+++ b/configs/paper/mt5base-qgtask-prepend-tquad2.yaml
@@ -0,0 +1,35 @@
+model_name_or_path: "google/mt5-base"
+tokenizer_path: "tokenizers/mt5-base"
+label_smoothing_factor: 0
+freeze_embeddings: false
+run_name: "turque-mt5base-qg-prepend-adamw-tquad2train"
+wandb_project: null
+neptune_project: null
+neptune_api_token: null
+train_dataset_list: ["tquad2-train"]
+valid_dataset_list: ["tquad2-valid"]
+eval_dataset_list: ["tquad2-valid", "xquad.tr"]
+train_file_path: "data/train_data.pt"
+valid_file_path: "data/valid_data.pt"
+max_source_length: 512
+max_target_length: 64
+prepare_data: true
+mt5_task_list: ["qg"]
+mt5_qg_format: "prepend"
+output_dir: "runs/mt5-base/qg-prepend-adamw-tquad2train"
+do_train: true
+do_eval: true
+evaluation_strategy: "steps"
+eval_steps: 300
+per_device_train_batch_size: 32
+per_device_eval_batch_size: 32
+gradient_accumulation_steps: 8
+eval_accumulation_steps: 1
+learning_rate: 1.0e-3
+num_train_epochs: 15
+save_total_limit: 1
+no_cuda: false
+seed: 42
+fp16: false
+fp16_full_eval: false
+adafactor: false
\ No newline at end of file
diff --git a/configs/paper/mt5small-3task-both-tquad2.yaml b/configs/paper/mt5small-3task-both-tquad2.yaml
new file mode 100644
index 0000000..f6f0f6b
--- /dev/null
+++ b/configs/paper/mt5small-3task-both-tquad2.yaml
@@ -0,0 +1,35 @@
+model_name_or_path: "google/mt5-small"
+tokenizer_path: "tokenizers/mt5-small"
+label_smoothing_factor: 0
+freeze_embeddings: false
+run_name: "turque-mt5small-adamw-1e3-15ep-tquad2train"
+wandb_project: null
+neptune_project: null
+neptune_api_token: null
+train_dataset_list: ["tquad2-train"]
+valid_dataset_list: ["tquad2-valid"]
+eval_dataset_list: ["tquad2-valid", "xquad.tr"]
+train_file_path: "data/train_data.pt"
+valid_file_path: "data/valid_data.pt"
+max_source_length: 512
+max_target_length: 64
+prepare_data: true
+mt5_task_list: ["qa", "qg", "ans_ext"]
+mt5_qg_format: "both"
+output_dir: "runs/mt5-small/3task/adamw-1e3-15ep-both-tquad2train"
+do_train: true
+do_eval: true
+evaluation_strategy: "steps"
+eval_steps: 300
+per_device_train_batch_size: 64
+per_device_eval_batch_size: 64
+gradient_accumulation_steps: 4
+eval_accumulation_steps: 1
+learning_rate: 1.0e-3
+num_train_epochs: 15
+save_total_limit: 1
+no_cuda: false
+seed: 42
+fp16: false
+fp16_full_eval: false
+adafactor: false
\ No newline at end of file
diff --git a/core/__init__.py b/core/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/core/api.py b/core/api.py
new file mode 100644
index 0000000..eba3199
--- /dev/null
+++ b/core/api.py
@@ -0,0 +1,297 @@
+import itertools
+import logging
+import os
+import re
+from collections import OrderedDict
+from typing import Dict, List, Optional, Union
+
+from tqdm import tqdm
+
+from hf.model import MT5Model
+from utils.file import load_json
+from utils.nlp import (
+ add_start_end_to_answer_list_per_sentence,
+ normalize_text,
+ postprocess_answer_extraction_output,
+ prepare_answer_extraction_samples,
+ prepare_qa_sample,
+ prepare_qg_samples,
+ sentence_tokenize,
+)
+
+logger = logging.getLogger(__name__)
+logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ level=os.environ.get("LOGLEVEL", "INFO").upper(),
+)
+
+
+class TurQue:
+ def __init__(
+ self,
+ model_url_or_path: str = None,
+ use_cuda: bool = None,
+ max_source_length: int = 512,
+ max_target_length: int = 80,
+ generate_num_beams: int = 4,
+ top_k: int = None,
+ top_p: float = None,
+ qg_format: str = "highlight",
+ ):
+ model_url_or_path = "turque-s1" if model_url_or_path is None else model_url_or_path
+ mt5_model = MT5Model(model_url_or_path, use_cuda=use_cuda)
+ self.model = mt5_model.model
+ self.tokenizer = mt5_model.tokenizer
+ self.model_type = mt5_model.type
+
+ self.max_source_length = max_source_length
+ self.max_target_length = max_target_length
+ self.generate_num_beams = generate_num_beams
+ self.top_k = top_k
+ self.top_p = top_p
+ self.qg_format = qg_format
+
+ def __call__(
+ self,
+ task: str,
+ context: str,
+ question: Optional[str] = None,
+ answer_list: Optional[Union[List[str], List[Dict]]] = None,
+ ):
+ context = normalize_text(context)
+ if task == "answer-extraction":
+ # return answer list using context
+ answer_list = self._generate_answer_list_from_context(context)
+ output = [{"answer": answer["text"]} for answer in answer_list]
+ return output
+ elif task == "question-answering":
+ # return answer list using context and question
+ question = normalize_text(question)
+ answer = self._generate_answer_from_context_and_question(question, context)
+ output = [{"answer": answer}]
+ return output
+ elif task == "question-generation":
+ # return question list using context
+ if answer_list is None:
+ answer_list = self._generate_answer_list_from_context(context)
+
+ if not answer_list:
+ return [{"answer": None, "question": None}]
+ else:
+ _answer_list = []
+ for answer in answer_list:
+ answer["text"] = normalize_text(answer["text"])
+ _answer_list.append(answer)
+ answer_list = _answer_list
+
+ samples = prepare_qg_samples(context, answer_list, qg_format=self.qg_format)
+
+ if samples[0]["answer"] is None:
+ return [{"answer": None, "question": None}]
+ else:
+ inputs = [sample["source_text"] for sample in samples]
+
+ # single generation without padding is 5 times faster than padding + batch generation
+ question_list = []
+ for input in inputs:
+ question = self._generate_question_list([input], padding=False)[0]
+ question_list.append(question)
+
+ output = [
+ {"answer": sample["answer"], "question": question}
+ for sample, question in zip(samples, question_list)
+ ]
+ return output
+ else:
+ raise NameError(
+ f"{task} is not defined. 'task' must be one of ['answer-extraction', 'question-answering', 'question-generation']"
+ )
+
+ def _generate_question_list(self, inputs, padding=True, truncation=True):
+ inputs = self._tokenize(inputs, padding=padding, truncation=truncation)
+
+ outs = self.model.generate(
+ input_ids=inputs["input_ids"].to(self.model.device),
+ attention_mask=inputs["attention_mask"].to(self.model.device),
+ max_length=self.max_target_length,
+ num_beams=self.generate_num_beams,
+ top_k=self.top_k,
+ top_p=self.top_p,
+ )
+
+ question_list = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in outs]
+ return question_list
+
+ def _generate_answer_list_from_context(self, context):
+ samples = prepare_answer_extraction_samples(context=context)
+ inputs = [sample["source_text"] for sample in samples]
+
+ inputs = self._tokenize(inputs, padding=True, truncation=True)
+
+ outs = self.model.generate(
+ input_ids=inputs["input_ids"].to(self.model.device),
+ attention_mask=inputs["attention_mask"].to(self.model.device),
+ max_length=self.max_target_length,
+ )
+
+ output_list = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in outs]
+
+ answer_list_per_sentence = []
+ for output in output_list:
+ # postprocess answer extraction output list
+ answer_text_list = postprocess_answer_extraction_output(output)
+ answer_list = [{"text": normalize_text(answer_text)} for answer_text in answer_text_list]
+ answer_list_per_sentence.append(answer_list)
+
+ sentence_list = sentence_tokenize(context)
+ answer_list = add_start_end_to_answer_list_per_sentence(sentence_list, answer_list_per_sentence)
+
+ return answer_list
+
+ def _tokenize(self, inputs, padding=True, truncation=True, add_special_tokens=True):
+ inputs = self.tokenizer.batch_encode_plus(
+ inputs,
+ max_length=self.max_source_length,
+ add_special_tokens=add_special_tokens,
+ truncation=truncation,
+ padding="max_length" if padding else False,
+ pad_to_max_length=padding,
+ return_tensors="pt",
+ )
+ return inputs
+
+ def _generate_answer_from_context_and_question(self, question, context):
+ sample = prepare_qa_sample(context, question)
+ source_text = sample["source_text"]
+ inputs = self._tokenize([source_text], padding=False)
+
+ outs = self.model.generate(
+ input_ids=inputs["input_ids"].to(self.model.device),
+ attention_mask=inputs["attention_mask"].to(self.model.device),
+ max_length=self.max_target_length,
+ )
+
+ answer = self.tokenizer.decode(outs[0], skip_special_tokens=True)
+
+ return answer
+
+ def qg_from_file(self, path_or_dict: str, use_answers: bool = False, **kwargs):
+ """performs question-generation using the contexts and answers
+ from squad formatted json file or answers extracted by the model"""
+
+ for k, v in kwargs.items():
+ setattr(self, k, v)
+ task = "question-generation"
+
+ # read data from path or dict
+ if isinstance(path_or_dict, str):
+ data = load_json(path_or_dict)["data"]
+ else:
+ data = path_or_dict["data"]
+
+ out = {"data": []}
+ for article in tqdm(data, desc="Question Generation from articles"):
+ out_article = {"paragraphs": [], "title": article.get("title")}
+ # iterate over each paragraph
+ for paragraph in article["paragraphs"]:
+ context = paragraph["context"]
+ out_para = {"context": context, "qas": []}
+ if use_answers and paragraph["qas"]:
+ answer_list = []
+ paragraph["qas"] = sorted(paragraph["qas"], key=lambda k: k["answers"][0]["answer_start"])
+ for qa in paragraph["qas"]:
+ answer = qa["answers"][0]
+ answer_list.append(answer)
+ elif use_answers and not paragraph["qas"]: # pass if paragraph["qas"] is empty
+ continue
+ else:
+ answer_list = None
+ qg_out = self(task=task, context=context, answer_list=answer_list)
+ if qg_out[0]["question"] is None:
+ continue
+ for qa, gold_qa in zip(qg_out, paragraph["qas"]):
+ if use_answers:
+ qa["gold_question"] = gold_qa["question"]
+ out_para["qas"].append(qa)
+ out_article["paragraphs"].append(out_para)
+ out["data"].append(out_article)
+ return out
+
+ def qa_from_file(self, path_or_dict: str, **kwargs):
+ """performs question-answering using the contexts and questions
+ from squad formatted json file"""
+
+ for k, v in kwargs.items():
+ setattr(self, k, v)
+ task = "question-answering"
+
+ # read data from path or dict
+ if isinstance(path_or_dict, str):
+ data = load_json(path_or_dict)["data"]
+ else:
+ data = path_or_dict["data"]
+
+ out = {"data": []}
+ for article in tqdm(data, desc="Question Answering from articles"):
+ out_article = {"paragraphs": [], "title": article.get("title")}
+ # iterate over each paragraph
+ for paragraph in article["paragraphs"]:
+ context = paragraph["context"]
+ out_para = {"context": context, "qas": []}
+
+ # extract questions from dataset paragraph and get answer from model
+ for qa in paragraph["qas"]:
+ question = qa.get("question")
+ if question is not None:
+ qa_out = self(task=task, context=context, question=question)[0]
+ # append q&a pair into out_para
+ qa_out["gold_answer"] = qa["answers"][0]["text"]
+ qa_out["question"] = question
+ out_para["qas"].append(qa_out)
+ else:
+ logger.warning("skipping a paragraph without questions.")
+
+ out_article["paragraphs"].append(out_para)
+ out["data"].append(out_article)
+ return out
+
+ def ans_ext_from_file(self, path_or_dict: str, **kwargs):
+ """performs answer-extraction using the contexts from squad formatted json file"""
+
+ for k, v in kwargs.items():
+ setattr(self, k, v)
+ task = "answer-extraction"
+
+ # read data from path or dict
+ if isinstance(path_or_dict, str):
+ data = load_json(path_or_dict)["data"]
+ else:
+ data = path_or_dict["data"]
+
+ out = {"data": []}
+ for article in tqdm(data, desc="Answer Extraction from articles"):
+ out_article = {"paragraphs": [], "title": article.get("title")}
+ # iterate over each paragraph
+ for paragraph in article["paragraphs"]:
+ context = paragraph["context"]
+ out_para = {"context": context, "gold_answer_list": [], "predicted_answer_list": []}
+
+ if paragraph["qas"]:
+ gold_answer_list = []
+ paragraph["qas"] = sorted(paragraph["qas"], key=lambda k: k["answers"][0]["answer_start"])
+ for qa in paragraph["qas"]:
+ answer = qa["answers"][0]
+ gold_answer_list.append(answer["text"])
+ else:
+ logger.warning("skipping a paragraph without q/a's.")
+ # extract answers
+ ans_ext_out = self(task=task, context=context)
+ # add gold and predicted answers
+ predicted_answer_list = [output["answer"] for output in ans_ext_out]
+ out_para["gold_answer_list"] = gold_answer_list
+ out_para["predicted_answer_list"] = predicted_answer_list
+
+ out_article["paragraphs"].append(out_para)
+ out["data"].append(out_article)
+ return out
diff --git a/core/argument_parsers.py b/core/argument_parsers.py
new file mode 100644
index 0000000..f79089b
--- /dev/null
+++ b/core/argument_parsers.py
@@ -0,0 +1,162 @@
+import logging
+import os
+import sys
+from dataclasses import dataclass, field
+from pathlib import Path
+from typing import Dict, List, Optional
+
+from transformers import HfArgumentParser, TrainingArguments
+
+from utils.file import read_yaml
+
+logger = logging.getLogger(__name__)
+logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ level=os.environ.get("LOGLEVEL", "INFO").upper(),
+)
+
+
+@dataclass
+class ModelArguments:
+ """
+ Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
+ """
+
+ model_name_or_path: str = field(
+ default="google/mt5-small",
+ metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"},
+ )
+ tokenizer_path: Optional[str] = field(
+ default=None, metadata={"help": "Pretrained tokenizer path to be saved/loaded"}
+ )
+ cache_dir: Optional[str] = field(
+ default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
+ )
+
+ wandb_project: Optional[str] = field(default=None, metadata={"help": "Wandb project for experiment tracking"})
+ wandb_id: Optional[str] = field(
+ default=None, metadata={"help": "Wandb run id, will be used when 'do_train=False', 'do_eval=True'"}
+ )
+ neptune_project: Optional[str] = field(default=None, metadata={"help": "Neptune project for experiment tracking"})
+ neptune_run: Optional[str] = field(
+ default=None, metadata={"help": "Neptune run id, will be used when 'do_train=False', 'do_eval=True'"}
+ )
+ neptune_api_token: Optional[str] = field(
+ default=None, metadata={"help": "Neptune api token for experiment tracking"}
+ )
+ model_type: str = field(default=None, metadata={"help": "'mt5' or 'bert'"})
+
+
+@dataclass
+class DataArguments:
+ """
+ Arguments pertaining to what data we are going to input our model for training and eval.
+ """
+
+ train_file_path: Optional[str] = field(
+ default="data/train_data_multitask_mt5.pt",
+ metadata={"help": "name for cached train dataset"},
+ )
+ valid_file_path: Optional[str] = field(
+ default="data/valid_data_multitask_mt5.pt",
+ metadata={"help": "name for cached valid dataset"},
+ )
+
+ prepare_data: bool = field(
+ default=True,
+ metadata={
+ "help": "Runs prepare_data.py before starting to train. Set if false if you alreade prepared data before."
+ },
+ )
+
+
+@dataclass
+class ExtendedTrainingArguments(TrainingArguments):
+ train_dataset_list: Optional[List[str]] = field(
+ default_factory=lambda: ["tquad-train"],
+ metadata={"help": "dataset name list of the training"},
+ )
+ valid_dataset_list: Optional[List[str]] = field(
+ default_factory=lambda: ["tquad-val"],
+ metadata={"help": "dataset name list of the validation"},
+ )
+ eval_dataset_list: Optional[List[str]] = field(
+ default_factory=lambda: ["tquad-val", "xquad.tr"],
+ metadata={"help": "dataset name list of the evaluation"},
+ )
+ freeze_embeddings: bool = field(
+ default=False,
+ metadata={"help": "Freeze token embeddings."},
+ )
+ max_source_length: Optional[int] = field(
+ default=512,
+ metadata={"help": "Max input length for the source text"},
+ )
+ max_target_length: Optional[int] = field(
+ default=80,
+ metadata={"help": "Max input length for the target text"},
+ )
+ mt5_task_list: Optional[List[str]] = field(
+ default_factory=lambda: ["qa", "qg", "ans_ext"],
+ metadata={"help": "task list for mt5"},
+ )
+ mt5_qg_format: Optional[str] = field(
+ default="highlight",
+ metadata={"help": 'mt5 qg format as "highlight", "prepend" or "both"'},
+ )
+
+
+def parser(args_file_path: Optional[str] = None):
+ # See all possible arguments in src/transformers/training_args.py
+ # or by passing the --help flag to this script.
+ # We now keep distinct sets of args, for a cleaner separation of concerns.
+
+ parser = HfArgumentParser((ModelArguments, DataArguments, ExtendedTrainingArguments))
+
+ accepted_args_file_suffixes = (".json", ".yml", ".yaml")
+
+ # validate args file suffix
+ args_file_suffix: Optional[str] = None
+ if args_file_path:
+ # If we pass args_file_path to the script and it's the path to a json/yml/yaml file,
+ # let's parse it to get our arguments.
+ assert type(args_file_path) == str, TypeError(f"invalid 'args_file_path': {args_file_path}")
+ args_file_suffix = Path(args_file_path).suffix
+ assert args_file_suffix in accepted_args_file_suffixes, TypeError(
+ f"""args file should be one of: /
+ {accepted_args_file_suffixes}, invalid args file format: {args_file_suffix}"""
+ )
+ elif len(sys.argv) == 2:
+ # If we pass only one argument to the script and it's the path to a json/yml/yaml file,
+ # let's parse it to get our arguments.
+ args_file_path = sys.argv[1]
+ args_file_suffix = Path(args_file_path).suffix
+ assert args_file_suffix in accepted_args_file_suffixes, TypeError(
+ f"""args file should be one of: /
+ {accepted_args_file_suffixes}, invalid args file format: {args_file_suffix}"""
+ )
+
+ if args_file_suffix == ".json":
+ model_args, data_args, training_args = parser.parse_json_file(json_file=args_file_path)
+ elif args_file_suffix in (".yml", ".yaml"):
+ args_dict = read_yaml(args_file_path)
+ model_args, data_args, training_args = parser.parse_dict(args=args_dict)
+ else:
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
+
+ if model_args.tokenizer_path is None:
+ model_name = (model_args.model_name_or_path).split("/")[-1]
+ model_args.tokenizer_path = model_name + "_tokenizer"
+
+ # overwrite model type
+ if "mt5" in model_args.model_name_or_path:
+ model_type = "mt5"
+ elif "bert" in model_args.model_name_or_path:
+ model_type = "bert"
+ else:
+ logger.info("couldnt infer model type from 'model_name_or_path', assuming its 'mt5'.")
+ model_type = "mt5"
+ model_args.model_type = model_type
+
+ return model_args, data_args, training_args
diff --git a/core/bert_api.py b/core/bert_api.py
new file mode 100644
index 0000000..8677388
--- /dev/null
+++ b/core/bert_api.py
@@ -0,0 +1,129 @@
+import logging
+import os
+from typing import Tuple, Union
+
+import numpy as np
+import torch
+from tqdm import tqdm
+
+from hf.model import BertModel
+from utils.file import load_json
+
+logger = logging.getLogger(__name__)
+logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ level=os.environ.get("LOGLEVEL", "INFO").upper(),
+)
+
+
+def _postprocess_output(
+ model_output, top_n_best_answers: int = 20, max_answer_length: int = 64
+) -> Tuple[Union[int, int]]:
+ """returns valid_answer_list (List[Dict]) with each elements having
+ score (float), answer_start (int), answer_end (int) keys"""
+
+ start_logits = model_output.start_logits[0].cpu().detach().numpy()
+ end_logits = model_output.end_logits[0].cpu().detach().numpy()
+ # Gather the indices the best start/end logits:
+ start_indexes = np.argsort(start_logits)[-1 : -top_n_best_answers - 1 : -1].tolist()
+ end_indexes = np.argsort(end_logits)[-1 : -top_n_best_answers - 1 : -1].tolist()
+ valid_answers = []
+ for start_index in start_indexes:
+ for end_index in end_indexes:
+ # Don't consider answers with a length that is either < 0 or > max_answer_length.
+ if end_index < start_index or end_index - start_index + 1 > max_answer_length:
+ continue
+ valid_answers.append(
+ {
+ "score": start_logits[start_index] + end_logits[end_index],
+ "answer_start": start_index,
+ "answer_end": end_index,
+ }
+ )
+ valid_answer_list = sorted(valid_answers, key=lambda x: x["score"], reverse=True)
+ return valid_answer_list
+
+
+def qa_from_file(
+ path_or_dict: str,
+ model_url_or_path: str,
+ use_cuda: bool = True,
+ top_n_best_answers: int = 20,
+ max_source_lenght: int = 512,
+ max_target_length: int = 80,
+):
+ model = BertModel(model_url_or_path, use_cuda=use_cuda)
+
+ # read data from path or dict
+ if isinstance(path_or_dict, str):
+ data = load_json(path_or_dict)["data"]
+ else:
+ data = path_or_dict["data"]
+
+ out = {"data": []}
+ for article in tqdm(data, desc="Answer Generation using Bert from articles"):
+ out_article = {"paragraphs": [], "title": article.get("title")}
+ # iterate over each paragraph
+ for paragraph in article["paragraphs"]:
+ context = paragraph["context"]
+ out_para = {"context": context, "qas": []}
+
+ # extract questions from dataset paragraph and get answer from model
+ for qa in paragraph["qas"]:
+ question = qa.get("question")
+ if question is not None:
+
+ inputs = model.tokenizer.encode_plus(
+ context,
+ question,
+ max_length=max_source_lenght,
+ add_special_tokens=True,
+ padding=False,
+ pad_to_max_length=False,
+ return_tensors="pt",
+ truncation=True,
+ )
+ input_ids = inputs["input_ids"].tolist()[0]
+
+ model_output = model.model(torch.tensor([input_ids]).to(model.device))
+ # answer_dict_list = _postprocess_output(
+ # model_output, top_n_best_answers=top_n_best_answers, max_answer_length=max_target_length
+ # )
+ answer_start_scores, answer_end_scores = model_output["start_logits"], model_output["end_logits"]
+
+ answer_start = torch.argmax(
+ answer_start_scores
+ ) # Get the most likely beginning of answer with the argmax of the score
+ answer_end = (
+ torch.argmax(answer_end_scores) + 1
+ ) # Get the most likely end of answer with the argmax of the score
+
+ # answer_list = []
+ # for answer_dict in answer_dict_list:
+ # answer = model.tokenizer.convert_tokens_to_string(
+ # model.tokenizer.convert_ids_to_tokens(
+ # input_ids[answer_dict["answer_start"] : answer_dict["answer_end"]]
+ # )
+ # )
+ # answer_list.append(answer)
+ # append q&a pair into out_para
+ if answer_end > answer_start:
+ answer = model.tokenizer.convert_tokens_to_string(
+ model.tokenizer.convert_ids_to_tokens(input_ids[answer_start:answer_end])
+ )
+ else:
+ answer = ""
+ qa_out = {
+ "answer": answer,
+ # "alternative_answers": answer_list[1:],
+ "gold_answer": qa["answers"][0]["text"],
+ "question": question,
+ }
+ out_para["qas"].append(qa_out)
+ else:
+ logger.warning("skipping a paragraph without questions.")
+
+ out_article["paragraphs"].append(out_para)
+ out["data"].append(out_article)
+ return out
diff --git a/core/collator.py b/core/collator.py
new file mode 100644
index 0000000..cd7aee9
--- /dev/null
+++ b/core/collator.py
@@ -0,0 +1,81 @@
+from typing import Dict, List
+
+import torch
+
+# taken from https://github.com/patil-suraj/question_generation/blob/master/data_collator.py
+
+
+def trim_batch(
+ input_ids,
+ pad_token_id,
+ attention_mask=None,
+):
+ """Remove columns that are populated exclusively by pad_token_id"""
+ keep_column_mask = input_ids.ne(pad_token_id).any(dim=0)
+ if attention_mask is None:
+ return input_ids[:, keep_column_mask]
+ else:
+ return (input_ids[:, keep_column_mask], attention_mask[:, keep_column_mask])
+
+
+# prepares lm_labels from target_ids, returns examples with keys as expected by the forward method
+# this is necessary because the trainer directly passes this dict as arguments to the model
+# so make sure the keys match the parameter names of the forward method
+class T2TDataCollator:
+ def __init__(self, tokenizer, mode="training", using_tpu=False):
+ self.tokenizer = tokenizer
+ self.using_tpu = using_tpu
+ self.mode = mode
+
+ def __call__(self, batch: List) -> Dict[str, torch.Tensor]:
+ """
+ Take a list of samples from a Dataset and collate them into a batch.
+ Returns:
+ A dictionary of tensors
+ """
+ input_ids = torch.stack([example["source_ids"] for example in batch])
+ target_ids = torch.stack([example["target_ids"] for example in batch])
+ attention_mask = torch.stack([example["attention_mask"] for example in batch])
+
+ pad_token_id = self.tokenizer.pad_token_id
+
+ # don't trim on tpu, for some reason trimming leads to slower training on TPU
+ if not self.using_tpu:
+ input_ids, attention_mask = trim_batch(input_ids, pad_token_id, attention_mask=attention_mask)
+ target_ids = trim_batch(target_ids, pad_token_id)
+
+ lm_labels = target_ids.clone()
+ decoder_input_ids = self._shift_right_t5(lm_labels)
+ if self.mode == "training":
+ lm_labels[lm_labels[:, :] == pad_token_id] = -100
+
+ params = {
+ "input_ids": input_ids,
+ "attention_mask": attention_mask,
+ "labels": lm_labels,
+ "decoder_input_ids": decoder_input_ids,
+ }
+
+ return params
+
+ def _shift_right_t5(self, input_ids):
+ decoder_start_token_id = self.tokenizer.pad_token_id
+ pad_token_id = self.tokenizer.pad_token_id
+
+ assert (
+ decoder_start_token_id is not None
+ ), """self.model.config.decoder_start_token_id has to be defined. In T5 it is usually set to the pad_token_id.
+ See T5 docs for more information"""
+
+ # shift inputs to the right
+ shifted_input_ids = input_ids.new_zeros(input_ids.shape)
+ shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
+ shifted_input_ids[..., 0] = decoder_start_token_id
+
+ assert pad_token_id is not None, "self.model.config.pad_token_id has to be defined."
+ # replace possible -100 values in labels by `pad_token_id`
+ shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
+
+ assert torch.all(shifted_input_ids >= 0).item(), "Verify that `labels` has only positive values and -100"
+
+ return shifted_input_ids
diff --git a/core/dataset_parsers.py b/core/dataset_parsers.py
new file mode 100644
index 0000000..9ee57df
--- /dev/null
+++ b/core/dataset_parsers.py
@@ -0,0 +1,195 @@
+import os
+from typing import Dict, List
+
+from datasets import Dataset
+
+from utils.file import load_json
+from utils.nlp import (
+ download_and_load_dataset,
+ normalize_text,
+ prepare_answer_extraction_samples,
+ prepare_qa_sample,
+ prepare_qg_samples,
+)
+
+# former url train data was unfit for (pyarrow) datasets.Dataset (file maybe corrupt).
+
+_TQUAD_URL = "https://github.com/fcakyon/turkish-qa-datasets/releases/download/0.0.1/"
+_TQUAD1_DEV_FILE = "tquad_dev_data_v1.json"
+_TQUAD1_TRAINING_FILE = "tquad_train_data_v1.json"
+_TQUAD2_DEV_FILE = "tquad_dev_data_v2.json"
+_TQUAD2_TRAINING_FILE = "tquad_train_data_v2.json"
+_TQUAD_DEV_SMALL_FILE = "tquad_dev_small_data.json"
+_TQUAD_LOCAL_DIR = "data/tquad/"
+_XQUAD_TR_URL = "https://raw.githubusercontent.com/deepmind/xquad/master/"
+_XQUAD_LOCAL_DIR = "data/xquad"
+
+
+def load_tquad_data(data_name="tquad2-train") -> Dataset:
+ if "tquad1-train" in data_name:
+ tquad_url = os.path.join(_TQUAD_URL, _TQUAD1_TRAINING_FILE)
+ tquad_local_path = os.path.join(_TQUAD_LOCAL_DIR, _TQUAD1_TRAINING_FILE)
+ elif "tquad2-train" in data_name:
+ tquad_url = os.path.join(_TQUAD_URL, _TQUAD2_TRAINING_FILE)
+ tquad_local_path = os.path.join(_TQUAD_LOCAL_DIR, _TQUAD2_TRAINING_FILE)
+ elif "tquad1-valid" in data_name:
+ tquad_url = os.path.join(_TQUAD_URL, _TQUAD1_DEV_FILE)
+ tquad_local_path = os.path.join(_TQUAD_LOCAL_DIR, _TQUAD1_DEV_FILE)
+ elif "tquad2-valid" in data_name:
+ tquad_url = os.path.join(_TQUAD_URL, _TQUAD2_DEV_FILE)
+ tquad_local_path = os.path.join(_TQUAD_LOCAL_DIR, _TQUAD2_DEV_FILE)
+ elif "small" in data_name:
+ tquad_url = os.path.join(_TQUAD_URL, _TQUAD_DEV_SMALL_FILE)
+ tquad_local_path = os.path.join(_TQUAD_LOCAL_DIR, _TQUAD_DEV_SMALL_FILE)
+ else:
+ raise ValueError(
+ f"Unknown data_name {data_name}, must be one of ['tquad1-train', 'tquad2-train', 'tquad1-valid', 'tquad2-valid', 'tquad.small']"
+ )
+
+ return download_and_load_dataset(tquad_url, tquad_local_path)
+
+
+def load_xquad_data(data_name="xquad.tr"):
+ """XQuad dataset has only validation split."""
+ xquad_url = os.path.join(_XQUAD_TR_URL, data_name + ".json")
+ xquad_local_path = os.path.join(_XQUAD_LOCAL_DIR, data_name + ".json")
+ return download_and_load_dataset(xquad_url, xquad_local_path)
+
+
+def prepare_data_for_bert(data: Dataset) -> List[Dict]:
+ """
+ Args:
+ data: squad data
+
+ Returns: Processed samples as list in bert input format.
+ """
+ samples = []
+ data = data["data"]
+
+ for group in data:
+ for passage in group["paragraphs"]:
+ context = passage["context"]
+ if passage["qas"]:
+ for qa in passage["qas"]:
+ question = qa["question"]
+ # for answer in qa["answers"]:
+ answer = qa["answers"][0]
+
+ gold_text = answer["text"]
+ start_idx = answer["answer_start"]
+ end_idx = start_idx + len(gold_text)
+
+ # sometimes squad answers are off by a character or two – fix this
+ if context[start_idx:end_idx] == gold_text:
+ answer["answer_end"] = end_idx
+ elif context[start_idx - 1 : end_idx - 1] == gold_text:
+ answer["answer_start"] = start_idx - 1
+ answer["answer_end"] = end_idx - 1 # When the gold label is off by one character
+ elif context[start_idx - 2 : end_idx - 2] == gold_text:
+ answer["answer_start"] = start_idx - 2
+ answer["answer_end"] = end_idx - 2 # When the gold label is off by two characters
+ elif context[start_idx - 3 : end_idx - 3] == gold_text:
+ answer["answer_start"] = start_idx - 3
+ answer["answer_end"] = end_idx - 3 # When the gold label is off by three characters
+ else:
+ print(
+ f"skipping the answer|answer_start|context {answer['text']}|{answer['answer_start']}|{context} | for reason: 'answer indexes are off by a lot'"
+ )
+ continue
+
+ sample = {"context": context, "question": question, "answer": answer}
+ samples.append(sample)
+ return samples
+
+
+def prepare_data_for_mt5(
+ data: Dataset, task_list: List[str] = ["ans_ext", "qa", "qg"], qg_format="highlight"
+) -> List[Dict]:
+ """
+ Args:
+ data: squad data
+ task_list: list of tasks to data be prepared
+ qg_format: "highlight", "prepend" or "both"
+
+ Returns: Processed samples as list in mt5 input format.
+ """
+ samples = []
+ data = data["data"]
+
+ for article in data:
+ for paragraph in article["paragraphs"]:
+ context = paragraph["context"]
+ question_list = []
+ answer_list = []
+
+ if paragraph["qas"]: # pass if paragraph["qas"] is empty
+ for qa in paragraph["qas"]:
+ question = normalize_text(qa["question"])
+ answer = qa["answers"][0]
+ answer["text"] = answer["text"]
+ qa_sample = prepare_qa_sample(context=context, question=question, answer=answer["text"])
+ if qa_sample["target_text"] is not None:
+ if "qa" in task_list:
+ samples.append(qa_sample)
+ question_list.append(question)
+ answer_list.append(answer)
+
+ if answer_list and question_list:
+ qg_samples = prepare_qg_samples(
+ context=context, answer_list=answer_list, question_list=question_list, qg_format=qg_format
+ )
+ if qg_samples[0]["answer"] is not None:
+ if "qg" in task_list:
+ samples.extend(qg_samples)
+
+ answer_extraction_samples = prepare_answer_extraction_samples(context=context, answer_list=answer_list)
+ for answer_extraction_sample in answer_extraction_samples:
+ if answer_extraction_sample["target_text"] is not None:
+ if "ans_ext" in task_list:
+ samples.extend(answer_extraction_samples)
+ return samples
+
+
+def prepare_data(
+ data: Dict, target_format="mt5", mt5_task_list: List[str] = ["ans_ext", "qa", "qg"], mt5_qg_format="highlight"
+):
+ """
+ Args:
+ target_format (str): output format ('mt5' or 'bert')
+ mt5_task_list: list of tasks for mt5 data to be prepared
+ mt5_qg_format: "highlight", "prepend" or "both"
+ """
+ if target_format == "mt5":
+ samples = prepare_data_for_mt5(data, mt5_task_list, mt5_qg_format)
+ elif target_format == "bert":
+ samples = prepare_data_for_bert(data)
+ return samples
+
+
+def load_dataset(name_or_path: str):
+ if os.path.isfile(name_or_path):
+ data = load_json(name_or_path)
+ elif "tquad" in name_or_path:
+ data = load_tquad_data(name_or_path)
+ elif "xquad" in name_or_path:
+ data = load_xquad_data(name_or_path)
+ else:
+ raise ValueError(f"Unknown dataset {name_or_path}.")
+
+ return data
+
+
+def load_and_prepare_dataset(
+ name_or_path: str,
+ target_format="mt5",
+ mt5_task_list: List[str] = ["ans_ext", "qa", "qg"],
+ mt5_qg_format="highlight",
+):
+ """
+ Args:
+ target_format (str): output format ('mt5' or 'bert')
+ mt5_task_list: list of tasks for mt5 data to be prepared
+ mt5_qg_format: "highlight", "prepend" or "both"
+ """
+ data = load_dataset(name_or_path)
+ return prepare_data(data, target_format=target_format, mt5_task_list=mt5_task_list, mt5_qg_format=mt5_qg_format)
diff --git a/core/evaluate.py b/core/evaluate.py
new file mode 100644
index 0000000..8499936
--- /dev/null
+++ b/core/evaluate.py
@@ -0,0 +1,289 @@
+import logging
+import os
+from pathlib import Path
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+from jury import Jury
+from jury.metrics import load_metric
+
+from core.dataset_parsers import load_dataset
+from core.generate import generate
+from utils.file import load_json, save_json
+
+logger = logging.getLogger(__name__)
+logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ level=os.environ.get("LOGLEVEL", "INFO").upper(),
+)
+
+
+TASK_TO_METRIC = {
+ "ans_ext": [
+ load_metric("f1", task="language-generation"),
+ load_metric("precision", task="language-generation"),
+ load_metric("bertscore", compute_kwargs={"lang": "tr"}),
+ ],
+ "qa": [load_metric("squad"), load_metric("bertscore", compute_kwargs={"lang": "tr"})],
+ "qg": [
+ load_metric("bleu", compute_kwargs={"max_order": 1}),
+ load_metric("bleu", compute_kwargs={"max_order": 2}),
+ load_metric("bleu", compute_kwargs={"max_order": 3}),
+ load_metric("bleu", compute_kwargs={"max_order": 4}),
+ load_metric("rouge"),
+ load_metric("bertscore", compute_kwargs={"lang": "tr"}),
+ ],
+}
+
+
+class Evaluation:
+ r"""
+ Simple evaluation pipeline for text based metrics. By default it computes BLEU(n),
+ METEOR, ROUGE-L and SacreBLEU metrics. It supports both QA and QG evaluation, when BMR metrics
+ are given, it runs a QG Evaluation, for QA Evaluation construct the object with
+ "squad".
+
+ Note:
+
+ If ``predictions`` and ``references`` are given as list of strings, the order is recieved
+ as prediction & reference pairs and evaluation is done by prioratizing the order.
+ """
+
+ def __init__(self, metrics: Optional[List[str]] = None):
+ self.metrics = metrics
+
+ @staticmethod
+ def task_to_metric(task: str) -> List[str]:
+ return TASK_TO_METRIC.get(task)
+
+ @staticmethod
+ def metric_to_task(metric: str) -> str:
+ for task, metrics in TASK_TO_METRIC.items():
+ if metric in metrics:
+ return task
+
+ @staticmethod
+ def get_tasks():
+ return list(TASK_TO_METRIC.keys())
+
+ @staticmethod
+ def _get_task_related_samples(
+ desired_task: str,
+ predictions: Union[List[str], List[Dict]],
+ references: Union[List[str], List[Dict]],
+ tasks: Optional[List[str]] = None,
+ ):
+ if tasks is None:
+ return predictions, references
+
+ selected_predictions = []
+ selected_references = []
+ for prediction, reference, task in zip(predictions, references, tasks):
+ if task == desired_task:
+ selected_predictions.append(prediction)
+ selected_references.append(reference)
+ return selected_predictions, selected_references
+
+ def run(
+ self,
+ predictions: Union[List[str], List[Dict]],
+ references: Union[List[str], List[Dict]],
+ tasks: List[str],
+ ) -> Dict[str, Any]:
+ scores = {}
+
+ for task in TASK_TO_METRIC:
+ metrics = self.task_to_metric(task)
+ scorer = Jury(metrics=metrics, run_concurrent=True)
+ selected_predictions, selected_references = self._get_task_related_samples(
+ desired_task=task, predictions=predictions, references=references, tasks=tasks
+ )
+ save_json(
+ {"predictions": selected_predictions, "references": selected_references},
+ task + "_outputs_during_training.json",
+ )
+ task_scores = scorer(predictions=selected_predictions, references=selected_references)
+ for key, value in task_scores.items():
+ scores[task + "_" + key] = value
+
+ return scores
+
+
+def _get_qa_predictions_references(data: Dict) -> Tuple[List, List]:
+ predictions = []
+ references = []
+ for article in data:
+ for paragraph in article["paragraphs"]:
+ for qa in paragraph["qas"]:
+ predictions.append(qa["answer"])
+ references.append(qa["gold_answer"])
+ return predictions, references
+
+
+def _get_qg_predictions_references(data: Dict) -> Tuple[List, List]:
+ predictions = []
+ references = []
+ for article in data:
+ for paragraph in article["paragraphs"]:
+ for i, qa in enumerate(paragraph["qas"]):
+ predictions.append(qa["question"])
+ references.append(qa["gold_question"])
+ return predictions, references
+
+
+def _get_ans_ext_predictions_references(data: Dict, dont_combine_answer_list: bool = False) -> Tuple[List, List]:
+ predictions = []
+ references = []
+ for article in data:
+ for paragraph in article["paragraphs"]:
+ if dont_combine_answer_list:
+ predictions.append(paragraph["predicted_answer_list"])
+ references.append(paragraph["gold_answer_list"])
+ else:
+ predictions.append(" ".join(paragraph["predicted_answer_list"]))
+ references.append(" ".join(paragraph["gold_answer_list"]))
+ return predictions, references
+
+
+def evaluate_from_file(path: str, task: str, output: str = None, dont_combine_answer_list: bool = False):
+ # prepare predictions & references
+ data = load_json(path)["data"]
+ if task == "qa":
+ predictions, references = _get_qa_predictions_references(data)
+ elif task == "qg":
+ predictions, references = _get_qg_predictions_references(data)
+ elif task == "ans_ext":
+ predictions, references = _get_ans_ext_predictions_references(data, dont_combine_answer_list)
+ else:
+ raise ValueError("Unknown task. Must be one of [qa, qg, ans_ext]")
+ # prepare metrics
+ metrics = TASK_TO_METRIC.get(task)
+
+ try:
+ # calculate scores
+ scorer = Jury(metrics=metrics, run_concurrent=False)
+ scores = scorer(predictions=predictions, references=references)
+ # log results
+ logger.info(scores)
+ # export result
+ if output is None:
+ export_path = str(Path(path).parent / (task + "_eval_scores.json"))
+ else:
+ export_path = output
+ save_json(
+ scores,
+ export_path,
+ )
+ return scores
+ except Exception as e:
+ logger.warning(e)
+ return None
+
+
+def evaluate_on_train_end(model_args, training_args):
+ logger.info("*** Evaluate on train end ***")
+
+ if model_args.model_type == "mt5":
+ eval_tasks = training_args.mt5_task_list
+ elif model_args.model_type == "bert":
+ eval_tasks = ["qa_for_bert"]
+
+ overall_results = {}
+ for eval_dataset in training_args.eval_dataset_list:
+ for eval_task in eval_tasks:
+ dataset_name = Path(eval_dataset).name
+ data = load_dataset(eval_dataset)
+ output_generation_file = os.path.join(
+ training_args.output_dir, dataset_name + "_" + eval_task + "_generation.json"
+ )
+ logger.info(f"Evaluating on {dataset_name} for {eval_task} task.")
+
+ if eval_task == "ans_ext":
+ generate(
+ path_or_dict=data,
+ output=output_generation_file,
+ model_url_or_path=training_args.output_dir,
+ use_cuda=not training_args.no_cuda,
+ task=eval_task,
+ max_source_length=training_args.max_source_length,
+ max_target_length=training_args.max_target_length,
+ seed=training_args.seed,
+ )
+ output_eval_file = os.path.join(
+ training_args.output_dir, dataset_name + "_" + eval_task + "_eval_result.json"
+ )
+ results = evaluate_from_file(
+ path=output_generation_file, task=eval_task, output=output_eval_file, dont_combine_answer_list=False
+ )
+ if results is not None:
+ for key, value in results.items():
+ overall_results["eval_" + dataset_name + "_" + eval_task + "_" + key] = value
+ elif eval_task == "qa":
+ generate(
+ path_or_dict=data,
+ output=output_generation_file,
+ model_url_or_path=training_args.output_dir,
+ use_cuda=not training_args.no_cuda,
+ task=eval_task,
+ max_source_length=training_args.max_source_length,
+ max_target_length=training_args.max_target_length,
+ seed=training_args.seed,
+ )
+ output_eval_file = os.path.join(
+ training_args.output_dir, dataset_name + "_" + eval_task + "_eval_result.json"
+ )
+ results = evaluate_from_file(
+ path=output_generation_file,
+ task=eval_task,
+ output=output_eval_file,
+ dont_combine_answer_list=False,
+ )
+ for key, value in results.items():
+ overall_results["eval_" + dataset_name + "_" + eval_task + "_" + key] = value
+ elif eval_task == "qg":
+ generate(
+ path_or_dict=data,
+ output=output_generation_file,
+ model_url_or_path=training_args.output_dir,
+ use_cuda=not training_args.no_cuda,
+ task=eval_task,
+ use_answers=True,
+ max_source_length=training_args.max_source_length,
+ max_target_length=training_args.max_target_length,
+ qg_format=training_args.mt5_qg_format,
+ seed=training_args.seed,
+ )
+ output_eval_file = os.path.join(
+ training_args.output_dir, dataset_name + "_" + eval_task + "_eval_result.json"
+ )
+ results = evaluate_from_file(
+ path=output_generation_file,
+ task=eval_task,
+ output=output_eval_file,
+ dont_combine_answer_list=False,
+ )
+ for key, value in results.items():
+ overall_results["eval_" + dataset_name + "_" + eval_task + "_" + key] = value
+ elif eval_task == "qa_for_bert":
+ generate(
+ path_or_dict=data,
+ output=output_generation_file,
+ model_url_or_path=training_args.output_dir,
+ use_cuda=not training_args.no_cuda,
+ task=eval_task,
+ max_source_length=training_args.max_source_length,
+ max_target_length=training_args.max_target_length,
+ seed=training_args.seed,
+ )
+ output_eval_file = os.path.join(
+ training_args.output_dir, dataset_name + "_" + eval_task + "_eval_result.json"
+ )
+ results = evaluate_from_file(
+ path=output_generation_file,
+ task="qa",
+ output=output_eval_file,
+ dont_combine_answer_list=False,
+ )
+ for key, value in results.items():
+ overall_results["eval_" + dataset_name + "_" + "qa" + "_" + key] = value
+ return overall_results
diff --git a/core/generate.py b/core/generate.py
new file mode 100644
index 0000000..fab77a3
--- /dev/null
+++ b/core/generate.py
@@ -0,0 +1,163 @@
+import json
+from typing import Optional, Union
+
+from transformers import set_seed
+
+from core.api import TurQue
+from core.bert_api import qa_from_file
+from utils.file import save_json
+
+
+def read_config(path: str):
+ with open(path, "r", encoding="utf-8") as jf:
+ config = json.load(jf)
+ return config
+
+
+def generate_qg(
+ path_or_dict: str = None,
+ output: str = None,
+ model_url_or_path: str = None,
+ use_cuda: bool = True,
+ use_answers: str = None,
+ max_source_length: int = 512,
+ max_target_length: int = 80,
+ qg_format: str = "highlight",
+):
+ use_answers = False if use_answers is None else use_answers
+ turque = TurQue(
+ model_url_or_path=model_url_or_path,
+ use_cuda=use_cuda,
+ max_source_length=max_source_length,
+ max_target_length=max_target_length,
+ qg_format=qg_format,
+ )
+ result = turque.qg_from_file(path_or_dict=path_or_dict, use_answers=use_answers)
+ save_json(result, output)
+
+
+def generate_qa(
+ path_or_dict: str = None,
+ output: str = None,
+ model_url_or_path: str = None,
+ use_cuda: bool = True,
+ max_source_length: int = 512,
+ max_target_length: int = 80,
+):
+ turque = TurQue(
+ model_url_or_path=model_url_or_path,
+ use_cuda=use_cuda,
+ max_source_length=max_source_length,
+ max_target_length=max_target_length,
+ )
+ result = turque.qa_from_file(path_or_dict=path_or_dict)
+ save_json(result, output)
+
+
+def generate_ans_ext(
+ path_or_dict: str = None,
+ output: str = None,
+ model_url_or_path: str = None,
+ use_cuda: bool = True,
+ max_source_length: int = 512,
+ max_target_length: int = 80,
+):
+ turque = TurQue(
+ model_url_or_path=model_url_or_path,
+ use_cuda=use_cuda,
+ max_source_length=max_source_length,
+ max_target_length=max_target_length,
+ )
+ result = turque.ans_ext_from_file(path_or_dict=path_or_dict)
+ save_json(result, output)
+
+
+def generate_qa_for_bert(
+ path_or_dict: str,
+ output: str,
+ model_url_or_path: str = None,
+ use_cuda: bool = True,
+ max_source_length: int = 512,
+ max_target_length: int = 80,
+):
+ result = qa_from_file(
+ path_or_dict=path_or_dict,
+ model_url_or_path=model_url_or_path,
+ use_cuda=use_cuda,
+ max_source_lenght=max_source_length,
+ max_target_length=max_target_length,
+ )
+ save_json(result, output)
+
+
+def generate(
+ path_or_dict: str = None,
+ output: str = None,
+ model_url_or_path: str = None,
+ use_cuda: bool = True,
+ use_answers: Union[str, bool] = False,
+ task: str = "qa",
+ max_source_length: int = 512,
+ max_target_length: int = 80,
+ config: str = None,
+ qg_format: str = "highlight",
+ seed: Optional[int] = None,
+):
+ """
+ path_or_dict (str): path or dict for a squad formatted dataset
+ output (str): output path for generation json
+ use_cuda (bool): perform generation on cuda
+ use_answers (bool): use gold answer for qg
+ task (str): one of 'qa', 'qg', 'ans_ext', 'qa_for_bert'
+ config (str): path to a json file
+ qg_format (str): 'highlight', 'prepend' or 'both'
+ seed (int): seed for randomized operations
+ """
+ args = read_config(config) if config is not None else {}
+ path_or_dict = args.get("path_or_dict") if path_or_dict is None else path_or_dict
+ output = args.get("output") if output is None else output
+ model_url_or_path = args.get("model_url_or_path") if model_url_or_path is None else model_url_or_path
+ task = args.get("task") if task is None else task
+ use_answers = False if use_answers is None else use_answers
+ if seed is not None:
+ set_seed(seed)
+ if task == "qa":
+ generate_qa(
+ path_or_dict=path_or_dict,
+ output=output,
+ model_url_or_path=model_url_or_path,
+ use_cuda=use_cuda,
+ max_source_length=max_source_length,
+ max_target_length=max_target_length,
+ )
+ elif task == "qg":
+ generate_qg(
+ path_or_dict=path_or_dict,
+ output=output,
+ model_url_or_path=model_url_or_path,
+ use_cuda=use_cuda,
+ use_answers=use_answers,
+ max_source_length=max_source_length,
+ max_target_length=max_target_length,
+ qg_format=qg_format,
+ )
+ elif task == "ans_ext":
+ generate_ans_ext(
+ path_or_dict=path_or_dict,
+ output=output,
+ model_url_or_path=model_url_or_path,
+ use_cuda=use_cuda,
+ max_source_length=max_source_length,
+ max_target_length=max_target_length,
+ )
+ elif task == "qa_for_bert":
+ generate_qa_for_bert(
+ path_or_dict=path_or_dict,
+ output=output,
+ model_url_or_path=model_url_or_path,
+ use_cuda=use_cuda,
+ max_source_length=max_source_length,
+ max_target_length=max_target_length,
+ )
+ else:
+ raise ValueError(f"'task' should be one of ['qa', 'qg', 'ans_ext', 'qa_for_bert'] but given as {task}")
diff --git a/core/pipelines.py b/core/pipelines.py
new file mode 100644
index 0000000..62a1461
--- /dev/null
+++ b/core/pipelines.py
@@ -0,0 +1,242 @@
+import itertools
+import logging
+from typing import Dict, List, Optional, Union
+
+import torch
+from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, PreTrainedModel, PreTrainedTokenizer
+
+from utils.nlp import sentence_tokenize
+
+
+class QGPipeline:
+ def __init__(
+ self,
+ model: PreTrainedModel,
+ tokenizer: PreTrainedTokenizer,
+ ans_model: PreTrainedModel,
+ ans_tokenizer: PreTrainedTokenizer,
+ qg_format: str,
+ use_cuda: bool,
+ generate_max_length: int = 80,
+ generate_num_beams: int = 4,
+ ):
+ self.model = model
+ self.tokenizer = tokenizer
+
+ self.ans_model = ans_model
+ self.ans_tokenizer = ans_tokenizer
+
+ self.qg_format = qg_format
+
+ self.device = "cuda" if torch.cuda.is_available() and use_cuda else "cpu"
+ self.model.to(self.device)
+
+ if self.ans_model is not self.model:
+ self.ans_model.to(self.device)
+
+ assert self.model.__class__.__name__ in ["MT5ForConditionalGeneration"]
+
+ self.model_type = "mt5"
+
+ self.generate_max_length = generate_max_length
+ self.generate_num_beams = generate_num_beams
+
+ def __call__(self, inputs: str):
+ inputs = " ".join(inputs.split())
+ sents, answers = self._extract_answers(inputs)
+ flat_answers = list(itertools.chain(*answers))
+
+ if len(flat_answers) == 0:
+ return []
+
+ qg_examples = self._prepare_inputs_for_qg_from_answers_hl(sents, answers)
+
+ qg_inputs = [example["source_text"] for example in qg_examples]
+ questions = self._generate_questions(qg_inputs)
+ output = [{"answer": example["answer"], "question": que} for example, que in zip(qg_examples, questions)]
+ return output
+
+ def _generate_questions(self, inputs):
+ inputs = self._tokenize(inputs, padding=True, truncation=True)
+
+ outs = self.model.generate(
+ input_ids=inputs["input_ids"].to(self.device),
+ attention_mask=inputs["attention_mask"].to(self.device),
+ max_length=self.generate_max_length,
+ num_beams=self.generate_num_beams,
+ )
+
+ questions = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in outs]
+ return questions
+
+ def _extract_answers(self, context):
+ sents, inputs = self._prepare_inputs_for_ans_extraction(context)
+
+ inputs = self._tokenize(inputs, padding=True, truncation=True)
+
+ outs = self.ans_model.generate(
+ input_ids=inputs["input_ids"].to(self.device),
+ attention_mask=inputs["attention_mask"].to(self.device),
+ max_length=self.generate_max_length,
+ )
+
+ dec = [self.ans_tokenizer.decode(ids, skip_special_tokens=True) for ids in outs]
+
+ answers = [item.split("") for item in dec]
+
+ answers = [i[:-1] for i in answers]
+
+ return sents, answers
+
+ def _tokenize(self, inputs, padding=True, truncation=True, add_special_tokens=True, max_length=512):
+ inputs = self.tokenizer.batch_encode_plus(
+ inputs,
+ max_length=max_length,
+ add_special_tokens=add_special_tokens,
+ truncation=truncation,
+ padding="max_length" if padding else False,
+ pad_to_max_length=padding,
+ return_tensors="pt",
+ )
+ return inputs
+
+ def _prepare_inputs_for_ans_extraction(self, text):
+ sents = sentence_tokenize(text)
+
+ inputs = []
+ for i in range(len(sents)):
+ source_text = "extract answers:"
+ for j, sent in enumerate(sents):
+ if i == j:
+ sent = " %s " % sent
+ source_text = "%s %s" % (source_text, sent)
+ source_text = source_text.strip()
+
+ # if self.model_type == "mt5":
+ # source_text = source_text + " "
+ inputs.append(source_text)
+
+ return sents, inputs
+
+ def _prepare_inputs_for_qg_from_answers_hl(self, sents, answers):
+ inputs = []
+ for i, answer in enumerate(answers):
+ if len(answer) == 0:
+ continue
+ for answer_text in answer:
+ sent = sents[i]
+ sents_copy = sents[:]
+
+ answer_text = answer_text.strip()
+ # sometimes extracted answers does not match with the original text :/
+ try:
+ ans_start_idx = sent.index(answer_text)
+
+ sent = f"{sent[:ans_start_idx]} {answer_text} {sent[ans_start_idx + len(answer_text): ]}"
+ sents_copy[i] = sent
+
+ source_text = " ".join(sents_copy)
+ source_text = f"generate question: {source_text}"
+ # if self.model_type == "mt5":
+ # source_text = source_text + " "
+ except:
+ continue
+
+ inputs.append({"answer": answer_text, "source_text": source_text})
+
+ return inputs
+
+
+class MultiTaskQAQGPipeline(QGPipeline):
+ def __init__(self, **kwargs):
+ super().__init__(**kwargs)
+
+ def __call__(self, inputs: Union[Dict, str]):
+ if type(inputs) is str:
+ # do qg
+ return super().__call__(inputs)
+ else:
+ # do qa
+ return self._extract_answer(inputs["question"], inputs["context"])
+
+ def _prepare_inputs_for_qa(self, question, context):
+ source_text = f"question: {question} context: {context}"
+ # if self.model_type == "mt5":
+ # source_text = source_text + " "
+ return source_text
+
+ def _extract_answer(self, question, context):
+ source_text = self._prepare_inputs_for_qa(question, context)
+ inputs = self._tokenize([source_text], padding=False)
+
+ outs = self.model.generate(
+ input_ids=inputs["input_ids"].to(self.device),
+ attention_mask=inputs["attention_mask"].to(self.device),
+ max_length=self.generate_max_length,
+ )
+
+ answer = self.tokenizer.decode(outs[0], skip_special_tokens=True)
+
+ return answer
+
+
+SUPPORTED_TASKS = {
+ "multitask-qa-qg": {
+ "class": MultiTaskQAQGPipeline,
+ "default": {
+ "model": "obss/mt5-qa-qg",
+ },
+ },
+}
+
+
+def pipeline(
+ task: str,
+ model: Optional[PreTrainedModel] = None,
+ tokenizer: Optional[Union[str, PreTrainedTokenizer]] = None,
+ qg_format: Optional[str] = "highlight",
+ use_cuda: Optional[bool] = True,
+ **kwargs,
+):
+ # Retrieve the task
+ if task not in SUPPORTED_TASKS:
+ raise KeyError("Unknown task {}, available tasks are {}".format(task, list(SUPPORTED_TASKS.keys())))
+
+ targeted_task = SUPPORTED_TASKS[task]
+ task_class = targeted_task["class"]
+
+ # Use default model/config/tokenizer for the task if no model is provided
+ if model is None:
+ model = targeted_task["default"]["model"]
+
+ # Try to infer tokenizer from model or config name (if provided as str)
+ if tokenizer is None:
+ if isinstance(model, str):
+ tokenizer = model
+ else:
+ # Impossible to guest what is the right tokenizer here
+ raise Exception(
+ "Impossible to guess which tokenizer to use. "
+ "Please provided a PretrainedTokenizer class or a path/identifier to a pretrained tokenizer."
+ )
+
+ # Instantiate tokenizer if needed
+ if isinstance(tokenizer, (str, tuple)):
+ if isinstance(tokenizer, tuple):
+ # For tuple we have (tokenizer name, {kwargs})
+ tokenizer = AutoTokenizer.from_pretrained(tokenizer[0], **tokenizer[1])
+ else:
+ tokenizer = AutoTokenizer.from_pretrained(tokenizer)
+
+ # Instantiate model if needed
+ if isinstance(model, str):
+ model = AutoModelForSeq2SeqLM.from_pretrained(model)
+
+ return task_class(
+ model=model,
+ tokenizer=tokenizer,
+ ans_model=model,
+ ans_tokenizer=tokenizer,
+ qg_format=qg_format,
+ use_cuda=use_cuda,
+ )
diff --git a/environment.yml b/environment.yml
new file mode 100644
index 0000000..ab3e6c0
--- /dev/null
+++ b/environment.yml
@@ -0,0 +1,26 @@
+name: turque
+#prefix: /your/custom/path/envs/turque
+channels:
+ - anaconda
+dependencies:
+ - anaconda::cudatoolkit=11
+ - pip
+ - python=3.8
+ - pytorch::pytorch
+ - pip:
+ - bert-score==0.3.10
+ - black==21.7b0
+ - datasets>=1.12.0,<2.0.0
+ - flake8==3.9.2
+ - gdown
+ - isort==5.9.2
+ - jupyterlab==3.0.14
+ - jury>=2.1.0,<3.0.0
+ - protobuf>=3.17.3
+ - pyyaml
+ - pysocks==1.5.6
+ - rouge-score==0.0.4
+ - sacrebleu==1.5.1
+ - sentencepiece==0.1.96
+ - transformers>=4.10.0,<5.0.0
+ - trtokenizer==0.0.3
diff --git a/hf/__init__.py b/hf/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/hf/model.py b/hf/model.py
new file mode 100644
index 0000000..0c8f39d
--- /dev/null
+++ b/hf/model.py
@@ -0,0 +1,106 @@
+import logging
+import os
+from typing import Optional, Tuple
+
+import torch
+from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, BertForQuestionAnswering, BertTokenizerFast
+from transformers.hf_argparser import DataClass
+
+from utils.file import download_from_gdrive_and_unzip
+from utils.torch import assert_not_all_frozen, freeze_embeds
+
+TURQUE_S1_GDRIVE_URL = "https://drive.google.com/uc?id=10hHFuavHCofDczGSzsH1xPHgTgAocOl1"
+
+logger = logging.getLogger(__name__)
+logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ level=os.environ.get("LOGLEVEL", "INFO").upper(),
+)
+
+
+class MT5Model:
+ def __init__(
+ self,
+ model_name_or_path: str = "turque-s1",
+ tokenizer_name_or_path: str = None,
+ freeze_embeddings: bool = False,
+ cache_dir: Optional[str] = None,
+ use_cuda: bool = None,
+ ):
+ # try downloading pretrained files
+ if model_name_or_path == "turque-s1":
+ model_name_or_path = "data/pretrained/turque-s1/"
+ if not os.path.isfile("data/pretrained/turque-s1/pytorch_model.bin"):
+ download_from_gdrive_and_unzip(TURQUE_S1_GDRIVE_URL, model_name_or_path)
+ logger.info(f"pretrained model is downloaded to {model_name_or_path}")
+ else:
+ logger.info(f"using pretrained model at {model_name_or_path}")
+ model_name_or_path = "data/pretrained/turque-s1/"
+
+ model = AutoModelForSeq2SeqLM.from_pretrained(
+ model_name_or_path,
+ cache_dir=cache_dir,
+ )
+ if tokenizer_name_or_path is not None:
+ tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path, cache_dir=cache_dir)
+ else:
+ try:
+ tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, cache_dir=cache_dir)
+ except:
+ tokenizer = AutoTokenizer.from_pretrained(
+ os.path.join(model_name_or_path, "tokenizer_config.json"), cache_dir=cache_dir
+ )
+ assert model.__class__.__name__ in ["MT5ForConditionalGeneration"]
+ self.model = model
+ self.tokenizer = tokenizer
+ self.type = "mt5"
+
+ self.device = "cuda" if torch.cuda.is_available() and use_cuda else "cpu"
+ self.model.to(self.device)
+
+ if freeze_embeddings:
+ logger.info("freezing embeddings of the model")
+ freeze_embeds(self.model)
+ assert_not_all_frozen(self.model)
+
+ self.model.resize_token_embeddings(len(self.tokenizer))
+
+
+class BertModel:
+ def __init__(
+ self,
+ model_name_or_path: str = "dbmdz/bert-base-turkish-cased",
+ tokenizer_name_or_path: str = None,
+ freeze_embeddings: bool = False,
+ cache_dir: Optional[str] = None,
+ use_cuda: bool = None,
+ ):
+
+ model = BertForQuestionAnswering.from_pretrained(
+ model_name_or_path,
+ cache_dir=cache_dir,
+ )
+ if tokenizer_name_or_path is not None:
+ tokenizer = BertTokenizerFast.from_pretrained(tokenizer_name_or_path, cache_dir=cache_dir)
+ else:
+ try:
+ tokenizer = BertTokenizerFast.from_pretrained(model_name_or_path, cache_dir=cache_dir)
+ except:
+ tokenizer = BertTokenizerFast.from_pretrained(
+ os.path.join(model_name_or_path, "tokenizer_config.json"), cache_dir=cache_dir
+ )
+ assert model.__class__.__name__ in ["BertForQuestionAnswering"]
+ self.model = model
+ self.tokenizer: BertTokenizerFast = tokenizer
+ self.type = "bert"
+
+ self.device = "cuda" if torch.cuda.is_available() and use_cuda else "cpu"
+ self.model.to(self.device)
+
+ if freeze_embeddings:
+ logger.info("freezing embeddings of the model")
+ freeze_embeds(self.model)
+ assert_not_all_frozen(self.model)
+
+ self.model.resize_token_embeddings(len(self.tokenizer))
diff --git a/prepare_data.py b/prepare_data.py
new file mode 100644
index 0000000..279eaa7
--- /dev/null
+++ b/prepare_data.py
@@ -0,0 +1,193 @@
+import logging
+import os
+from pathlib import Path
+from typing import List
+
+import pandas as pd
+import torch
+from datasets import Dataset
+
+from core.argument_parsers import parser
+from core.dataset_parsers import load_and_prepare_dataset
+from hf.model import BertModel, BertTokenizerFast, MT5Model
+
+logger = logging.getLogger(__name__)
+
+
+class MT5DataProcessor:
+ def __init__(self, tokenizer, max_source_length=512, max_target_length=80):
+ self.tokenizer = tokenizer
+ self.max_source_length = max_source_length
+ self.max_target_length = max_target_length
+
+ def process(self, dataset):
+ dataset = dataset.map(self._convert_to_features, batched=True)
+
+ return dataset
+
+ # tokenize the examples
+ def _convert_to_features(self, example_batch):
+ source_encoding = self.tokenizer.batch_encode_plus(
+ example_batch["source_text"],
+ max_length=self.max_source_length,
+ padding="max_length",
+ pad_to_max_length=True,
+ truncation=True,
+ )
+ target_encoding = self.tokenizer.batch_encode_plus(
+ example_batch["target_text"],
+ max_length=self.max_target_length,
+ padding="max_length",
+ pad_to_max_length=True,
+ truncation=True,
+ )
+
+ encodings = {
+ "source_ids": source_encoding["input_ids"],
+ "target_ids": target_encoding["input_ids"],
+ "attention_mask": source_encoding["attention_mask"],
+ }
+
+ return encodings
+
+
+class BertDataProcessor:
+ def __init__(self, tokenizer: BertTokenizerFast, max_source_length: int = 512):
+ self.tokenizer = tokenizer
+ self.max_source_length = max_source_length
+
+ def process(self, dataset):
+ dataset = dataset.map(self._convert_to_features, batched=True)
+
+ return dataset
+
+ def _add_token_positions(self, encodings, answers):
+ start_positions = []
+ end_positions = []
+ for i in range(len(answers)):
+ start_positions.append(encodings.char_to_token(i, answers[i]["answer_start"]))
+ end_positions.append(encodings.char_to_token(i, answers[i]["answer_end"] - 1))
+
+ # if start position is None, the answer passage has been truncated
+ if start_positions[-1] is None:
+ start_positions[-1] = self.max_source_length
+ if end_positions[-1] is None:
+ end_positions[-1] = self.max_source_length
+
+ encodings.update({"start_positions": start_positions, "end_positions": end_positions})
+
+ # tokenize the examples
+ def _convert_to_features(self, example_batch):
+ encodings = self.tokenizer(
+ example_batch["context"],
+ example_batch["question"],
+ max_length=self.max_source_length,
+ padding="max_length",
+ pad_to_max_length=True,
+ truncation=True,
+ )
+
+ self._add_token_positions(encodings, example_batch["answer"])
+
+ return encodings
+
+
+def _read_datasets(
+ names: List[str], target_format="mt5", mt5_task_list: List[str] = ["ans_ext", "qa", "qg"], mt5_qg_format="highlight"
+) -> Dataset:
+ """
+ Args:
+ names: lisf of dataset subset names or paths
+ target_format (str): output format ('mt5' or 'bert')
+ mt5_task_list: list of tasks for mt5 data to be prepared
+ mt5_qg_format: "highlight", "prepend" or "both"
+ """
+ data = []
+ for name in names:
+ data.extend(
+ load_and_prepare_dataset(
+ name, target_format=target_format, mt5_task_list=mt5_task_list, mt5_qg_format=mt5_qg_format
+ )
+ )
+ data = Dataset.from_pandas(pd.DataFrame(data))
+ return data
+
+
+def main(args_file_path: str = None):
+ model_args, data_args, train_args = parser(args_file_path)
+
+ logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ level=os.environ.get("LOGLEVEL", "INFO").upper(),
+ )
+
+ # set datasets
+ train_dataset = _read_datasets(
+ names=train_args.train_dataset_list,
+ target_format=model_args.model_type,
+ mt5_task_list=train_args.mt5_task_list,
+ mt5_qg_format=train_args.mt5_qg_format,
+ )
+ valid_dataset = _read_datasets(
+ names=train_args.valid_dataset_list,
+ target_format=model_args.model_type,
+ mt5_task_list=train_args.mt5_task_list,
+ mt5_qg_format=train_args.mt5_qg_format,
+ )
+
+ # set tokenizer
+ if model_args.model_type == "mt5":
+ model = MT5Model(model_args.model_name_or_path)
+ tokenizer = model.tokenizer
+ tokenizer.add_tokens(["", ""])
+ elif model_args.model_type == "bert":
+ model = BertModel(model_args.model_name_or_path)
+ tokenizer = model.tokenizer
+
+ # set processor
+ if model_args.model_type == "mt5":
+ processor = MT5DataProcessor(
+ tokenizer, max_source_length=train_args.max_source_length, max_target_length=train_args.max_target_length
+ )
+ elif model_args.model_type == "bert":
+ processor = BertDataProcessor(tokenizer, max_source_length=train_args.max_source_length)
+
+ # process datasets
+ train_dataset = processor.process(train_dataset)
+ valid_dataset = processor.process(valid_dataset)
+
+ if model_args.model_type == "mt5":
+ columns = ["source_ids", "target_ids", "attention_mask"]
+ train_dataset.set_format(type="torch", columns=columns)
+ valid_dataset.set_format(type="torch", columns=columns)
+ elif model_args.model_type == "bert":
+ columns = ["start_positions", "end_positions", "input_ids", "attention_mask"]
+ train_dataset.set_format(type="torch")
+ valid_dataset.set_format(type="torch")
+
+ # create train/valid file dirs
+ train_file_path = Path(str(data_args.train_file_path).strip())
+ if not train_file_path.parent.exists():
+ train_file_path.parent.mkdir(parents=True, exist_ok=True)
+ valid_file_path = Path(str(data_args.valid_file_path).strip())
+ if not valid_file_path.parent.exists():
+ valid_file_path.parent.mkdir(parents=True, exist_ok=True)
+
+ # save train/valid files
+ torch.save(train_dataset, train_file_path)
+ logger.info(f"saved train dataset at {train_file_path}")
+
+ torch.save(valid_dataset, valid_file_path)
+ logger.info(f"saved validation dataset at {valid_file_path}")
+
+ # save tokenizer
+ tokenizer_path = model_args.tokenizer_path
+ if not os.path.exists(tokenizer_path):
+ os.mkdir(tokenizer_path)
+ tokenizer.save_pretrained(tokenizer_path)
+ logger.info(f"saved tokenizer at {tokenizer_path}")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/pyproject.toml b/pyproject.toml
new file mode 100644
index 0000000..3f398cf
--- /dev/null
+++ b/pyproject.toml
@@ -0,0 +1,17 @@
+[tool.black]
+line-length = 120
+exclude = '''
+(
+ /(
+ .git
+ | .vscode
+ | .venv
+ | runs
+ | data
+ )/
+)
+'''
+
+[tool.isort]
+line_length = 120
+profile = "black"
\ No newline at end of file
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000..90a5dd5
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,13 @@
+bert-score==0.3.10
+datasets>=1.12.0,<2.0.0
+gdown
+jury>=2.1.0,<3.0.0
+trtokenizer==0.0.3
+protobuf>=3.17.3
+pysocks==1.5.6
+pyyaml
+rouge-score==0.0.4
+sacrebleu==1.5.1
+sentencepiece==0.1.96
+torch==1.10.0
+transformers>=4.10.0,<5.0.0
diff --git a/run.py b/run.py
new file mode 100644
index 0000000..29c8bfc
--- /dev/null
+++ b/run.py
@@ -0,0 +1,230 @@
+import json
+import logging
+import os
+from typing import Tuple
+
+import torch
+import transformers
+from transformers import Trainer as HFTrainer
+from transformers import set_seed
+from transformers.hf_argparser import DataClass
+from transformers.optimization import Adafactor, AdamW
+from transformers.trainer import Trainer
+
+from core.argument_parsers import parser
+from core.collator import T2TDataCollator
+from core.evaluate import evaluate_on_train_end
+from hf.model import BertModel, MT5Model
+from prepare_data import main as prepare_data
+from utils.file import save_experiment_config
+from utils.neptune import init_neptune, log_to_neptune
+from utils.wandb import init_wandb, log_to_wandb
+
+
+def setup_logger(args: DataClass) -> logging.Logger:
+ logger = logging.getLogger(__name__)
+ logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ level=os.environ.get("LOGLEVEL", "INFO").upper() if args.local_rank in [-1, 0] else logging.WARN,
+ )
+ logger.warning(
+ "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
+ args.local_rank,
+ args.device,
+ args.n_gpu,
+ bool(args.local_rank != -1),
+ args.fp16,
+ )
+ logger.info("Training/evaluation parameters %s", args)
+ return logger
+
+
+def check_output(args: DataClass, logger: logging.Logger = None) -> None:
+ if (
+ os.path.exists(args.output_dir)
+ and os.listdir(args.output_dir)
+ and args.do_train
+ and not args.overwrite_output_dir
+ ):
+ raise ValueError(
+ f"Output directory ({args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to overcome."
+ )
+
+
+def load_datasets(args: DataClass, train: bool, eval: bool, logger: logging.Logger) -> Tuple:
+ logger.info("loading dataset")
+ train_dataset = torch.load(args.train_file_path) if train else None
+ valid_dataset = torch.load(args.valid_file_path) if eval else None
+ logger.info("finished loading dataset")
+
+ return train_dataset, valid_dataset
+
+
+def main(args_file_path: str = None):
+
+ model_args, data_args, training_args = parser(args_file_path)
+
+ # check for output_dir with given arguments.
+ check_output(training_args)
+
+ logger = setup_logger(training_args)
+
+ # set seed
+ set_seed(training_args.seed)
+
+ # initialize experiment tracking
+ report_to = []
+
+ if training_args.do_train:
+ wandb_status, wandb = init_wandb(project=model_args.wandb_project, name=training_args.run_name)
+ else:
+ wandb_status, wandb = init_wandb(
+ project=model_args.wandb_project, name=training_args.run_name, id=model_args.wandb_id
+ )
+ neptune_status, neptune = init_neptune(
+ project=model_args.neptune_project, api_token=model_args.neptune_api_token, name=training_args.run_name
+ )
+
+ if wandb_status:
+ report_to.append("wandb")
+ if neptune_status:
+ report_to.append("neptune")
+
+ training_args.report_to = report_to
+
+ # disable wandb console logs
+ logging.getLogger("wandb.run_manager").setLevel(logging.WARNING)
+
+ # prepare data()
+ if data_args.prepare_data:
+ prepare_data(args_file_path)
+
+ # load model
+ if model_args.model_type == "mt5":
+ model = MT5Model(
+ model_name_or_path=model_args.model_name_or_path,
+ tokenizer_name_or_path=model_args.tokenizer_path,
+ freeze_embeddings=training_args.freeze_embeddings,
+ cache_dir=model_args.cache_dir,
+ use_cuda=True,
+ )
+ elif model_args.model_type == "bert":
+ model = BertModel(
+ model_name_or_path=model_args.model_name_or_path,
+ tokenizer_name_or_path=model_args.tokenizer_path,
+ freeze_embeddings=training_args.freeze_embeddings,
+ cache_dir=model_args.cache_dir,
+ use_cuda=True,
+ )
+ train_dataset, valid_dataset = load_datasets(
+ data_args, train=training_args.do_train, eval=training_args.do_eval, logger=logger
+ )
+
+ # set optimizer
+ if training_args.adafactor:
+ # as adviced in https://huggingface.co/transformers/main_classes/optimizer_schedules.html#adafactor-pytorch
+ optimizer = Adafactor(
+ model.model.parameters(),
+ scale_parameter=False,
+ relative_step=False,
+ warmup_init=False,
+ weight_decay=training_args.weight_decay,
+ lr=training_args.learning_rate,
+ )
+ else:
+ optimizer = AdamW(
+ model.model.parameters(), weight_decay=training_args.weight_decay, lr=training_args.learning_rate
+ )
+
+ if model_args.model_type == "mt5":
+ # initialize data_collator
+ data_collator = T2TDataCollator(
+ tokenizer=model.tokenizer, mode="training", using_tpu=training_args.tpu_num_cores is not None
+ )
+
+ # fix https://discuss.huggingface.co/t/mt5-fine-tuning-keyerror-source-ids/5257/2
+ training_args.remove_unused_columns = False if model_args.model_type == "mt5" else True
+
+ # export experiment config
+ save_experiment_config(model_args, data_args, training_args)
+
+ # start training
+ if training_args.do_train:
+ # init model
+ if model_args.model_type == "mt5":
+ trainer: Trainer = HFTrainer(
+ model=model.model,
+ args=training_args,
+ train_dataset=train_dataset,
+ eval_dataset=valid_dataset,
+ data_collator=data_collator,
+ optimizers=(optimizer, None),
+ )
+ elif model_args.model_type == "bert":
+ trainer: Trainer = HFTrainer(
+ model=model.model,
+ args=training_args,
+ train_dataset=train_dataset,
+ eval_dataset=valid_dataset,
+ optimizers=(optimizer, None),
+ )
+
+ # perform training
+ trainer.train(
+ resume_from_checkpoint=model_args.model_name_or_path
+ if os.path.isdir(model_args.model_name_or_path)
+ else None
+ )
+ trainer.save_model()
+ # For convenience, we also re-save the tokenizer to the same directory,
+ # so that you can share your model easily on huggingface.co/models =)
+
+ model.tokenizer.save_pretrained(training_args.output_dir)
+
+ # start evaluation
+ if training_args.do_eval and training_args.local_rank in [-1, 0]:
+ # arange neptune/wandb loggers
+ if training_args.do_train:
+ for callback in trainer.callback_handler.callbacks:
+ if isinstance(callback, transformers.integrations.WandbCallback):
+ wandb = callback._wandb
+ for callback in trainer.callback_handler.callbacks:
+ if isinstance(callback, transformers.integrations.NeptuneCallback):
+ neptune_run = callback._neptune_run
+ if not training_args.do_train:
+ if "neptune" in report_to:
+ neptune_run = neptune.init(
+ project=os.getenv("NEPTUNE_PROJECT"),
+ api_token=os.getenv("NEPTUNE_API_TOKEN"),
+ mode=os.getenv("NEPTUNE_CONNECTION_MODE", "async"),
+ name=os.getenv("NEPTUNE_RUN_NAME", None),
+ run=model_args.neptune_run,
+ )
+ elif "wandb" in report_to:
+ wandb.init(project=model_args.wandb_project, name=model_args.run_name, id=model_args.wandb_id)
+
+ # calculate evaluation results
+ overall_results = evaluate_on_train_end(model_args, training_args)
+
+ # log to neptune/wandb
+ if "neptune" in report_to:
+ log_to_neptune(neptune_run, overall_results)
+ if "wandb" in report_to:
+ log_to_wandb(wandb, overall_results)
+
+
+def _mp_fn(index):
+ # For xla_spawn (TPUs)
+ main()
+
+
+def run_multi(args_dict):
+ with open("args.json", "w") as f:
+ json.dump(args_dict, f)
+
+ main(args_file="args.json")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/setup.cfg b/setup.cfg
new file mode 100644
index 0000000..a3efd0f
--- /dev/null
+++ b/setup.cfg
@@ -0,0 +1,5 @@
+[flake8]
+max-line-length = 120
+exclude =.git
+ignore = I101,I201,F401,F403,S001,D100,D101,D102,D103,D104,D105,D106,D107,D200,D205,D400,W504,D202,E203,E722,W503,B006
+inline-quotes = "
diff --git a/tests/__init__.py b/tests/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/tests/test_config.yaml b/tests/test_config.yaml
new file mode 100644
index 0000000..545a432
--- /dev/null
+++ b/tests/test_config.yaml
@@ -0,0 +1,40 @@
+model_name_or_path: "google/mt5-small"
+tokenizer_path: "mt5_small_tokenizer"
+label_smoothing_factor: 0
+freeze_embeddings: false
+run_name: null
+wandb_project: null
+wandb_id: null
+neptune_project: null
+neptune_run: null
+neptune_api_token: null
+train_dataset_list: ["tquad.small"]
+valid_dataset_list: ["tquad.small"]
+eval_dataset_list: ["tquad.small"]
+train_file_path: "data/train_data_multitask_mt5.pt"
+valid_file_path: "data/valid_data_multitask_mt5.pt"
+max_source_length: 512
+max_target_length: 64
+prepare_data: true
+mt5_task_list: [
+ "qa",
+ "qg",
+ "ans_ext"
+]
+mt5_qg_format: "highlight"
+output_dir: "runs/exp1"
+do_train: true
+do_eval: true
+evaluation_strategy: "steps"
+eval_steps: 1
+eval_accumulation_steps: 1
+per_device_train_batch_size: 1
+per_device_eval_batch_size: 1
+gradient_accumulation_steps: 1
+learning_rate: 1.0e-4
+num_train_epochs: 1
+save_total_limit: 1
+no_cuda: true
+seed: 42
+max_steps: 2
+overwrite_output_dir: true
\ No newline at end of file
diff --git a/tr_non_suffixes b/tr_non_suffixes
new file mode 100644
index 0000000..a6bfd9e
--- /dev/null
+++ b/tr_non_suffixes
@@ -0,0 +1,246 @@
+# This is a non-breaking prefix list for the Turkish language.
+# The file is used for sentence tokenization (tr_tokenizer/tokenizer/SentenceTokenizer class).
+#
+#ASSUMPTION:
+#
+# Anything in this file, followed by a period (and an upper-case word), does NOT
+# indicate an end-of-sentence marker.
+# Special cases are included for prefixes that ONLY appear before 0-9 numbers.
+# Any single upper case letter followed by a period is not a sentence ender
+# Usually upper case letters are initials in a name.
+A
+B
+C
+D
+E
+F
+G
+H
+I
+J
+K
+L
+M
+N
+O
+P
+Q
+R
+S
+T
+U
+V
+W
+X
+Y
+Z
+# Usually upper case letters are initials in a name (Turkish alphabet)
+Ç
+Ğ
+İ
+Ö
+Ş
+Ü
+# Roman Numerals
+I
+II
+III
+IV
+V
+VI
+VII
+VIII
+IX
+X
+XI
+XII
+XIII
+XIV
+XV
+XVI
+XVII
+XVIII
+XIX
+XX
+
+# English -- but these work globally for all languages:
+Mr
+Mrs
+No
+pp
+St
+no
+Sr
+Jr
+Bros
+etc
+vs
+esp
+Fig
+fig
+Jan
+Feb
+Mar
+Apr
+Jun
+Jul
+Aug
+Sep
+Sept
+Oct
+Okt
+Nov
+Dec
+Ph.D
+PhD
+# in "et al."
+al
+cf
+Inc
+Ms
+Gen
+Sen
+Prof
+Dr
+Corp
+Co
+# http://en.wiktionary.org/wiki/Category:Turkish_abbreviations
+Av
+no
+Başk.yard
+Bşk.yrd
+Akad
+Alb
+Alm
+anat
+ant
+Apt
+Ar
+Ar. Gör
+ark
+As
+Asb
+Asist
+astr
+astrol
+Atğm
+atm
+Av
+bağ
+Bçvş
+B.E
+bitb
+biy
+bk
+Bl
+Bn
+Bnb
+bot
+Böl
+bs
+Bşk
+Bul
+Bulg
+C
+Cad
+coğ
+çev
+Çvş
+D
+dam
+db
+dbl
+Doç
+doğ
+Dr
+Dz. Kuv. K
+dzş
+e
+Ecz
+ed
+ekon
+Ens
+Erm
+F
+f
+Fak
+Far
+fel
+fil
+fiz
+Fr
+Gen
+geom
+gn
+Gnkur
+Gön
+H.O
+Hv. Kuv
+Hv. Kuv. K
+Hz
+İbr
+is
+İsp
+İt
+Mah
+Sok
+LTD
+ŞTİ
+TİC
+PAZ
+SAN
+CAD
+SK
+ORG
+SAN
+BEL
+İST
+JAN
+EĞ
+ÜNİV
+TUG
+SNR
+M.E.B
+KDZ
+A.Ş
+MD
+GN
+ŞB
+BŞ
+K.K
+Yzb
+
+# Number indicators
+# add #NUMERIC_ONLY# after the word if it should ONLY be non-breaking when a 0-9 digit follows it
+hayır
+
+# Ordinals are (apparently) done with . in Turkish - "1." = "1st" in English
+# Ordinals 1 to 10000 added to self.__non_breaking_suffixes, thus they were deleted from here
+# Numbers that may be used in date were added
+01
+02
+03
+04
+05
+06
+07
+08
+09
+
+# (d. 998 - ö. 1068)
+d
+ö
+(d
+- ö
+# Ömer b. Abdülazīz
+b
+# Ord. Prof. Dr.
+Ord
+# 14-17. Yüzyıllarda ??
+#(Çerkez, Abaza, vs.) Musluaman ??
+# Prof.Dr. Ernst
+Prof.Dr
+# Ss. Cyril and Methodius University
+Ss
+# Bulgar Bilimler Akademisi Başkanı Acad. Stefan Vodenicharov'dan aldı.
+Acad
\ No newline at end of file
diff --git a/utils/__init__.py b/utils/__init__.py
new file mode 100644
index 0000000..b2de5a9
--- /dev/null
+++ b/utils/__init__.py
@@ -0,0 +1,5 @@
+import os.path
+from pathlib import Path
+
+PROJECT_ROOT = Path(os.path.dirname(__file__)).parent.resolve()
+SOURCE_ROOT = PROJECT_ROOT
diff --git a/utils/file.py b/utils/file.py
new file mode 100644
index 0000000..f75e008
--- /dev/null
+++ b/utils/file.py
@@ -0,0 +1,141 @@
+import json
+import os
+import urllib.request
+import zipfile
+from dataclasses import asdict
+from pathlib import Path
+from typing import Dict
+
+import yaml
+
+
+def safe_download(target_path: str, source_url: str, source_url2=None, min_bytes=1e0, error_msg="") -> None:
+ """Attempts to download file from source_url or source_url2, checks and removes incomplete downloads < min_bytes"""
+ file = Path(target_path)
+ assert_msg = f"Downloaded file '{file}' does not exist or size is < min_bytes={min_bytes}"
+ try: # url1
+ print(f"Downloading {source_url} to {file}...")
+ urllib.request.urlretrieve(
+ source_url,
+ target_path,
+ )
+ assert file.exists() and file.stat().st_size > min_bytes, assert_msg # check
+ except Exception as e: # url2
+ file.unlink(missing_ok=True) # remove partial downloads
+ print(f"ERROR: {e}\nRe-attempting {source_url2 or source_url} to {file}...")
+ urllib.request.urlretrieve(
+ source_url2 or source_url,
+ target_path,
+ )
+ finally:
+ if not file.exists() or file.stat().st_size < min_bytes: # check
+ file.unlink(missing_ok=True) # remove partial downloads
+ print(f"ERROR: {assert_msg}\n{error_msg}")
+ print("")
+
+
+def attempt_download(source_url: str, target_path: str) -> None:
+ target_path = Path(str(target_path).strip().replace("'", ""))
+ if not target_path.exists():
+ target_path.parent.mkdir(parents=True, exist_ok=True)
+ safe_download(target_path=str(target_path), source_url=source_url)
+
+
+def load_json(load_path, object_hook=None):
+ """
+ Loads json formatted data (given as "data") from load_path
+ Example inputs:
+ load_path: "dirname/squad.json"
+ """
+ # read from path
+ with open(load_path, encoding="utf-8") as json_file:
+ data = json.load(json_file, object_hook=object_hook)
+ return data
+
+
+def save_json(obj: Dict, path: str, encoding: str = "utf-8", indent: int = 4):
+ """
+ Save dict as json file.
+ """
+ with open(path, "w", encoding=encoding) as jf:
+ json.dump(obj, jf, indent=indent, default=str, ensure_ascii=False)
+
+
+def read_yaml(yaml_path):
+ """
+ Reads yaml file as dict.
+ """
+ with open(yaml_path) as f:
+ yaml_data = yaml.load(f, Loader=yaml.FullLoader)
+
+ return yaml_data
+
+
+def save_yaml(dict_file, yaml_path):
+ """
+ Saves dict as yaml file.
+ """
+
+ Path(yaml_path).parent.mkdir(parents=True, exist_ok=True)
+
+ with open(yaml_path, "w") as file:
+ yaml.dump(dict_file, file)
+
+
+def unzip(file_path: str, dest_dir: str):
+ """
+ Unzips compressed .zip file.
+ Example inputs:
+ file_path: 'data/01_alb_id.zip'
+ dest_dir: 'data/'
+ """
+
+ # unzip file
+ with zipfile.ZipFile(file_path) as zf:
+ zf.extractall(dest_dir)
+
+
+def download_from_url(from_url: str, to_path: str):
+
+ Path(to_path).parent.mkdir(parents=True, exist_ok=True)
+
+ if not os.path.exists(to_path):
+ urllib.request.urlretrieve(
+ from_url,
+ to_path,
+ )
+
+
+def save_experiment_config(model_args, data_args, training_args):
+ experiment_config = {}
+ experiment_config.update(asdict(model_args))
+ experiment_config.update(asdict(data_args))
+ experiment_config.update(asdict(training_args))
+ yaml_path = Path(training_args.output_dir) / "experiment_config.yaml"
+ save_yaml(experiment_config, yaml_path)
+
+
+def download_from_gdrive(url: str, save_dir: str) -> str:
+ """
+ Downloads file from gdrive, shows progress.
+ Example inputs:
+ url: 'https://drive.google.com/uc?id=10hHFuavHCofDczGSzsH1xPHgTgAocOl1'
+ save_dir: 'data/'
+ """
+ import gdown
+
+ # create save_dir if not present
+ Path(save_dir).mkdir(parents=True, exist_ok=True)
+ # download file
+ filepath = gdown.download(url, save_dir, quiet=False)
+ return filepath
+
+
+def download_from_gdrive_and_unzip(url: str, save_dir: str) -> str:
+ save_dir = save_dir + os.path.sep
+ # download zip file
+ filepath = download_from_gdrive(url, save_dir)
+ # extract zip file
+ unzip(filepath, str(Path(filepath).parent))
+ # remove zip file
+ os.remove(filepath)
diff --git a/utils/neptune.py b/utils/neptune.py
new file mode 100644
index 0000000..0f092ea
--- /dev/null
+++ b/utils/neptune.py
@@ -0,0 +1,27 @@
+import os
+
+
+def init_neptune(project: str = None, api_token: str = None, name: str = None):
+ status = False
+ neptune = None
+ if project is not None and api_token is not None:
+ try:
+ import neptune.new as neptune
+
+ os.environ["NEPTUNE_PROJECT"] = project
+ os.environ["NEPTUNE_API_TOKEN"] = api_token
+ if name is not None:
+ os.environ["NEPTUNE_RUN_NAME"] = name
+
+ status = True
+ except ImportError:
+ print("neptune not installed, skipping neptune logging. 'pip install neptune-client' for neptune logging.")
+ except Exception as e:
+ print(e)
+
+ return status, neptune
+
+
+def log_to_neptune(neptune_run, dict):
+ for k, v in dict.items():
+ neptune_run[k].log(v)
diff --git a/utils/nlp.py b/utils/nlp.py
new file mode 100644
index 0000000..7c0079d
--- /dev/null
+++ b/utils/nlp.py
@@ -0,0 +1,481 @@
+import itertools
+import logging
+import os
+import re
+import string
+from typing import Dict, List, Union
+
+from datasets import Dataset
+from trtokenizer.tr_tokenizer import SentenceTokenizer
+
+from utils import SOURCE_ROOT
+from utils.file import attempt_download, load_json
+
+logger = logging.getLogger(__name__)
+logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ level=os.environ.get("LOGLEVEL", "INFO").upper(),
+)
+
+
+def sentence_tokenize(context: str) -> List[str]:
+ non_prefix_file_path = str(SOURCE_ROOT / "tr_non_suffixes")
+ sentence_tokenizer = SentenceTokenizer(non_breaking_prefix_file=non_prefix_file_path)
+ context = context.replace("\xa0", " ").replace("\ufeff", " ").replace("\t", " ")
+
+ sentence_list = []
+ for trtok_sentence in sentence_tokenizer.tokenize(context):
+
+ pattern = re.escape(trtok_sentence)
+ pattern = " +".join(pattern.split())
+ pattern += " {0,1}" # handle space between sentences
+ pattern = r"%s" % pattern
+ pattern += "\r{0,1}" # handle \r between sentences
+ pattern = r"%s" % pattern
+ pattern += "\n{0,1}" # handle \n between sentences
+ pattern = r"%s" % pattern
+ match_str = re.search(pattern, context)
+ start_idx, end_idx = match_str.span()
+ sentence = context[start_idx:end_idx]
+ sentence_list.append(sentence)
+ return sentence_list
+
+
+def _get_correct_alignement(context, answer):
+ """Some original examples in SQuAD have indices wrong by 1 or 2 character. We test and fix this here."""
+ gold_text = answer["text"]
+ start_idx = answer["answer_start"]
+ end_idx = start_idx + len(gold_text)
+ if context[start_idx:end_idx] == gold_text:
+ return start_idx, end_idx # When the gold label position is good
+ elif context[start_idx - 1 : end_idx - 1] == gold_text:
+ return start_idx - 1, end_idx - 1 # When the gold label is off by one character
+ elif context[start_idx - 2 : end_idx - 2] == gold_text:
+ return start_idx - 2, end_idx - 2 # When the gold label is off by two character
+ elif context[start_idx + 1 : end_idx + 1] == gold_text:
+ return start_idx + 1, end_idx + 1 # When the gold label is off by one character
+ elif context[start_idx + 2 : end_idx + 2] == gold_text:
+ return start_idx + 2, end_idx + 2 # When the gold label is off by two character
+ else:
+ raise ValueError()
+
+
+def get_answer(qa: Dict):
+ """qa: each element of 'qas' field of squad formatted dataset paragraph"""
+ if qa.get("answer") is not None:
+ return qa["answer"]
+ try:
+ answer = qa["answers"][0]
+ except IndexError:
+ answer = None
+ return answer
+
+
+def get_answer_indices_from_sentence(answer_text: str, sentence: str, loose_match: bool = False):
+ """
+ first match is returned in match case
+
+ Args:
+ answer_text (str)
+ sentence (str)
+ loose_match (bool) : if True, regex also matches substrings
+
+ Returns:
+ None, if no match
+ {"text": answer_text, "answer_start": ans_start_idx, "answer_end": ans_end_idx}, if match
+ answer_start and answer_end are sentence-wise indexes, not context-wise
+ """
+ # sometimes extracted answers does not match with the original text :/
+ try:
+ pattern = r"(? ".join(answer_list) + " "
+ else:
+ target_text = None
+
+ for sentence_ind2, sentence in enumerate(sentence_list):
+ if sentence_ind == sentence_ind2:
+ sentence = f" {sentence} "
+ source_text = f"{source_text} {sentence}"
+ source_text = source_text.strip()
+
+ sample = {"source_text": source_text, "target_text": target_text, "answer_list": answer_list}
+ if sample["target_text"] is None:
+ sample
+ samples.append(sample)
+
+ return samples
+
+
+def prepare_qa_sample(context: str, question: str, answer: str = None):
+ """
+ Args:
+ context (str) (assumed to be normalized via normalize_text)
+ question (str)
+ answer (str)
+ """
+ prepare_target = True if answer else False
+
+ source_text = f"question: {question} context: {context}"
+ if prepare_target:
+ target_text = f"{answer}"
+ else:
+ target_text = None
+ return {"source_text": source_text, "target_text": target_text}
+
+
+def prepare_qg_samples(context, answer_list: List[Dict], question_list: List[str] = None, qg_format: str = "highlight"):
+ """
+ Args:
+ context (str)
+ question_list (List[str])
+ answer_list: [
+ {'text': str, 'answer_start': int},
+ {'text': str, 'answer_start': int},
+ ...
+ ]
+ qg_format: 'highlight', 'prepend' or 'both'
+ """
+ # split into sentences
+
+ try:
+ samples = []
+ for ind, answer in enumerate(answer_list):
+ start_pos, end_pos = _get_correct_alignement(context, answer)
+ answer_text = answer["text"]
+
+ if qg_format == "prepend":
+ source_text = f"answer: {answer_text} context: {context}"
+ elif qg_format == "highlight":
+ source_text = f"generate question: {context[:start_pos]} {answer_text} {context[end_pos:]}"
+ elif qg_format == "both":
+ source_text = (
+ f"answer: {answer_text} context: {context[:start_pos]} {answer_text} {context[end_pos:]}"
+ )
+ else:
+ raise ValueError(f"unsupported qg format: {qg_format}")
+
+ if question_list:
+ question = question_list[ind]
+ else:
+ question = None
+
+ samples.append({"answer": answer_text, "source_text": source_text, "target_text": question})
+
+ except ValueError:
+ sentence_list = sentence_tokenize(normalize_text(context))
+
+ if question_list:
+ answer_list_per_sentence, question_list_per_sentence = get_answer_list_per_sentence(
+ sentence_list, answer_list, question_list
+ )
+ else:
+ answer_list_per_sentence = get_answer_list_per_sentence(sentence_list, answer_list)
+
+ samples = []
+ for sentence_ind, answer_list in enumerate(answer_list_per_sentence):
+ if not answer_list:
+ continue
+ for answer_ind, answer in enumerate(answer_list):
+ sentence = sentence_list[sentence_ind]
+ sentence_list_copy = sentence_list[:]
+
+ answer_start = answer["answer_start"]
+ answer_text = answer["text"]
+ answer_end = answer["answer_end"]
+
+ sentence = f"{sentence[:answer_start]} {answer_text} {sentence[answer_end:]}"
+ sentence_list_copy[sentence_ind] = sentence
+ highlighted_context = " ".join(sentence_list_copy)
+
+ if qg_format == "prepend":
+ source_text = f"answer: {answer_text} context: {context}"
+ elif qg_format == "highlight":
+ source_text = f"generate question: {highlighted_context}"
+ elif qg_format == "both":
+ source_text = f"answer: {answer_text} context: {highlighted_context}"
+ else:
+ raise ValueError(f"unsupported qg format: {qg_format}")
+
+ if question_list:
+ question_list = question_list_per_sentence[sentence_ind]
+ question = question_list[answer_ind]
+ else:
+ question = None
+
+ samples.append({"answer": answer_text, "source_text": source_text, "target_text": question})
+
+ if not samples:
+ samples.append({"answer": None, "source_text": None, "target_text": None})
+
+ return samples
+
+
+def postprocess_answer_extraction_output(answer_extraction_output: str):
+ """
+ Args:
+ answer_extraction_output (str): decoded answer extraction output
+
+ Returns:
+ answer_text_list (List[str])
+ """
+ # parse answers
+ answers = answer_extraction_output.split("")[:-1]
+ # normalize and append answers
+ answer_text_list = []
+ for answer_text in answers:
+ # append if not present
+ if answer_text and (answer_text not in answer_text_list):
+ answer_text_list.append(answer_text)
+ return answer_text_list
+
+
+def download_and_load_dataset(source_url: str, target_path: str) -> Dataset:
+ attempt_download(source_url=source_url, target_path=target_path)
+ data = load_json(target_path)
+ return data
+
+
+def remove_punctuations(text: str) -> str:
+ regex = re.compile("[%s]" % re.escape(string.punctuation))
+ text = regex.sub(" ", text)
+ return " ".join(text.split())
+
+
+def replace_multiple_spaces_with_single_whitespace(text: str) -> str:
+ return re.sub("\\s+", " ", text)
+
+
+def remove_citations(text: str) -> str:
+ """
+ Removes the citations that consist of a pair of brackets having a substring
+ containing at least one digit inside them.
+ Args:
+ text (str):
+
+ Returns:
+
+ """
+ text = re.sub("\[[a-zA-Z]\]", "", text)
+ return re.sub(r"\[(\s|\w)*\d+(\s|\w)*\]", "", text)
+
+
+def handle_ugly_case(text: str) -> str:
+ pattern = r"(?<=\d.)(/)(?=\d.)"
+ return re.sub(pattern, "-", text)
+
+
+def normalize_text(text: str) -> str:
+ text = text.strip()
+ text = replace_multiple_spaces_with_single_whitespace(text)
+ text = remove_citations(text)
+ text = handle_ugly_case(text)
+ return text
diff --git a/utils/torch.py b/utils/torch.py
new file mode 100644
index 0000000..dfddc65
--- /dev/null
+++ b/utils/torch.py
@@ -0,0 +1,56 @@
+from typing import Iterable, List
+
+from torch import nn
+
+# taken from https://github.com/patil-suraj/question_generation/blob/master/utils.py
+
+
+def grad_status(model: nn.Module) -> Iterable:
+ return (par.requires_grad for par in model.parameters())
+
+
+def freeze_params(model: nn.Module):
+ for par in model.parameters():
+ par.requires_grad = False
+
+
+def freeze_embeds(model: nn.Module):
+ """Freeze token embeddings and positional embeddings for bart, just token embeddings for t5."""
+ try:
+ freeze_params(model.model.shared)
+ for d in [model.model.encoder, model.model.decoder]:
+ freeze_params(d.embed_positions)
+ freeze_params(d.embed_tokens)
+ except AttributeError:
+ freeze_params(model.shared)
+ for d in [model.encoder, model.decoder]:
+ freeze_params(d.embed_tokens)
+
+
+def assert_not_all_frozen(model):
+ model_grads: List[bool] = list(grad_status(model))
+ npars = len(model_grads)
+ assert any(model_grads), f"none of {npars} weights require grad"
+
+
+def label_smoothed_nll_loss(lprobs, target, epsilon, ignore_index=-100):
+ """From fairseq"""
+ if target.dim() == lprobs.dim() - 1:
+ target = target.unsqueeze(-1)
+ nll_loss = -lprobs.gather(dim=-1, index=target)
+ smooth_loss = -lprobs.sum(dim=-1, keepdim=True)
+ if ignore_index is not None:
+ pad_mask = target.eq(ignore_index)
+ nll_loss.masked_fill_(pad_mask, 0.0)
+ smooth_loss.masked_fill_(pad_mask, 0.0)
+ bs = pad_mask.long().sum()
+ else:
+ nll_loss = nll_loss.squeeze(-1)
+ smooth_loss = smooth_loss.squeeze(-1)
+ bs = lprobs.shape[0]
+
+ nll_loss = nll_loss.sum() # mean()? Scared to break other math.
+ smooth_loss = smooth_loss.sum()
+ eps_i = epsilon / lprobs.size(-1)
+ loss = (1.0 - epsilon) * nll_loss + eps_i * smooth_loss
+ return loss / bs, nll_loss / bs
diff --git a/utils/wandb.py b/utils/wandb.py
new file mode 100644
index 0000000..9a211bd
--- /dev/null
+++ b/utils/wandb.py
@@ -0,0 +1,22 @@
+def init_wandb(project: str = None, name: str = None, id: str = None):
+ status = False
+ wandb = None
+ if project is not None and name is not None:
+ try:
+ import wandb
+
+ if not id:
+ wandb.init(project=project, name=name)
+ else:
+ wandb.init(project=project, name=name, id=id)
+ status = True
+ except ImportError:
+ print("wandb not installed, skipping wandb logging. 'pip install wandb' for wandb logging.")
+ except Exception as e:
+ print(e)
+
+ return status, wandb
+
+
+def log_to_wandb(wandb, dict):
+ wandb.log({**dict})