Skip to content

Latest commit

 

History

History
28 lines (23 loc) · 884 Bytes

README.md

File metadata and controls

28 lines (23 loc) · 884 Bytes

CircleCI

Optimal Completion Distillation (OCD) Training

Implementation of the Optimal Completion Distillation for Sequence Labeling
source : https://arxiv.org/abs/1810.01398

Requirements

python3, pytorch 1.0.0

Install

python3 -m venv env
source env/bin/activate
pip3 install .

How to use?

look at https://github.com/SaeedNajafi/pytorch-ocd/blob/master/ocd/__init__.py#L50 and
https://github.com/SaeedNajafi/pytorch-ocd/blob/master/tests/test_ocd.py#L132

from ocd import OCD

ocd_trainer = OCD(vocab_size=10, end_symbol_id=9)
...  # model defines scores for each step and each possible output token.
ocd_loss = ocd_trainer(model_scores, gold_output_sequence)
...  # backprop with ocd_loss