Skip to content

Commit

Permalink
Refactor openai api.
Browse files Browse the repository at this point in the history
  • Loading branch information
lxuechen committed Dec 10, 2022
1 parent 66bd0e6 commit 36a8c2b
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 4 deletions.
14 changes: 12 additions & 2 deletions ml_swissknife/openai_utils.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
"""Light wrapper around OpenAI API.
Should not rewrite these multiple times for different projects...
For reference:
https://beta.openai.com/docs/api-reference/completions/create
"""
import dataclasses
import logging
import math
import sys
import time
from typing import Union, Optional, Tuple
from typing import Optional, Tuple, Union

import openai
import tqdm
Expand All @@ -30,7 +33,8 @@ def _openai_completion(
model_name, prompts: Union[str, list, tuple], decoding_args, sleep_time=2, batch_size=1,
max_batches=sys.maxsize, # This should only be used during testing.
):
if isinstance(prompts, str):
is_single_prompt = isinstance(prompts, str)
if is_single_prompt:
prompts = [prompts]

num_prompts = len(prompts)
Expand Down Expand Up @@ -66,4 +70,10 @@ def _openai_completion(
logging.warning('Hit request rate limit; retrying...')
time.sleep(sleep_time) # Annoying rate limit on requests.

if is_single_prompt and decoding_args.n == 1:
completions, = completions # Return non-tuple if only 1 input and 1 generation.
return completions


# Keep the private function for backwards compat.
openai_completion = _openai_completion
5 changes: 3 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import setuptools

version = "0.1.8"
version = "0.1.9"

extras_require = {
"latex": ("bibtexparser",)
Expand All @@ -17,7 +17,8 @@
url="https://github.com/lxuechen/ml-swissknife",
install_requires=[
'torch', 'torchvision', 'spacy', 'tqdm', 'numpy', 'scipy', 'gputil', 'fire', 'requests', 'nltk', 'transformers',
'datasets', 'gdown>=4.4.0', 'pandas', 'pytest', 'matplotlib', 'seaborn', 'cvxpy', 'imageio', 'wandb', 'openai', 'numba'
'datasets', 'gdown>=4.4.0', 'pandas', 'pytest', 'matplotlib', 'seaborn', 'cvxpy', 'imageio', 'wandb', 'openai',
'numba'
],
extras_require=extras_require,
python_requires='~=3.7',
Expand Down

0 comments on commit 36a8c2b

Please sign in to comment.