Skip to content

Commit

Permalink
fix hidden state bug
Browse files Browse the repository at this point in the history
fix hidden state bug
  • Loading branch information
goru001 authored Dec 14, 2019
2 parents 261eea2 + ba0ab0d commit 7c55095
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 3 deletions.
7 changes: 5 additions & 2 deletions inltk/inltk.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ def get_embedding_vectors(input: str, language_code: str):
path = Path(__file__).parent
learn = load_learner(path / 'models' / f'{language_code}')
encoder = get_model(learn.model)[0]
encoder.reset()
embeddings = encoder.state_dict()['encoder.weight']
embeddings = np.array(embeddings)
embedding_vectors = []
Expand All @@ -105,8 +106,9 @@ def get_sentence_encoding(input: str, language_code: str):
defaults.device = torch.device('cpu')
path = Path(__file__).parent
learn = load_learner(path / 'models' / f'{language_code}')
m = learn.model
kk0 = m[0](Tensor([token_ids]).to(torch.int64))
encoder = learn.model[0]
encoder.reset()
kk0 = encoder(Tensor([token_ids]).to(torch.int64))
return np.array(kk0[0][-1][0][-1])


Expand All @@ -128,6 +130,7 @@ def get_similar_sentences(sen: str, no_of_variations: int, language_code: str):
path = Path(__file__).parent
learn = load_learner(path / 'models' / f'{language_code}')
encoder = get_model(learn.model)[0]
encoder.reset()
embeddings = encoder.state_dict()['encoder.weight']
embeddings = np.array(embeddings)
# cos similarity of vectors
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

setuptools.setup(
name="inltk",
version="0.7.1",
version="0.7.2",
author="Gaurav",
author_email="[email protected]",
description="Natural Language Toolkit for Indian Languages (iNLTK)",
Expand Down

0 comments on commit 7c55095

Please sign in to comment.