Skip to content

Commit

Permalink
Lstm benchmark example fix (explosion#625)
Browse files Browse the repository at this point in the history
* bugfix

* type annotation

* formatting

* reformat import statements

Co-authored-by: Kádár Ákos <[email protected]>
Co-authored-by: svlandeg <[email protected]>
  • Loading branch information
3 people authored Mar 31, 2022
1 parent 58f50aa commit e9a25ca
Showing 1 changed file with 13 additions and 13 deletions.
26 changes: 13 additions & 13 deletions examples/benchmarks/lstm_tagger.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,11 @@
import tqdm
import numpy.random
from timeit import default_timer as timer
from thinc.api import Model, Config, registry, chain, list2padded, with_array
from thinc.api import Model, Config, registry
from thinc.api import chain, list2padded, with_array, with_padded
from thinc.api import to_categorical, set_current_ops
from thinc.api import NumpyOps, CupyOps, fix_random_seed, require_gpu
from thinc.types import Array2d, Padded
from thinc.api import Ops, NumpyOps, CupyOps, fix_random_seed, require_gpu
from thinc.types import Ints1d, Ints2d, Floats2d, Padded

CONFIG = """
[data]
Expand Down Expand Up @@ -56,15 +57,12 @@

@registry.layers("LSTMTagger.v1")
def build_tagger(
embed: Model[Array2d, Array2d],
embed: Model[Ints2d, Floats2d],
encode: Model[Padded, Padded],
predict: Model[Array2d, Array2d],
) -> Model[List[Array2d], Padded]:
predict: Model[Floats2d, Floats2d],
) -> Model[List[Ints1d], Padded]:
model = chain(
list2padded(),
with_array(embed),
encode,
with_array(predict),
with_array(embed), with_padded(encode), with_array(predict), list2padded()
)
model.set_ref("embed", embed)
model.set_ref("encode", encode)
Expand Down Expand Up @@ -101,7 +99,7 @@ def run_forward_backward(model, batches, n_times=1):
for _ in range(n_times):
for X, Y in tqdm.tqdm(batches):
Yh, get_dX = model.begin_update(X)
dX = get_dX(Yh)
get_dX(Yh)
total += Yh.data.sum()
return float(total)

Expand All @@ -117,12 +115,14 @@ def set_backend(name, gpu_id):
set_current_ops(CupyOps())
if name == "pytorch":
import torch

torch.set_num_threads(1)
CONFIG = CONFIG.replace("LSTM.v1", "PyTorchLSTM.v1")


def main(numpy: bool=False, pytorch: bool = False,
generic: bool=False, gpu_id: int = -1):
def main(
numpy: bool = False, pytorch: bool = False, generic: bool = False, gpu_id: int = -1
):
global CONFIG
fix_random_seed(0)
if gpu_id >= 0:
Expand Down

0 comments on commit e9a25ca

Please sign in to comment.