Implementation of the Optimal Completion Distillation for Sequence Labeling
source : https://arxiv.org/abs/1810.01398
python3
, pytorch 1.0.0
python3 -m venv env
source env/bin/activate
pip3 install .
look at https://github.com/SaeedNajafi/pytorch-ocd/blob/master/ocd/__init__.py#L50
and
https://github.com/SaeedNajafi/pytorch-ocd/blob/master/tests/test_ocd.py#L132
from ocd import OCD
ocd_trainer = OCD(vocab_size=10, end_symbol_id=9)
... # model defines scores for each step and each possible output token.
ocd_loss = ocd_trainer(model_scores, gold_output_sequence)
... # backprop with ocd_loss