Skip to content

Commit

Permalink
PR feedback
Browse files Browse the repository at this point in the history
Signed-off-by: James Bornholt <[email protected]>
  • Loading branch information
jamesbornholt committed Dec 8, 2023
1 parent 07fd195 commit 51e6f43
Showing 1 changed file with 9 additions and 9 deletions.
18 changes: 9 additions & 9 deletions examples/pytorch/resnet.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""This is a simple example of how to use Mountpoint as a PyTorch data loader via the torchdata
library. By default, it trains a ResNet-18 model for a few epochs using PyTorch Lightning, with
library. By default, it trains a ResNet-50 model for a few epochs using PyTorch Lightning, with
synthetic ImageNet-sized training data, stored in S3 as shards in WebDataset format or as individual
images.
Expand Down Expand Up @@ -200,16 +200,16 @@ def make_s3_datapipe(args: argparse.Namespace) -> (torchdata.datapipes.iter.Iter
the local path of Mountpoint if args.source_kind == mountpoint, or None otherwise."""
if args.source_kind == "mountpoint" or args.source_kind == "local":
if args.source_kind == "mountpoint":
local_path = make_mountpoint(args.s3url, args.mountpoint_path, args.mountpoint_args)
local_path = make_mountpoint(args.path, args.mountpoint_path, args.mountpoint_args)
sentinel_path = local_path
elif args.source_kind == "local":
local_path, sentinel_path = args.s3url, None
local_path, sentinel_path = args.path, None
lister = torchdata.datapipes.iter.FileLister([local_path], recursive=True)
lister = lister.sharding_filter()
return torchdata.datapipes.iter.FileOpener(lister, mode="rb"), sentinel_path
elif args.source_kind == "fsspec":
# Load from S3 using the FSSpec/S3FS libraries
lister = torchdata.datapipes.iter.FSSpecFileLister([args.s3url])
lister = torchdata.datapipes.iter.FSSpecFileLister([args.path])
if args.dataset_format == "single":
# fsspec lists directories rather than recursively, so need a second-level list
lister = torchdata.datapipes.iter.FSSpecFileLister(lister)
Expand All @@ -219,7 +219,7 @@ def make_s3_datapipe(args: argparse.Namespace) -> (torchdata.datapipes.iter.Iter
# Load from S3 using the S3-specific IO datapipe (requires a BUILD_S3=1 version of torchdata)
if args.region is None:
raise Exception("region must be specified for s3io")
lister = torchdata.datapipes.iter.S3FileLister([args.s3url], region=args.region)
lister = torchdata.datapipes.iter.S3FileLister([args.path], region=args.region)
lister = lister.sharding_filter()
return torchdata.datapipes.iter.S3FileLoader(lister, region=args.region), None
else:
Expand All @@ -231,18 +231,18 @@ def make_dataset(args: argparse.Namespace) -> (torch.utils.data.Dataset, Optiona
args.source_kind == mountpoint, or None otherwise."""
if args.dataset_format == "imagefolder":
if args.source_kind == "mountpoint":
local_path = make_mountpoint(args.s3url, args.mountpoint_path, args.mountpoint_args)
local_path = make_mountpoint(args.path, args.mountpoint_path, args.mountpoint_args)
return torchvision.datasets.ImageFolder(local_path), local_path
elif args.source_kind == "local":
return torchvision.datasets.ImageFolder(args.s3url), None
return torchvision.datasets.ImageFolder(args.path), None
else:
raise Exception(f"imagefolder dataset only supports mountpoint and local sources")

pipe, local_path = make_s3_datapipe(args)
if args.dataset_format == "webdataset":
return pipe.load_from_tar().webdataset().map(load_image), local_path
elif args.dataset_format == "single":
return pipe.map(lambda x: {".jpg": x[1], ".cls": extract_class(x[0], args.s3url)}).map(load_image), local_path
return pipe.map(lambda x: {".jpg": x[1], ".cls": extract_class(x[0], args.path)}).map(load_image), local_path
else:
raise Exception(f"unknown dataset format {args.dataset_format}")

Expand Down Expand Up @@ -298,7 +298,7 @@ def run_training(dataset: torch.utils.data.Dataset, args: argparse.Namespace, lo
"--model", default="resnet50", help="model name to train (must be a model from `torchvision.models`)"
)
p_train.add_argument(
"s3url", help="S3 URL for sharded training data directory (starts with 's3://', ends with '/')"
"path", help="path to sharded training data directory (can be an S3 URI like 's3://bucket/prefix/' or a local folder)"
)
p_train.add_argument(
"--source-kind",
Expand Down

0 comments on commit 51e6f43

Please sign in to comment.