From 51e6f43a048ea92d606a11bd6e6cc1f5283f6564 Mon Sep 17 00:00:00 2001 From: James Bornholt Date: Fri, 8 Dec 2023 00:37:07 +0000 Subject: [PATCH] PR feedback Signed-off-by: James Bornholt --- examples/pytorch/resnet.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/examples/pytorch/resnet.py b/examples/pytorch/resnet.py index 202ea947c..365e981b3 100644 --- a/examples/pytorch/resnet.py +++ b/examples/pytorch/resnet.py @@ -1,5 +1,5 @@ """This is a simple example of how to use Mountpoint as a PyTorch data loader via the torchdata -library. By default, it trains a ResNet-18 model for a few epochs using PyTorch Lightning, with +library. By default, it trains a ResNet-50 model for a few epochs using PyTorch Lightning, with synthetic ImageNet-sized training data, stored in S3 as shards in WebDataset format or as individual images. @@ -200,16 +200,16 @@ def make_s3_datapipe(args: argparse.Namespace) -> (torchdata.datapipes.iter.Iter the local path of Mountpoint if args.source_kind == mountpoint, or None otherwise.""" if args.source_kind == "mountpoint" or args.source_kind == "local": if args.source_kind == "mountpoint": - local_path = make_mountpoint(args.s3url, args.mountpoint_path, args.mountpoint_args) + local_path = make_mountpoint(args.path, args.mountpoint_path, args.mountpoint_args) sentinel_path = local_path elif args.source_kind == "local": - local_path, sentinel_path = args.s3url, None + local_path, sentinel_path = args.path, None lister = torchdata.datapipes.iter.FileLister([local_path], recursive=True) lister = lister.sharding_filter() return torchdata.datapipes.iter.FileOpener(lister, mode="rb"), sentinel_path elif args.source_kind == "fsspec": # Load from S3 using the FSSpec/S3FS libraries - lister = torchdata.datapipes.iter.FSSpecFileLister([args.s3url]) + lister = torchdata.datapipes.iter.FSSpecFileLister([args.path]) if args.dataset_format == "single": # fsspec lists directories rather than recursively, so need a second-level list lister = torchdata.datapipes.iter.FSSpecFileLister(lister) @@ -219,7 +219,7 @@ def make_s3_datapipe(args: argparse.Namespace) -> (torchdata.datapipes.iter.Iter # Load from S3 using the S3-specific IO datapipe (requires a BUILD_S3=1 version of torchdata) if args.region is None: raise Exception("region must be specified for s3io") - lister = torchdata.datapipes.iter.S3FileLister([args.s3url], region=args.region) + lister = torchdata.datapipes.iter.S3FileLister([args.path], region=args.region) lister = lister.sharding_filter() return torchdata.datapipes.iter.S3FileLoader(lister, region=args.region), None else: @@ -231,10 +231,10 @@ def make_dataset(args: argparse.Namespace) -> (torch.utils.data.Dataset, Optiona args.source_kind == mountpoint, or None otherwise.""" if args.dataset_format == "imagefolder": if args.source_kind == "mountpoint": - local_path = make_mountpoint(args.s3url, args.mountpoint_path, args.mountpoint_args) + local_path = make_mountpoint(args.path, args.mountpoint_path, args.mountpoint_args) return torchvision.datasets.ImageFolder(local_path), local_path elif args.source_kind == "local": - return torchvision.datasets.ImageFolder(args.s3url), None + return torchvision.datasets.ImageFolder(args.path), None else: raise Exception(f"imagefolder dataset only supports mountpoint and local sources") @@ -242,7 +242,7 @@ def make_dataset(args: argparse.Namespace) -> (torch.utils.data.Dataset, Optiona if args.dataset_format == "webdataset": return pipe.load_from_tar().webdataset().map(load_image), local_path elif args.dataset_format == "single": - return pipe.map(lambda x: {".jpg": x[1], ".cls": extract_class(x[0], args.s3url)}).map(load_image), local_path + return pipe.map(lambda x: {".jpg": x[1], ".cls": extract_class(x[0], args.path)}).map(load_image), local_path else: raise Exception(f"unknown dataset format {args.dataset_format}") @@ -298,7 +298,7 @@ def run_training(dataset: torch.utils.data.Dataset, args: argparse.Namespace, lo "--model", default="resnet50", help="model name to train (must be a model from `torchvision.models`)" ) p_train.add_argument( - "s3url", help="S3 URL for sharded training data directory (starts with 's3://', ends with '/')" + "path", help="path to sharded training data directory (can be an S3 URI like 's3://bucket/prefix/' or a local folder)" ) p_train.add_argument( "--source-kind",