Skip to content

Commit

Permalink
Optimize imports
Browse files Browse the repository at this point in the history
  • Loading branch information
CompilerCrash committed Apr 20, 2023
1 parent d7dbf65 commit 2c4b4c6
Show file tree
Hide file tree
Showing 10 changed files with 61 additions and 67 deletions.
19 changes: 8 additions & 11 deletions datasets/apps_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,23 +5,20 @@
# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
#

import torch
import glob
import logging
import random
import fnmatch
import numpy as np
import gc
import json
import os
from tqdm import tqdm
from collections import Counter
import pdb
import pickle as pkl
import json, pdb
import random
from collections import Counter

from multiprocessing import Manager
import transformers
import numpy as np
import torch
from tqdm import tqdm

import datasets.utils as dsutils
import transformers


class APPSBaseDataset(torch.utils.data.Dataset):
Expand Down
9 changes: 5 additions & 4 deletions datasets/reindent.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@
"""

from __future__ import print_function
import sys
import getopt

import codecs
import tempfile
import shutil
import getopt
import os
import shutil
import sys
import tempfile


def _find_indentation(line, config):
Expand Down
6 changes: 0 additions & 6 deletions datasets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,6 @@
# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
#

import sys
import getopt
import codecs
import tempfile
import shutil
import os
import io

from datasets.reindent import run as run_reindent
Expand Down
13 changes: 7 additions & 6 deletions generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,19 @@
# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
#

import glob
import json
import os
import pickle as pkl
import pprint
from collections import Counter

import numpy as np
import torch
import pdb
import glob
from tqdm import tqdm
import pickle as pkl
import numpy as np
from collections import Counter
from transformers import RobertaTokenizer, T5ForConditionalGeneration

import datasets.utils as dsutils
from transformers import RobertaTokenizer, T5ForConditionalGeneration


def generate_prompt(args, test_case_path, prompt_path, solutions_path, tokenizer,
Expand Down
20 changes: 19 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,21 @@
torch==1.11.0
pyext==0.7
deepspeed
deepspeed
numpy
tqdm
transformers
packaging
filelock
ipython
requests
sacremoses
tokenizers
regex
PyYAML
six
pydantic
psutil
attrs
black
setuptools
joblib
12 changes: 5 additions & 7 deletions test_one_solution.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,15 @@
# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
#

import glob
import json
import numpy as np
import os
import os.path
import pprint
import glob
from tqdm import tqdm
import pdb
import traceback
import pickle as pkl
from typing import List
import traceback

import numpy as np
from tqdm import tqdm

from utils.testing_util import run_test

Expand Down
15 changes: 3 additions & 12 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,27 +5,18 @@
# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
#

import io
import logging
import math
import json
import os
import pprint
import sys
import time
import json
import pdb
from tqdm import tqdm
from datetime import datetime

import transformers
import torch
import torch.multiprocessing

import transformers
from datasets.apps_dataset import APPSBaseDataset
from trainers.trainer_rl import Trainer_RL
from transformers import Trainer

import torch.multiprocessing

torch.multiprocessing.set_sharing_strategy('file_system')


Expand Down
3 changes: 2 additions & 1 deletion trainers/trainer_rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@

from tqdm.auto import tqdm

import pdb

# Integrations must be imported before ML frameworks:
from transformers.integrations import ( # isort: split
Expand Down Expand Up @@ -121,6 +120,7 @@
from transformers.training_args import ParallelMode, TrainingArguments
from transformers.utils import logging


_is_torch_generator_available = False
_is_native_amp_available = False

Expand Down Expand Up @@ -168,6 +168,7 @@

from .trainer_pt_utils import smp_forward_backward, smp_forward_only, smp_gather, smp_nested_concat


if TYPE_CHECKING:
import optuna

Expand Down
9 changes: 5 additions & 4 deletions utils/reindent.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,13 @@
"""

from __future__ import print_function
import sys
import getopt

import codecs
import tempfile
import shutil
import getopt
import os
import shutil
import sys
import tempfile


def _find_indentation(line, config):
Expand Down
22 changes: 7 additions & 15 deletions utils/testing_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,34 +6,26 @@
#

import argparse
import faulthandler
import gc
import json
import os
# to run the solution files we're using a timing based approach
import signal
import sys
import io
import faulthandler

# used for debugging to time steps
from datetime import datetime

# to run the solution files we're using a timing based approach
import signal

import numpy as np
from enum import Enum
# for capturing the stdout
from io import StringIO
from typing import get_type_hints
from typing import List, Tuple
from typing import List
# used for testing the code that reads from input
from unittest.mock import patch, mock_open

import numpy as np
from pyext import RuntimeModule
import gc
from enum import Enum
import traceback
from tqdm import tqdm

import pdb


class CODE_TYPE(Enum):
call_based = 0
Expand Down

0 comments on commit 2c4b4c6

Please sign in to comment.