diff --git a/runbenchmark.py b/runbenchmark.py index ad5eac7e3..6662d2f4c 100644 --- a/runbenchmark.py +++ b/runbenchmark.py @@ -4,6 +4,7 @@ import re import shutil import sys +from pathlib import Path # prevent asap other modules from defining the root logger using basicConfig import amlb.logger @@ -11,7 +12,7 @@ import openml import amlb -from amlb.utils import Namespace as ns, config_load, datetime_iso, str2bool, str_sanitize, zip_path +from amlb.utils import Namespace as ns, config_load, datetime_iso, str2bool, str_sanitize, zip_path, run_cmd from amlb import log, AutoMLError from amlb.defaults import default_dirs @@ -99,6 +100,18 @@ # help="The region on which to run the benchmark when using AWS.") args = parser.parse_args() + +GIT_PATTERN = re.compile(r"https://(?:www.)?\w+.\w+/([a-zA-Z0-9_\-\.]+)/([a-zA-Z0-9_\-\.]+).git") +if args.userdir and (match := GIT_PATTERN.match(args.userdir)): + user, repo = match.groups() + DOWNLOAD_DIRECTORY = Path(__file__).parent / "downloads" + download_path = DOWNLOAD_DIRECTORY / user / repo + if not download_path.exists(): + download_path.mkdir(parents=True) + cmd = f"clone {args.userdir}" + run_cmd(f"git {cmd} {download_path}", _log_level_=logging.DEBUG)[0].strip() + args.userdir = str(download_path) + script_name = os.path.splitext(os.path.basename(__file__))[0] extras = {t[0]: t[1] if len(t) > 1 else True for t in [x.split('=', 1) for x in args.extra]}