Skip to content

Commit

Permalink
add tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
shibing624 committed Feb 11, 2022
1 parent ffcfb53 commit da089fe
Show file tree
Hide file tree
Showing 10 changed files with 16,451 additions and 93 deletions.
38 changes: 31 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,10 @@ code-autocomplete, a code completion plugin for Python.
- [Reference](#reference)

# Feature


# Demo

http://42.193.145.218/product/short_text_sim/
- GPT2-based code completion
- Code completion for Python, other language is coming soon
- Line and block completion
- Train(Fine-tune GPT2) and predict model with your own data

# Install
```
Expand All @@ -46,7 +45,7 @@ python3 setup.py install
### Code Completion


开源项目:[code-autocomplete](https://github.com/shibing624/code-autocomplete),可支持GPT2模型,通过如下命令调用:
基于GPT2模型预测补全代码,通过如下命令调用:

```python
from autocomplete.gpt2 import Infer
Expand All @@ -55,6 +54,10 @@ i = m.predict('import torch.nn as')
print(i)
```

output:
```shell
import torch.nn as nn
```
当然,你也可使用官方的huggingface/transformers调用:

*Please use 'GPT2' related functions to load this model!*
Expand All @@ -81,7 +84,7 @@ prompts = [
"""import numpy as np
import torch
import torch.nn as""",
"import java.util.ArrayList",
"import java.util.ArrayList;",
"def factorial(n):",
]
for prompt in prompts:
Expand All @@ -101,6 +104,27 @@ for prompt in prompts:
print("=" * 20)
```

output:
```python
from torch import nn
class LSTM(Module):
def __init__(self, *,
n_tokens: int,
embedding_size: int,
hidden_size: int,
n_layers: int):
self.hidden_size = hidden_size
self.embedding_size = embedding_size

====================

import numpy as np
import torch
import torch.nn as nn

====================
...
```


# Contact
Expand Down
8 changes: 5 additions & 3 deletions autocomplete/gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
import torch
from simpletransformers.language_generation import LanguageGenerationModel
from simpletransformers.language_modeling import LanguageModelingModel
import transformers

transformers.logging.set_verbosity_error()
use_cuda = torch.cuda.is_available()


Expand Down Expand Up @@ -55,13 +57,13 @@ def __init__(self, model_name="gpt2", model_dir="outputs/fine-tuned", use_cuda=u
# cache_dir: None means use default cache dir: ~/.cache/huggingface/transformers/
self.model = LanguageGenerationModel(model_name, model_dir, args=args, use_cuda=use_cuda)

def predict(self, query):
def predict(self, prompt):
"""
Generate text using the model. Verbose set to False to prevent logging generated sequences.
:param query: str, input string
:param prompt: str, input string
:return: str
"""
generated = self.model.generate(query, verbose=False)
generated = self.model.generate(prompt, verbose=False)
generated = generated[0]
return generated

Expand Down
24 changes: 3 additions & 21 deletions examples/base_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,31 +3,13 @@
@author:XuMing([email protected])
@description:
"""
import argparse
import sys
import torch

sys.path.append('..')
from autocomplete.gpt2 import Infer

use_cuda = torch.cuda.is_available()
if __name__ == '__main__':
prompts = [
"""from torch import nn
class LSTM(Module):
def __init__(self, *,
n_tokens: int,
embedding_size: int,
hidden_size: int,
n_layers: int):""",
"""import numpy as np
import torch
import torch.nn as""",
"import java.util.ArrayList",
]
infer = Infer(model_name="gpt2", model_dir="shibing624/code-autocomplete-gpt2-base", use_cuda=use_cuda)
for prompt in prompts:
res = infer.predict(prompt)
print("Query:", prompt)
print("Result:", res)
print("=" * 20)
m = Infer(model_name="gpt2", model_dir="shibing624/code-autocomplete-gpt2-base", use_cuda=use_cuda)
i = m.predict('import torch.nn as')
print(i)
2 changes: 1 addition & 1 deletion examples/gpt2_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def __init__(self, *,
"""import numpy as np
import torch
import torch.nn as""",
"import java.util.ArrayList",
"import java.util.ArrayList;",
]
predict_with_original_gpt2(prompts)
if args.do_train:
Expand Down
Loading

0 comments on commit da089fe

Please sign in to comment.